From 0dfdcafba03a50c524a471e037023956bcfa3371 Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Tue, 23 Apr 2019 14:09:59 -0700 Subject: [PATCH] [Relay] C++ Build module --- include/tvm/build_module.h | 4 +- src/codegen/build_module.cc | 9 +- src/relay/backend/build_module.cc | 680 ++++++++++++++++++++ tests/python/relay/test_cpp_build_module.py | 77 +++ 4 files changed, 768 insertions(+), 2 deletions(-) create mode 100644 src/relay/backend/build_module.cc create mode 100644 tests/python/relay/test_cpp_build_module.py diff --git a/include/tvm/build_module.h b/include/tvm/build_module.h index 3c136444229b6..33f246717d905 100644 --- a/include/tvm/build_module.h +++ b/include/tvm/build_module.h @@ -351,12 +351,14 @@ TVM_DLL Array lower(Schedule sch, * \param target The target device to build for. * \param target_host The target for building host code. To use the default, pass Target() * \param config The build configuration. +* \param (optional) returned host functions * \return The built module. */ TVM_DLL runtime::Module build(const Array& funcs, const Target& target, const Target& target_host, - const BuildConfig& config); + const BuildConfig& config, + Array* fhost_ret = nullptr); class GenericFuncNode; diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index 65542dd508109..867d7c2ee6e71 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -425,7 +425,8 @@ Array lower(Schedule sch, runtime::Module build(const Array& funcs, const Target& target, const Target& target_host, - const BuildConfig& config) { + const BuildConfig& config, + Array* fhost_ret) { std::unordered_set all_names; for (const auto &x : funcs) { CHECK(all_names.count(x->name) == 0) << "Duplicate function name " << x->name; @@ -464,6 +465,12 @@ runtime::Module build(const Array& funcs, } } + if (fhost_ret != nullptr) { + for (auto f : fhost) { + fhost_ret->push_back(f); + } + } + auto keys = target->keys(); bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != keys.end(); diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc new file mode 100644 index 0000000000000..7f037e9fce546 --- /dev/null +++ b/src/relay/backend/build_module.cc @@ -0,0 +1,680 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * \file relay/backend/build_module.cc + * \brief Graph runtime codegen + */ + + + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "utils.h" + +namespace tvm { +namespace relay { +namespace backend { + +/*! + * \brief Context name / index + */ +struct ContextMap { + static const std::unordered_map mask2str; + static const std::unordered_map str2mask; + static std::string Mask2Str(int mask) { + CHECK_GT(mask2str.count(mask), 0) << "Unknown mask."; + return mask2str.at(mask); + } + static int Str2Mask(const std::string& str) { + CHECK_GT(str2mask.count(str), 0) << "Unknown context."; + return str2mask.at(str); + } + static std::unordered_map _declare_mask2str() { + std::unordered_map ret; + ret[1] = "cpu"; + ret[2] = "gpu"; + ret[4] = "opencl"; + ret[5] = "aocl"; + ret[6] = "sdaccel"; + ret[7] = "valkan"; + ret[8] = "metal"; + ret[9] = "vpi"; + ret[10] = "rocm"; + ret[11] = "opengl"; + ret[12] = "ext_dev"; + return ret; + } + static std::unordered_map _declare_str2mask() { + std::unordered_map ret; + ret["llvm"] = 1; + ret["stackvm"] = 1; + ret["cpu"] = 1; + ret["c"] = 1; + ret["gpu"] = 2; + ret["cuda"] = 2; + ret["nvptx"] = 2; + ret["cl"] = 4; + ret["opencl"] = 4; + ret["aocl"] = 5; + ret["aocl_sw_emu"] = 5; + ret["sdaccel"] = 6; + ret["vulkan"] = 7; + ret["metal"] = 8; + ret["vpi"] = 9; + ret["rocm"] = 10; + ret["opengl"] = 11; + ret["ext_dev"] = 12; + return ret; + } +}; + +const std::unordered_map ContextMap::mask2str = + ContextMap::_declare_mask2str(); +const std::unordered_map ContextMap::str2mask = + ContextMap::_declare_str2mask(); + +/*! \brief Optimization pass level */ +struct OptPassLevel { + static const std::unordered_map _data; + static std::unordered_map _declare_opt_level() { + std::unordered_map ret; + ret["SimplifyInference"] = 0; + ret["OpFusion"] = 1; + ret["FoldConstant"] = 2; + ret["CombineParallelConv2D"] = 3; + ret["FoldScaleAxis"] = 3; + ret["AlterOpLayout"] = 3; + ret["CanonicalizeOps"] = 3; + ret["EliminateCommonSubexpr"] = 3; + return ret; + } + int operator[](const std::string& key) const { + auto it = _data.find(key); + if (it == _data.end()) { + return -1; + } + return it->second; + } +}; + +const std::unordered_map OptPassLevel::_data = + OptPassLevel::_declare_opt_level(); + +/*! \brief Output of function building */ +struct BuildOutput { + std::string graph_json; + runtime::Module mod; + std::unordered_map params; +}; + +/*! \brief Relay Building configuration */ +struct RelayBuildConfig { + int opt_level{2}; + std::string fall_back_device{"llvm"}; + std::unordered_set add_pass; + std::unordered_set disabled_pass; + OptPassLevel OPT_PASS_LEVEL; + inline bool pass_enabled(std::string pass_name) const { + if (disabled_pass.count(pass_name)) { + return false; + } + if (add_pass.count(pass_name)) { + return true; + } + return opt_level >= OPT_PASS_LEVEL[pass_name]; + } +}; + +/*! \brief GraphCodegen wrapper */ +struct GraphCodegen { + public: + GraphCodegen() { + auto pf = GetPackedFunc("relay.build_module._GraphRuntimeCodegen"); + mod = (*pf)(); + } + ~GraphCodegen() {} + + void Init(runtime::Module* m, + Map targets) { + Array tgts; + for (auto kv : targets) { + tgts.push_back(kv.first); + tgts.push_back(kv.second); + } + _CallFunc("init", m, tgts); + } + + void Codegen(Function func) { + _CallFunc("codegen", func); + } + + std::string GetJSON() { + return _CallFunc("get_graph_json", nullptr); + } + + Map > GetLoweredFunc() { + return _CallFunc > >("get_lowered_funcs", nullptr); + } + + std::unordered_map GetParams() { + std::unordered_map ret; + auto names = _CallFunc >("list_params_name", nullptr); + for (auto expr : names) { + auto key = expr.as()->value; + ret[key] = _CallFunc("get_param_by_name", key); + } + return ret; + } + + protected: + tvm::runtime::Module mod; + template + R _CallFunc(const std::string &name, Args... args) { + auto pf = mod.GetFunction(name, false); + return pf(std::forward(args)...); + } + template + void _CallFunc(const std::string &name, Args... args) { + auto pf = mod.GetFunction(name, false); + pf(std::forward(args)...); + return; + } +}; + +template +R _CallPacked(const std::string &name, Args... args) { + auto pf = GetPackedFunc(name); + return (*pf)(std::forward(args)...); +} + +template +Function _CallPacked(const std::string &name, Args... args) { + auto pf = GetPackedFunc(name); + return (*pf)(std::forward(args)...); +} + + +class RelayBuildModule : public runtime::ModuleNode { + public: + /*! + * \brief Get member function to front-end + * \param name The name of the function. + * \param sptr_to_self The pointer to the module node. + * \return The corresponding member function. + */ + virtual PackedFunc GetFunction(const std::string& name, + const std::shared_ptr& sptr_to_self) { + if (name == "get_graph_json") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->_GetGraphJSON(); + }); + } else if (name == "get_module") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->_GetModule(); + }); + } else if (name == "build") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + CHECK_EQ(args.num_args, 3); + Array tmp = args[1]; + std::unordered_map targets; + for (size_t i = 0; i < tmp.size(); i += 2) { + auto k = tmp[i].as()->value; + auto v = tmp[i + 1].as()->value; + targets[k] = v; + } + this->_Build(args[0], targets, args[2]); + }); + } else if (name == "list_params") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->_ListParamNames(); + }); + } else if (name == "get_param_by_name") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + CHECK_EQ(args.num_args, 1); + *rv = this->_GetParam(args[0]); + }); + } else if (name == "set_opt_level") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + CHECK_EQ(args.num_args, 1); + int level = args[0]; + this->_SetOptLevel(level); + }); + } else if (name == "set_fallback_device") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + std::string dev = args[0]; + this->_SetFallBackDev(dev); + }); + } else if (name == "add_pass") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + std::string pass_name = args[0]; + this->_AddPass(pass_name); + }); + } else if (name == "disable_pass") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + std::string pass_name = args[0]; + this->_DisablePass(pass_name); + }); + } else { + return PackedFunc([sptr_to_self](TVMArgs args, TVMRetValue* rv) {}); + } + } + + /*! + * \brief Get the GraphJSON for runtime + * + * \return const std::string graph_json + */ + const std::string _GetGraphJSON() { + return ret_.graph_json; + } + /*! + * \brief Add extra pass during build + * + * \param pass_name + */ + void _AddPass(const std::string& pass_name) { + cfg_.add_pass.insert(pass_name); + } + + void _DisablePass(const std::string& pass_name) { + cfg_.disabled_pass.insert(pass_name); + } + + void _SetFallBackDev(const std::string& dev) { + cfg_.fall_back_device = dev; + } + /*! + * \brief Get the Module object + * + * \return runtime::Module + */ + runtime::Module _GetModule() { + return ret_.mod; + } + + /*! + * \brief List all paramter names + * + * \return Array + */ + Array _ListParamNames() { + Array ret; + for (const auto& kv : params_) { + ret.push_back(ir::StringImm::make(kv.first)); + } + return ret; + } + + /*! + * \brief Get the Param of name + * + * \param name + * \return runtime::NDArray + */ + runtime::NDArray _GetParam(const std::string& name) { + CHECK_GT(params_.count(name), 0) << "Can not find param with name: " << name; + return params_[name]; + } + + /*! + * \brief Set the parameters + * + * \param name name of parameter + * \param data_in input DLTensor + */ + void _SetParams(const std::string& name, DLTensor* data_in) { + if (!params_.count(name)) { + std::vector shape(data_in->shape, data_in->shape + data_in->ndim); + params_[name] = tvm::runtime::NDArray::Empty(shape, data_in->dtype, {kDLCPU, 0}); + } + params_[name].CopyFrom(data_in); + } + + /*! + * \brief Set the optimization level + * + * \param level + */ + void _SetOptLevel(char level) { + cfg_.opt_level = level; + } + + /*! + * \brief type key + * + * \return const char* + */ + const char* type_key() const final { + return "RelayBuildModule"; + } + + /*! + * \brief Build relay function for graph runtime + * + * \param func Relay Function + * \param target Target device + * \param target_host Host target device + */ + void _Build(Function func, + const std::unordered_map& targets, + const std::string& target_host) { + targets_ = targets; + target_host_ = target_host; + _BuildRelay(func, cfg_, params_); + } + + protected: + /*! + * \brief bind params to function + * + * \param func Relay function + * \param params params dict + * \return relay::Function + */ + relay::Function _bind_params_by_name(relay::Function func, + const std::unordered_map& params) { + std::unordered_map name_dict; + std::unordered_set repeat_var; + for (auto arg : func->params) { + const auto &name = arg->name_hint(); + if (name_dict.count(name)) { + repeat_var.insert(arg); + } else { + name_dict[name] = arg; + } + } + + std::unordered_map bind_dict; + for (auto &kv : params) { + if (name_dict.count(kv.first) == 0) { + continue; + } + auto arg = name_dict.at(kv.first); + if (repeat_var.count(arg)) { + LOG(FATAL) << "Multiple args in the function have name " << kv.first; + } + auto e = _CallPacked("relay._make.Constant", kv.second); + bind_dict[arg] = e; + } + return _CallPacked("relay._expr.Bind", func, tvm::Map(bind_dict)); + } + + /*! + * \brief Optimize Relay function + * + * \param func Input function + * \param target target device + * \param cfg Relay build config + * \param params params dict + * \return relay::Function + */ + relay::Function _Optimize(relay::Function func, + const std::unordered_map& targets, + const RelayBuildConfig& cfg, + const std::unordered_map& params) { + if (params.size()) { + func = _bind_params_by_name(func, params); + } + if (cfg.pass_enabled("SimplifyInference")) { + func = _CallPacked("relay._ir_pass.infer_type", func, nullptr); + func = _CallPacked("relay._ir_pass.simplify_inference", func); + } + if (cfg.pass_enabled("EliminateCommonSubexpr")) { + auto fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { + Expr expr = args[0]; + if (expr.as()) { + auto call_node = expr.as(); + auto op_node = call_node->op.as(); + if (op_node->name == "cast") { + auto attrs = call_node->attrs.as(); + if (attrs->dtype == HalideIR::Int(32)) { + return true; + } + } + } + return false; + }); + func = _CallPacked("relay._ir_pass.infer_type", func, nullptr); + func = _CallPacked("relay._ir_pass.eliminate_common_subexpr", func, fskip); + } + if (cfg.pass_enabled("CombineParallelConv2D")) { + func = _CallPacked("relay._ir_pass.infer_type", func, nullptr); + func = _CallPacked("relay._ir_pass.CombineParallelConv2D", func); + } + if (cfg.pass_enabled("FoldConstant")) { + func = _CallPacked("relay._ir_pass.FoldConstant", func); + } + if (cfg.pass_enabled("FoldScaleAxis")) { + func = _CallPacked("relay._ir_pass.infer_type", func, nullptr); + func = _CallPacked("relay._ir_pass.backward_fold_scale_axis", func); + func = _CallPacked("relay._ir_pass.infer_type", func, nullptr); + func = _CallPacked("relay._ir_pass.forward_fold_scale_axis", func); + func = _CallPacked("relay._ir_pass.FoldConstant", func); + } + if (cfg.pass_enabled("CanonicalizeOps")) { + func = _CallPacked("relay._ir_pass.infer_type", func, nullptr); + func = _CallPacked("relay._ir_pass.canonicalize_ops", func); + } + if (cfg.pass_enabled("AlterOpLayout")) { + if (targets.size() == 1) { + func = _CallPacked("relay._ir_pass.infer_type", func, nullptr); + func = _CallPacked("relay._ir_pass.AlterOpLayout", func); + } else { + LOG(WARNING) << "AlterOpLayout pass is not enabled for heterogeneous" + << " execution yet."; + } + } + if (cfg.pass_enabled("FoldConstant")) { + func = _CallPacked("relay._ir_pass.FoldConstant", func); + } + return func; + } + + Map _UpdateHeterogeneousInputs( + const std::unordered_map& targets, + const RelayBuildConfig& cfg) { + Map device_target; + std::unordered_map tmp_map; + auto fallback_idx = ContextMap::Str2Mask(cfg.fall_back_device); + + for (const auto& kv : targets) { + tmp_map[ContextMap::Str2Mask(kv.first)] = kv.second; + } + if (tmp_map.count(fallback_idx) == 0) { + tmp_map[fallback_idx] = cfg.fall_back_device; + } + for (const auto& kv : tmp_map) { + device_target.Set( + ir::IntImm::make(HalideIR::Int(64), kv.first), + ir::StringImm::make(kv.second)); + } + return device_target; + } + + Function _RunDeviceAnnotationPass( + Function func, + const RelayBuildConfig& cfg, + Map* targets_map_ptr) { + auto fallback_idx = ContextMap::Str2Mask(cfg.fall_back_device); + func = _CallPacked("relay._ir_pass.infer_type", func, nullptr); + func = _CallPacked("relay._ir_pass.RewriteDeviceAnnotation", func, fallback_idx); + auto device_map = _CallPacked >("relay._ir_pass.CollectDeviceInfo", + func, + nullptr); + if (device_map.size() == 0) { + auto annotation_map = _CallPacked >("_ir_pass.CollectDeviceAnnotationOps", + func, + nullptr); + if (annotation_map.size() == 0) { + targets_map_ptr->Set( + ir::IntImm::make(HalideIR::Int(64), 0), + ir::StringImm::make(cfg.fall_back_device)); + } else { + int64_t dev_type = -1; + for (auto kv : annotation_map) { + dev_type = kv.second->value; + break; + } + for (auto kv : annotation_map) { + CHECK_EQ(kv.second->value, dev_type) + << "Expressions in the function are " + << "annotated with various device types," + << "but not device copy operators " + << "found. Please check the " + << "RewriteAnnotation pass."; + } + targets_map_ptr->Set( + ir::IntImm::make(HalideIR::Int(64), 0), + ir::StringImm::make(ContextMap::Mask2Str(dev_type))); + } + } + return func; + } + /*! + * \brief Build module given lowered functions + * + * \param lowered_funcs + * \param targets + * \param cfg + */ + void _BuildModule(Map > lowered_funcs, + Map targets, + const BuildConfig& cfg) { + auto target_host = Target::create(cfg_.fall_back_device); + for (const auto& kv : lowered_funcs) { + std::unordered_set fname_set; + for (auto f : kv.second) { + if (fname_set.count(f->name)) { + LOG(FATAL) << "Duplicate function name " + << f->name; + } + fname_set.insert(f->name); + } + } + std::unordered_map target_map; + for (const auto& kv : lowered_funcs) { + target_map[kv.first] = Target::create(kv.first); + } + Array fhost_all; + std::vector device_module; + for (const auto& kv : lowered_funcs) { + auto target = target_map[kv.first]; + Array funcs; + for (auto f : kv.second) { + if (f->func_type == kHostFunc) { + fhost_all.push_back(f); + } else if (f->func_type == kDeviceFunc) { + auto fs = _CallPacked >("ir_pass.SplitHostDevice", f, nullptr); + CHECK_GT(fs.size(), 0); + fhost_all.push_back(fs[0]); + for (size_t i = 1; i < fs.size(); ++i) { + funcs.push_back(fs[i]); + } + } else { + funcs.push_back(f); + } + auto mdev = build(funcs, target, target_host, cfg, &fhost_all); + device_module.push_back(mdev); + } + } + + auto mhost = build(fhost_all, + target_host, + target_host, + cfg); + + for (auto mdev : device_module) { + mhost.Import(mdev); + } + ret_.mod = mhost; + } + + /*! + * \brief Build relay function to runtime module + * + * \param func Relay Function + * \param target target device + * \param target_host host device + * \param cfg Relay build config + * \param params params + * \return BuildOutput + */ + void _BuildRelay(relay::Function func, + const RelayBuildConfig& cfg, + const std::unordered_map ¶ms) { + // convert + tvm_cfg_ = build_config(); + auto device_target = _UpdateHeterogeneousInputs(targets_, cfg); + func = _Optimize(func, targets_, cfg, params); + if (targets_.size() > 1) { + func = _RunDeviceAnnotationPass(func, cfg, &device_target); + } + func = _CallPacked("relay._ir_pass.infer_type", func, nullptr); + func = _CallPacked("relay._ir_pass.FuseOps", func, cfg.opt_level); + func = _CallPacked("relay._ir_pass.infer_type", func, nullptr); + + graph_codegen_ = std::unique_ptr(new GraphCodegen()); + graph_codegen_->Init(nullptr, device_target); + graph_codegen_->Codegen(func); + + ret_.graph_json = graph_codegen_->GetJSON(); + ret_.params = graph_codegen_->GetParams(); + + _BuildModule(graph_codegen_->GetLoweredFunc(), + device_target, + tvm_cfg_); + } + + protected: + std::unique_ptr graph_codegen_; + /*! \brief target device */ + std::unordered_map targets_; + /*! \brief target host device */ + std::string target_host_; + /*! \brief frontend optimization configure */ + RelayBuildConfig cfg_; + /*! \brief parameters */ + std::unordered_map params_; + /*! \brief building output */ + BuildOutput ret_; + /*! \brief tvm building cfg */ + BuildConfig tvm_cfg_; +}; + +runtime::Module RelayBuildCreate() { + std::shared_ptr exec = std::make_shared(); + return runtime::Module(exec); +} + +TVM_REGISTER_GLOBAL("relay.build_module._BuildModule").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = RelayBuildCreate(); +}); + +} // namespace backend +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_cpp_build_module.py b/tests/python/relay/test_cpp_build_module.py new file mode 100644 index 0000000000000..4728c4f936129 --- /dev/null +++ b/tests/python/relay/test_cpp_build_module.py @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np + +import tvm +from tvm import relay + +from tvm._ffi.function import _init_api +_init_api("tvm.relay.build_module") + +class BuildModule(object): + def __init__(self): + self.mod = relay.build_module._BuildModule() + self._get_graph_json = self.mod["get_graph_json"] + self._get_module = self.mod["get_module"] + self._build = self.mod["build"] + + def build(self, func, target, target_host): + tgts = [] + for kv in target.items(): + tgts.append(kv[0]) + tgts.append(kv[1]) + self._build(func, tgts, target_host) + + def get_json(self): + return self._get_graph_json() + + def get_module(self): + return self._get_module() + +def test_build(): + m_bld = BuildModule() + # func + a = relay.var("a", dtype="float32", shape=(16, 8)) + b = relay.var("b", dtype="float32", shape=(8, 8)) + c = relay.var("c", dtype="float32", shape=(16, 8)) + x = relay.nn.dense(a, b) + y = relay.nn.relu(x) + z = y + c + func = relay.Function([a, b, c], z) + # build + targets = { + "cpu": "llvm -mcpu=sse3" + } + m_bld.build(func, targets, "llvm -mcpu=sse3") + g_json = m_bld.get_json() + mmod = m_bld.get_module() + + + # test + A = tvm.nd.array(np.random.uniform(-1, 1, (16, 8)).astype("float32")) + B = tvm.nd.array(np.random.uniform(-1, 1, (8, 8)).astype("float32")) + C = tvm.nd.array(np.random.uniform(-1, 1, (16, 8)).astype("float32")) + + rt = tvm.contrib.graph_runtime.create(g_json, mmod, tvm.cpu()) + rt.set_input("a", A) + rt.set_input("b", B) + rt.set_input("c", C) + rt.run() + out = rt.get_output(0) + + np.testing.assert_allclose(out.asnumpy(), + np.maximum(np.dot(A.asnumpy(), B.asnumpy().T), 0) + C.asnumpy(), atol=1e-5, rtol=1e-5) \ No newline at end of file