Skip to content

Commit

Permalink
[VirtualMachine] new method allowing to set one input tensor by its i…
Browse files Browse the repository at this point in the history
…ndex or name (#10293)

* set_input_with_index was implemented for VM

* clean code

* add getInputIndexFromName. add function descriptions. lint fix

* fix lint

* transfer comparison of parameter names number and assigned devices number to VMFunction constructor

* add GetVMFunctionWithName to Executable API

* clean code

* add SetInputWithName (set_input_with_name) to VM API

* join SetInputWithIndex and SetInputWithName to SetOneInputTensor (set_one_input) to VM API, the joined methods were removed

* fix lint

* some fixes after review

* add set_one_input method to python API of VirtualMachine

* pytests for set_input and set_one_input methods of VirtualMachine were implemented and checked

* CI restart

* construct simple model for pytests by relay instead of onnx tools (need for correct CI)

Co-authored-by: Valery Chernov <[email protected]>
  • Loading branch information
vvchernov and Valery Chernov authored Feb 25, 2022
1 parent 308d320 commit d62a364
Show file tree
Hide file tree
Showing 6 changed files with 322 additions and 54 deletions.
7 changes: 7 additions & 0 deletions include/tvm/runtime/vm/executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,13 @@ class Executable : public ModuleNode {
*/
void SetLib(const runtime::Module& lib);

/*!
* \brief Get VMFunction.
* \param func_name The function's name.
* \return VMFunction.
*/
const VMFunction& GetVMFunctionWithName(const std::string& func_name) const;

/*!
* \brief Get the arity of the VMFunction.
* \param func Function name.
Expand Down
55 changes: 54 additions & 1 deletion include/tvm/runtime/vm/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ struct VMFunction {
params(std::move(params)),
instructions(std::move(instructions)),
register_file_size(register_file_size),
param_device_indexes(std::move(param_device_indexes)) {}
param_device_indexes(std::move(param_device_indexes)) {
ICHECK_EQ(params.size(), param_device_indexes.size());
}

VMFunction() = default;

Expand Down Expand Up @@ -270,6 +272,15 @@ class VirtualMachine : public runtime::ModuleNode {
*/
void SetInput(std::string name, TVMArgs args, int offset);

/*!
* \brief Set one input tensor with index or name to a function.
* \param name The function name.
* \param tag index or name of the input tensor .
* \param tensor the input tensor. If the tensor is not of the correct device for the function,
* they will be copied to the device.
*/
void SetOneInput(std::string name, const TVMArgValue& tag, const TVMArgValue& tensor);

/*!
* \brief Internal hook for profiling the start of an op.
*
Expand All @@ -286,6 +297,48 @@ class VirtualMachine : public runtime::ModuleNode {
*/
virtual void OpStopHook();

private:
/*!
* \brief Get index of input tensor from its name.
* \param func_name The function's name.
* \param input_name The input tensor name.
* \return The input tensor index.
*/
int64_t GetInputIndexFromVMFunction(const std::string& func_name,
const std::string& input_name) const;

/*!
* \brief Get index of input tensor from its name.
* \param params parameter names.
* \param input_name The input tensor name.
* \return The input tensor index.
*/
int64_t GetInputIndexFromName(const std::vector<std::string>& params,
const std::string& input_name) const;

/*!
* \brief Check executable exists and get VM function from it.
* \param func_name The function's name.
* \return VM function.
*/
const VMFunction& CheckAndGetVMFunction(const std::string& func_name) const;

/*!
* \brief Creats inputs_ field, if it exists check its size.
* \param func_name The function's name.
* \param size inputs_ field size.
* \return VM function.
*/
void CreateInputsOrCheckSize(const std::string& func_name, size_t size);

/*!
* \brief Set one input tensor with given index to set of input tensors if need copy to given
* device. \param tensors the input tensors set (destination) \param tensor some tensor (not
* neccessary DLTensor). \param index The input tensor index. \param dev device to copy if need.
*/
void SetInputTensorWithIndex(std::vector<ObjectRef>& tensors, // NOLINT(*)
const TVMArgValue& tensor, int index, Device dev);

protected:
/*! \brief The virtual machine's packed function table. */
std::vector<PackedFunc> packed_funcs_;
Expand Down
25 changes: 25 additions & 0 deletions python/tvm/runtime/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ def __init__(self, exe, device, memory_cfg=None):
self._get_num_outputs = self.module["get_num_outputs"]
self._get_input_index = self.module["get_input_index"]
self._set_input = self.module["set_input"]
self._set_one_input = self.module["set_one_input"]
self._setup_device(device, memory_cfg)

def _setup_device(self, dev, memory_cfg):
Expand Down Expand Up @@ -450,6 +451,30 @@ def set_input(self, func_name, *args, **kwargs):
cargs = convert(args)
self._set_input(func_name, *cargs)

def set_one_input(self, func_name, *args, **kwargs):
"""Set the one input tensor with tag to a function.
Parameters
----------
func_name : str
The name of the function.
args : [str or int, tvm.runtime.NDArray]
name or index of tensor and input tensor, optional
kwargs: dict of str or int to tvm.runtime.NDArray, optional
taged arguments to the function.
Only args or kwargs should exist
"""
if kwargs:
assert len(kwargs) == 1
tag = next(iter(kwargs))
if isinstance(tag, str):
func_params = self._exec.get_function_params(func_name)
assert tag in func_params
self._set_one_input(func_name, tag, kwargs[tag])
else:
assert len(args) == 2
self._set_one_input(func_name, args[0], args[1])

def invoke(self, func_name, *args, **kwargs):
"""Invoke a function.
Expand Down
25 changes: 9 additions & 16 deletions src/runtime/vm/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,27 +109,20 @@ PackedFunc Executable::GetFunction(const std::string& name, const ObjectPtr<Obje
}
}

int Executable::GetFunctionArity(std::string func_name) const {
const VMFunction& Executable::GetVMFunctionWithName(const std::string& func_name) const {
auto it = global_map.find(func_name);
if (it == global_map.end()) {
LOG(ERROR) << "Cannot find function " << func_name << " in executable";
return -1;
}
const auto& func = functions[it->second];
ICHECK(it != global_map.end()) << "Cannot find function " << func_name << " in executable";
return functions[it->second];
}

int Executable::GetFunctionArity(std::string func_name) const {
const auto& func = GetVMFunctionWithName(func_name);
return func.params.size();
}

std::string Executable::GetFunctionParameterName(std::string func_name, uint32_t index) const {
auto it = global_map.find(func_name);
if (it == global_map.end()) {
LOG(ERROR) << "Cannot find function " << func_name << " in executable";
return "";
}
const auto& func = functions[it->second];
if (index > func.params.size()) {
LOG(ERROR) << "Invalid parameter index";
return "";
}
const auto& func = GetVMFunctionWithName(func_name);
ICHECK_LT(index, func.params.size()) << "Invalid parameter index";
return func.params[index];
}

Expand Down
122 changes: 85 additions & 37 deletions src/runtime/vm/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,17 +190,7 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
} else if (name == "get_input_index") {
return TypedPackedFunc<int64_t(std::string, std::string)>(
[this](std::string input_name, std::string func_name) {
auto gvit = exec_->global_map.find(func_name);
ICHECK(gvit != exec_->global_map.end()) << "Cannot find function " << func_name;
auto func_index = gvit->second;
const auto& vm_func = exec_->functions[func_index];
const auto& param_names = vm_func.params;
for (uint64_t i = 0; i < param_names.size(); i++) {
if (input_name == param_names[i]) {
return static_cast<int64_t>(i);
}
}
return static_cast<int64_t>(-1);
return GetInputIndexFromVMFunction(func_name, input_name);
});
} else if (name == "init") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
Expand All @@ -221,6 +211,12 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
} else if (name == "set_input") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { SetInput(args[0], args, 1); });
} else if (name == "set_one_input") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
ICHECK_EQ(args.size(), 3) << "The expected number of arguments is 3 "
<< "(func_name, index or name, tensor)";
SetOneInput(args[0], args[1], args[2]);
});
} else if (name == "load_late_bound_consts") {
return PackedFunc([this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.size(), 1);
Expand All @@ -234,39 +230,91 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
}

void VirtualMachine::SetInput(std::string func_name, TVMArgs args, int offset) {
ICHECK(exec_) << "The executable is not created yet.";
auto gvit = exec_->global_map.find(func_name);
ICHECK(gvit != exec_->global_map.end()) << "Cannot find function " << func_name;
auto func_index = gvit->second;
const auto& vm_func = exec_->functions[func_index];
const auto& param_names = vm_func.params;
ICHECK_EQ(args.size() - offset, param_names.size())
const auto& vm_func = CheckAndGetVMFunction(func_name);
size_t params_num = vm_func.params.size();
ICHECK_EQ(args.size() - offset, params_num)
<< "The number of provided parameters doesn't match the number of arguments";
ICHECK_EQ(param_names.size(), vm_func.param_device_indexes.size())
<< "The number of provided parameters doesn't match the number of assigned devices";
std::vector<ObjectRef> func_args(param_names.size());
std::vector<ObjectRef> func_args(params_num);
for (int i = offset; i < args.size(); ++i) {
Device dev = GetDevice(vm_func.param_device_indexes[i - offset]);

if (args[i].type_code() == kTVMDLTensorHandle) {
// Automatically convert input DLTensors to NDArray
DLTensor* tensor = args[i];
std::vector<int64_t> shape;
for (int64_t i = 0; i < tensor->ndim; i++) {
shape.push_back(tensor->shape[i]);
}
NDArray ary = NDArray::Empty(shape, tensor->dtype, dev);
ary.CopyFrom(tensor);
func_args[i - offset] = ary;
} else {
ObjectRef obj = CopyTo(args[i], dev);
func_args[i - offset] = obj;
}
int index = i - offset;
Device dev = GetDevice(vm_func.param_device_indexes[index]);
SetInputTensorWithIndex(func_args, args[i], index, dev);
}
inputs_.erase(func_name);
inputs_.emplace(func_name, func_args);
}

void VirtualMachine::SetOneInput(std::string func_name, const TVMArgValue& tag,
const TVMArgValue& tensor) {
const auto& vm_func = CheckAndGetVMFunction(func_name);
size_t params_num = vm_func.params.size();

int inp_index;
if (tag.type_code() == kTVMArgInt) {
inp_index = tag;
} else if (tag.type_code() == kTVMStr) {
inp_index = static_cast<int>(GetInputIndexFromName(vm_func.params, tag));
} else {
LOG(FATAL) << "The type of input tensor tag (" << tag.type_code()
<< ") doesn't match integer or string";
}
ICHECK_LT(inp_index, params_num);

CreateInputsOrCheckSize(func_name, params_num);
Device dev = GetDevice(vm_func.param_device_indexes[inp_index]);
SetInputTensorWithIndex(inputs_[func_name], tensor, inp_index, dev);
}

int64_t VirtualMachine::GetInputIndexFromVMFunction(const std::string& func_name,
const std::string& input_name) const {
const auto& vm_func = CheckAndGetVMFunction(func_name);
return GetInputIndexFromName(vm_func.params, input_name);
}

int64_t VirtualMachine::GetInputIndexFromName(const std::vector<std::string>& params,
const std::string& input_name) const {
// TODO(vvchernov): excess integer type?
for (uint64_t i = 0; i < params.size(); i++) {
if (input_name == params[i]) {
return static_cast<int64_t>(i);
}
}
return static_cast<int64_t>(-1);
}

const VMFunction& VirtualMachine::CheckAndGetVMFunction(const std::string& func_name) const {
ICHECK(exec_) << "The executable is not created yet.";
return exec_->GetVMFunctionWithName(func_name);
}

void VirtualMachine::CreateInputsOrCheckSize(const std::string& func_name, size_t size) {
if (inputs_.count(func_name)) {
ICHECK_EQ(inputs_[func_name].size(), size)
<< "The size of function" << func_name
<< " doesn't match the number of provided parameters";
} else {
std::vector<ObjectRef> func_args(size);
inputs_.emplace(func_name, func_args);
}
}

void VirtualMachine::SetInputTensorWithIndex(std::vector<ObjectRef>& tensors,
const TVMArgValue& inp_tensor, int index, Device dev) {
if (inp_tensor.type_code() == kTVMDLTensorHandle) {
// Automatically convert input DLTensors to NDArray
DLTensor* tensor = inp_tensor;
std::vector<int64_t> shape;
for (int64_t i = 0; i < tensor->ndim; i++) {
shape.push_back(tensor->shape[i]);
}
NDArray ary = NDArray::Empty(shape, tensor->dtype, dev);
ary.CopyFrom(tensor);
tensors[index] = ary;
} else {
tensors[index] = CopyTo(inp_tensor, dev);
}
}

inline Device VirtualMachine::GetDevice(Index device_index) const {
ICHECK_GE(devices_.size(), device_index) << "invalid device index: " << device_index;
return devices_[device_index];
Expand Down
Loading

0 comments on commit d62a364

Please sign in to comment.