Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
MXNet-TRT: Add PrePartition param caching - move init_tensorrt_param…
Browse files Browse the repository at this point in the history
…s logic (#18490)

* Update to TRT 7 API

Signed-off-by: Serge Panev <[email protected]>

* Add PrePartition param caching - move init_tensorrt_params logic

Signed-off-by: Serge Panev <[email protected]>

* Handle node with no defined input

Signed-off-by: Serge Panev <[email protected]>

* Remove tmp comment

Signed-off-by: Serge Panev <[email protected]>
  • Loading branch information
Kh4L authored Aug 5, 2020
1 parent 59e200a commit 7b7cef5
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 11 deletions.
2 changes: 1 addition & 1 deletion src/operator/subgraph/build_subgraph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ void SortEntries(const std::unordered_map<const nnvm::NodeEntry*, size_t>& entry
}

/*!
* \brief Given a subgraph, find the output entries of a subgraph.
* \brief Given a subgraph, find the input entries of a subgraph.
* \param g pointer to the whole graph
* \param simple_nods vector of simple nodes in top sorted order
* \param subgraph_nodes vector of pointers of simples of a subgraph.
Expand Down
4 changes: 3 additions & 1 deletion src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ std::tuple<unique_ptr<nvinfer1::ICudaEngine>,

auto trt_logger = std::unique_ptr<TRT_Logger>(new TRT_Logger(verbosity));
auto trt_builder = InferObject(nvinfer1::createInferBuilder(*trt_logger));
auto trt_network = InferObject(trt_builder->createNetwork());
const auto explicitBatch = 1U << static_cast<uint32_t>(
nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
auto trt_network = InferObject(trt_builder->createNetworkV2(explicitBatch));
auto trt_parser = InferObject(nvonnxparser::createParser(*trt_network, *trt_logger));
::ONNX_NAMESPACE::ModelProto parsed_model;
// We check for a valid parse, but the main effect is the side effect
Expand Down
48 changes: 42 additions & 6 deletions src/operator/subgraph/tensorrt/tensorrt-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,23 @@ class TensorrtProperty : public SubgraphProperty {
return std::make_shared<TensorrtProperty>();
}

void PrePartition(const nnvm::Graph& g,
const std::vector<std::pair<std::string, std::string>>& options_map) override {
auto& in_arg_names = g.GetAttr<std::vector<std::string>>("in_arg_names");
auto& in_aux_names = g.GetAttr<std::vector<std::string>>("in_aux_names");
NDArray **in_args_ptr = g.GetAttr<NDArray**>("in_args");
NDArray **in_aux_ptr = g.GetAttr<NDArray**>("in_aux");
in_args_dict.clear();
in_aux_dict.clear();
// we trust the Python API, len(in_arg_names) == len(in_args_ptr)
for (unsigned i = 0; i < in_arg_names.size(); ++i) {
in_args_dict[in_arg_names[i]] = in_args_ptr[i];
}
for (unsigned i = 0; i < in_aux_names.size(); ++i) {
in_aux_dict[in_aux_names[i]] = in_aux_ptr[i];
}
}

nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol &sym,
const int subgraph_id) const override {
nnvm::ObjectPtr n = nnvm::Node::Create();
Expand All @@ -280,16 +297,33 @@ class TensorrtProperty : public SubgraphProperty {
n->attrs.op = Op::Get("_TensorRT");
CHECK(n->attrs.op);
n->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(new_sym));

// Mapping subgraph params with NDArrays
TRTParam param;
std::ostringstream params_oss;
for (auto &e : new_sym.ListInputNames(nnvm::Symbol::kAll)) {
params_oss << e << ";";
for (auto &param_name : new_sym.ListInputNames(nnvm::Symbol::kAll)) {
NDArray *cache;
auto it_args = in_args_dict.find(param_name);
if (it_args != in_args_dict.end()) {
cache = it_args->second;
} else {
auto it_aux = in_aux_dict.find(param_name);
if (it_aux != in_aux_dict.end()) {
cache = it_aux->second;
}
}
if (cache != nullptr) {
param.params_map.emplace(param_name, cache->Copy(Context()));
param.params_map[param_name].WaitToRead();
params_oss << param_name << ";";
}
}
auto tensorrt_params_names = params_oss.str();
tensorrt_params_names.pop_back();
n->attrs.dict["subgraph_params_names"] = tensorrt_params_names;
TRTParam param;
if (!tensorrt_params_names.empty()) {
tensorrt_params_names.pop_back();
}
n->attrs.parsed = param;
n->op()->attr_parser(&(n->attrs));
n->attrs.dict["subgraph_params_names"] = tensorrt_params_names;
return n;
}

Expand Down Expand Up @@ -328,6 +362,8 @@ class TensorrtProperty : public SubgraphProperty {
}
subgraph_node->attrs.parsed = std::move(_params);
}

std::unordered_map<std::string, NDArray*> in_args_dict, in_aux_dict;
};


Expand Down
6 changes: 3 additions & 3 deletions src/operator/subgraph/tensorrt/tensorrt.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ void TRTCompute(const OpStatePtr& state, const OpContext& ctx,
param.bindings->at(i) = outputs[p.first].dptr_;
}
}
const int batch_size = static_cast<int>(inputs[0].shape_[0]);
param.trt_executor->enqueue(batch_size, param.bindings->data(), cuda_s, nullptr);
param.trt_executor->enqueueV2(param.bindings->data(), cuda_s, nullptr);
}

NNVM_REGISTER_OP(_TensorRT)
.set_attr<FStatefulCompute>("FStatefulCompute<gpu>", TRTCompute);
.set_attr<FStatefulCompute>("FStatefulCompute<gpu>", TRTCompute)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);

} // namespace op
} // namespace mxnet
Expand Down

0 comments on commit 7b7cef5

Please sign in to comment.