diff --git a/caffe2/opt/onnxifi_transformer.cc b/caffe2/opt/onnxifi_transformer.cc index e7e682f59be48..eb9e7cd96f4c4 100644 --- a/caffe2/opt/onnxifi_transformer.cc +++ b/caffe2/opt/onnxifi_transformer.cc @@ -162,7 +162,7 @@ OperatorDef OnnxifiTransformer::BuildOnnxifiOp( NetDef OnnxifiTransformer::SubnetToOnnxifiOp( const caffe2::NetDef& net, - const Workspace& mapped_ws, + const std::unordered_set& weights_in_ws, Workspace* ws, onnx::OnnxExporter* exporter, std::unordered_map* shape_hints) { @@ -240,13 +240,6 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOp( } // Convert inputs and figure out weights - std::unordered_set weights; - const std::vector& ws_blobs = mapped_ws.Blobs(); - for (const auto& s : ws_blobs) { - VLOG(2) << "Add weights: " << s; - weights.emplace(s); - } - std::unordered_set total_inputs; std::unordered_set initialization_list; std::vector total_inputs_vec; @@ -266,11 +259,11 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOp( for (const auto& op : net.op()) { for (const auto& input : op.input()) { - if (total_inputs.emplace(input).second && weights.count(input)) { + if (total_inputs.emplace(input).second && weights_in_ws.count(input)) { // We add weights as inputs too total_inputs_vec.emplace_back(input); initialization_list.emplace(input); - VLOG(2) << "Add input weights: " << input; + VLOG(2) << "Add weights: " << input; } else if (boundary_inputs.count(input)) { VLOG(2) << "Adding boundary input: " << input; total_inputs_vec.emplace_back(input); @@ -308,10 +301,9 @@ CaffeMap OnnxifiTransformer::SsaRewriteAndMapNames( NetDef* pred_net, const std::unordered_map& input_shape_hints) { input_mapping_ = onnx::SsaRewrite(nullptr, pred_net); - std::unordered_map input_reverse_mapping; std::vector external_inputs; for (const auto kv : input_mapping_) { - input_reverse_mapping.emplace(kv.second, kv.first); + reverse_input_mapping_.emplace(kv.second, kv.first); if (!ws->HasBlob(kv.second)) { external_inputs.emplace_back(kv.first); } @@ -321,8 +313,8 @@ CaffeMap OnnxifiTransformer::SsaRewriteAndMapNames( } CaffeMap shape_hints_ordered; for (const auto& kv : input_shape_hints) { - const auto it = input_reverse_mapping.find(kv.first); - if (it != input_reverse_mapping.end()) { + const auto it = reverse_input_mapping_.find(kv.first); + if (it != reverse_input_mapping_.end()) { shape_hints_ordered.emplace(it->second, kv.second); } else { shape_hints_ordered.emplace(kv.first, kv.second); @@ -336,6 +328,7 @@ CaffeMap OnnxifiTransformer::SsaRewriteAndMapNames( void OnnxifiTransformer::Transform( Workspace* ws, NetDef* pred_net, + const std::vector& external_inputs, const std::unordered_map& input_shape_hints) { CAFFE_ENFORCE(ws); auto shape_hints_ordered = @@ -404,9 +397,23 @@ void OnnxifiTransformer::Transform( // same exporter throughout the process to avoid duplicated dummy name // generation onnx::OnnxExporter exporter2(nullptr); - auto trt_converter = [this, ws, &mapped_ws, &shape_hints, &exporter2]( + std::unordered_set weights; + std::unordered_set input_set; + for (const auto& i : external_inputs) { + const auto it = reverse_input_mapping_.find(i); + if (it != reverse_input_mapping_.end()) { + input_set.emplace(it->second); + } + } + const std::vector& ws_blobs = mapped_ws.Blobs(); + for (const auto& s : ws_blobs) { + if (!input_set.count(s)) { + weights.emplace(s); + } + } + auto trt_converter = [this, ws, &weights, &shape_hints, &exporter2]( const caffe2::NetDef& net) mutable { - return SubnetToOnnxifiOp(net, mapped_ws, ws, &exporter2, &shape_hints); + return SubnetToOnnxifiOp(net, weights, ws, &exporter2, &shape_hints); }; NetDef net_opt = opt::OptimizeForBackend(*pred_net, supports, trt_converter); diff --git a/caffe2/opt/onnxifi_transformer.h b/caffe2/opt/onnxifi_transformer.h index 671cc3fef1c2b..d2fd9c0b396bd 100644 --- a/caffe2/opt/onnxifi_transformer.h +++ b/caffe2/opt/onnxifi_transformer.h @@ -25,15 +25,24 @@ class CAFFE2_API OnnxifiTransformer { void Transform( Workspace* ws, NetDef* pred_net, + const std::vector& external_inputs, const std::unordered_map& shape_hints); + const std::unordered_map& input_mapping() const { + return input_mapping_; + } + + const std::unordered_map& reverse_input_mapping() + const { + return reverse_input_mapping_; + } + private: - // Note that we have two workspaces here as inputs. The first mapped_ws is - // used to mapped SSA names back to c2 original names. The second one is - // actually used to inject more weights into the original workspace + // Since we create new tensors during the conversion process, we actually need + // into inject them into the original workspace caffe2::NetDef SubnetToOnnxifiOp( const caffe2::NetDef& net, - const Workspace& mapped_ws, + const std::unordered_set& weights_in_ws, Workspace* ws, onnx::OnnxExporter* exporter, std::unordered_map* shape_hints); @@ -64,7 +73,11 @@ class CAFFE2_API OnnxifiTransformer { // Backned IDs std::vector backend_ids_; - // Input mapping + + // Input mapping of input name -> original input name std::unordered_map input_mapping_; + + // Input mapping of orignal input name -> input name + std::unordered_map reverse_input_mapping_; }; } // namespace caffe2 diff --git a/caffe2/python/onnx/onnxifi.py b/caffe2/python/onnx/onnxifi.py index be482e5969ad1..07a86390eae9d 100644 --- a/caffe2/python/onnx/onnxifi.py +++ b/caffe2/python/onnx/onnxifi.py @@ -27,6 +27,7 @@ def onnxifi_caffe2_net( # Inject an fake input tensor to help popluate the shape if we # do not do shape inference shape_hints = {} + external_inputs = [] if not infer_shapes: for k, v in input_shapes.items(): need_input_tensor = True @@ -36,10 +37,12 @@ def onnxifi_caffe2_net( need_input_tensor = False if need_input_tensor: workspace.FeedBlob(k, np.random.randn(*v).astype(np.float32)) + external_inputs.append(k) for k, v in input_shapes.items(): shape_hints[k] = v pred_net_str = C.onnxifi(pred_net.SerializeToString(), + external_inputs, shape_hints, infer_shapes, debug) diff --git a/caffe2/python/pybind_state.cc b/caffe2/python/pybind_state.cc index 61474fe7f60ba..cd8549f86682c 100644 --- a/caffe2/python/pybind_state.cc +++ b/caffe2/python/pybind_state.cc @@ -1633,6 +1633,7 @@ void addGlobalMethods(py::module& m) { m.def( "onnxifi", [](const py::bytes& pred_net_str, + const std::vector& external_inputs, const std::unordered_map>& shapes, bool infer_shapes, bool debug_builder) -> py::bytes { @@ -1647,7 +1648,8 @@ void addGlobalMethods(py::module& m) { it.first, CreateTensorShape(it.second, TensorProto::FLOAT)); } OnnxifiTransformer ts(infer_shapes, debug_builder); - ts.Transform(GetCurrentWorkspace(), &pred_net, tensor_shapes); + ts.Transform( + GetCurrentWorkspace(), &pred_net, external_inputs, tensor_shapes); std::string pred_net_str2; pred_net.SerializeToString(&pred_net_str2); return py::bytes(pred_net_str2);