load model from file and stream for caffe and pytorch

load model from file and stream for caffe and pytorch

Guide

caffe

load from file

1
2
3
4
5
6
enum caffe::Phase phase = caffe::Phase::TEST;

std::string proto_filepath = "yolov3.prototxt";
std::string weight_filepath = "yolov3.caffemodel";
caffe::Net<float> net = caffe::Net<float>(proto_filepath, phase));
net.CopyTrainedLayersFrom(weight_filepath);

load from stream

no caffe method to load directly from stream.
we can override ReadProtoFromTextFile and ReadProtoFromBinaryFile in src/caffe/util/io.cpp to implement this fuction.

Replace

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
bool ReadProtoFromTextFile(const char* filename, Message* proto) {
Encryption encryption;
int fd = open(filename, O_RDONLY);
CHECK_NE(fd, -1) << "File not found: " << filename;
FileInputStream* input = new FileInputStream(fd);
bool success = google::protobuf::TextFormat::Parse(input, proto);
delete input;
close(fd);
return success;
}

bool ReadProtoFromBinaryFile(const char* filename, Message* proto) {
int fd = open(filename, O_RDONLY);
CHECK_NE(fd, -1) << "File not found: " << filename;
ZeroCopyInputStream* raw_input = new FileInputStream(fd);
CodedInputStream* coded_input = new CodedInputStream(raw_input);
coded_input->SetTotalBytesLimit(kProtoReadBytesLimit, 536870912);

bool success = proto->ParseFromCodedStream(coded_input);

delete coded_input;
delete raw_input;
close(fd);
return success;
}

load from demo.prototxt and demo.caffemodel

with

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
bool ReadProtoFromTextFile(const char *filename, Message *proto) {
Encryption encryption;
string res = encryption.decryptTextFile(filename); // demo.prototxt
istringstream ss(res);

IstreamInputStream *input = new IstreamInputStream(&ss);

bool success = google::protobuf::TextFormat::Parse(input, proto);
delete input;
return success;
}


bool ReadProtoFromBinaryFile(const char *filename, Message *proto) {
Encryption encryption;
string res = encryption.decryptModelFile(filename); // demo.caffemodel
istringstream ss(res);

IstreamInputStream *input = new IstreamInputStream(&ss);
CodedInputStream *coded_input = new CodedInputStream(input);
coded_input->SetTotalBytesLimit(kProtoReadBytesLimit, 536870912);

bool success = proto->ParseFromCodedStream(coded_input);

delete coded_input;
delete input;
return success;
}

load from demo_encrypt.prototxt and demo_encrypt.caffemodel

pytorch

  • torch::jit::script::Module load(const std::string& filename,...);
  • torch::jit::script::Module load(const std::istream& in,...);

load from file

1
2
3
std::string model_path = "model.libpt";
torch::jit::script::Module net = torch::jit::load(model_path);
assert(net != nullptr);

load from stream

1
2
3
4
std::string model_content = ""; // read from file
std::istringstream ss(model_content);
torch::jit::script::Module net = torch::jit::load(ss);
assert(net != nullptr);

Reference

History

  • 20191014: created.
坚持技术分享,您的支持将鼓励我继续创作!