diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h index b9ccbf99ef02..0cce533afe62 100644 --- a/include/tvm/runtime/vm.h +++ b/include/tvm/runtime/vm.h @@ -114,6 +114,7 @@ enum class Opcode { LoadConsti = 14U, Fatal = 15U, AllocStorage = 16U, + ShapeOf = 17U, }; /*! \brief A single virtual machine instruction. @@ -245,6 +246,9 @@ struct Instruction { /*! \brief The hint of the dtype. */ DLDataType dtype_hint; } alloc_storage; + struct /* ShapeOf Operands */ { + RegName tensor; + } shape_of; }; /*! @@ -389,6 +393,14 @@ struct Instruction { static Instruction AllocStorage(RegName size, Index alignment, DLDataType dtype_hint, RegName dst); + /*! + * \brief Get the shape of an input tensor. + * \param tensor The input tensor. + * \param dst The destination to store the shape of the given tensor. + * \return The shape of instruction. + */ + static Instruction ShapeOf(RegName tensor, RegName dst); + Instruction(); Instruction(const Instruction& instr); Instruction& operator=(const Instruction& instr); diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index ce0df9532d66..a45d466a2623 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -27,6 +27,7 @@ from .tensor import * from .transform import * from .algorithm import * +from .vm import * from . import nn from . import annotation from . import memory diff --git a/python/tvm/relay/op/vm/__init__.py b/python/tvm/relay/op/vm/__init__.py new file mode 100644 index 000000000000..2ac1e5743cb1 --- /dev/null +++ b/python/tvm/relay/op/vm/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import +"""Dialect operators for Relay VM.""" +from __future__ import absolute_import as _abs +from . import vm diff --git a/python/tvm/relay/op/vm/_ffi_api.py b/python/tvm/relay/op/vm/_ffi_api.py new file mode 100644 index 000000000000..3eeeeb811859 --- /dev/null +++ b/python/tvm/relay/op/vm/_ffi_api.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for relay.op.vm""" +import tvm._ffi + +tvm._ffi._init_api("relay.op.vm", __name__) diff --git a/python/tvm/relay/op/vm/vm.py b/python/tvm/relay/op/vm/vm.py new file mode 100644 index 000000000000..680729df88eb --- /dev/null +++ b/python/tvm/relay/op/vm/vm.py @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return,invalid-name,len-as-condition,too-many-nested-blocks +"""Dialect operators for Relay VM.""" +from . import _ffi_api + + +def shape_of(expr): + """Invoke a function to get the shape of a tensor. + + Parameters + ---------- + expr : tvm.relay.Expr + The expr used to evaluate its tensor shape. + + Returns + ------- + result : tvm.relay.Expr + The expression with the evaluated tensor shape. + """ + return _ffi_api.shape_of(expr) diff --git a/python/tvm/relay/transform/memory_alloc.py b/python/tvm/relay/transform/memory_alloc.py index 6c081cbac0de..a7ba2a8a5678 100644 --- a/python/tvm/relay/transform/memory_alloc.py +++ b/python/tvm/relay/transform/memory_alloc.py @@ -44,6 +44,7 @@ class ManifestAllocPass(ExprMutator): def __init__(self, target_host): self.invoke_tvm = op.memory.invoke_tvm_op self.shape_func = op.memory.shape_func + self.shape_of = op.vm.shape_of self.scopes = [ScopeBuilder()] self.target_host = target_host self.default_context = cpu(0) @@ -53,9 +54,6 @@ def __init__(self, target_host): def current_scope(self): return self.scopes[-1] - def shape_of(self, e): - return op.shape_of(e, self.compute_dtype) - def visit_tuple(self, tup): scope = self.current_scope() new_fields = [] diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 0b839a2a8f52..2151acfb216f 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -283,6 +283,7 @@ class VMFunctionCompiler : ExprFunctor { case Opcode::Invoke: case Opcode::AllocClosure: case Opcode::AllocStorage: + case Opcode::ShapeOf: case Opcode::Move: case Opcode::InvokeClosure: last_register_ = instr.dst; @@ -588,6 +589,18 @@ class VMFunctionCompiler : ExprFunctor { auto outputs = Downcast(args[2]); EmitShapeFunc(shape_func, inputs->fields, outputs->fields); }) + .Match("vm.shape_of", + [this](const Array& args, const Attrs& attrs, const Array& type_arg) { + CHECK_EQ(args.size(), 1U); + // Get the attributes. + const auto* shape_of_attrs = attrs.as(); + CHECK(shape_of_attrs) << "Must be the shape_of attrs"; + CHECK_EQ(shape_of_attrs->dtype.bits(), 64) + << "The dtype of shape of must be int64, but got" + << DLDataType2String(shape_of_attrs->dtype); + this->VisitExpr(args[0]); + Emit(Instruction::ShapeOf(last_register_, NewRegister())); + }) .Match("memory.kill", [](const Array& args, const Attrs& attrs, const Array& type_arg) { LOG(FATAL) << "memory.kill is not yet supported"; diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc index 6b72670babaf..99e6c026f8d1 100644 --- a/src/relay/op/tensor/unary.cc +++ b/src/relay/op/tensor/unary.cc @@ -396,20 +396,6 @@ RELAY_REGISTER_UNARY_OP("bitwise_not") // shape_of TVM_REGISTER_NODE_TYPE(ShapeOfAttrs); -bool ShapeOfRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { - CHECK_EQ(num_inputs, 1); - auto tt = types[0].as(); - if (tt == nullptr) { - return false; - } - const auto* param = attrs.as(); - CHECK(param != nullptr); - auto rank_shape = RankShape(tt->shape); - reporter->Assign(types[1], TensorType(rank_shape, param->dtype)); - return true; -} - Array ShapeOfCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { CHECK_EQ(inputs.size(), 1); diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index 46143d16c96d..0647ec9780f3 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -25,6 +25,7 @@ #include "./type_relations.h" #include +#include #include #include #include @@ -146,5 +147,19 @@ Array RankShape(const Array& shape) { } } +bool ShapeOfRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(num_inputs, 1); + auto tt = types[0].as(); + if (tt == nullptr) { + return false; + } + const auto* param = attrs.as(); + CHECK(param != nullptr); + auto rank_shape = RankShape(tt->shape); + reporter->Assign(types[1], TensorType(rank_shape, param->dtype)); + return true; +} + } // namespace relay } // namespace tvm diff --git a/src/relay/op/type_relations.h b/src/relay/op/type_relations.h index acd4b2dae1be..5ab8b121ae9d 100644 --- a/src/relay/op/type_relations.h +++ b/src/relay/op/type_relations.h @@ -79,6 +79,18 @@ bool IdentityCompRel(const Array& types, int num_inputs, const Attrs& attr Array RankShape(const Array& shape); +/*! + * \brief The shape of type relation. + * + * \param types The input and output types to the relation. + * \param num_inputs The number of input arguments. + * \param attrs The attributes + * \param reporter The reporter. + * \return true whether relation has been resolved. + */ +bool ShapeOfRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter); + } // namespace relay } // namespace tvm diff --git a/src/relay/op/vm/vm.cc b/src/relay/op/vm/vm.cc new file mode 100644 index 000000000000..af33100add31 --- /dev/null +++ b/src/relay/op/vm/vm.cc @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/op/vm/vm.cc + * \brief Dialect operators for Relay VM. + */ + +#include +#include +#include +#include +#include + +#include "../../transforms/infer_layout_util.h" +#include "../op_common.h" +#include "../type_relations.h" + +namespace tvm { +namespace relay { + +RELAY_REGISTER_OP("vm.shape_of") + .describe(R"code(Get the shape of an input tensor. +)code" TVM_ADD_FILELINE) + .set_num_inputs(1) + .add_argument("tensor", "Tensor", "The input tensor") + .add_type_rel("ShapeOf", ShapeOfRel) + .set_support_level(10) + .set_attr("TOpPattern", kOpaque) + .set_attr("TOpIsStateful", false) + .set_attr("TNonComputational", true) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); + +TVM_REGISTER_GLOBAL("relay.op.vm.shape_of").set_body_typed([](Expr expr) { + auto attrs = make_object(); + attrs->dtype = DataType::Int(64); + static const Op& op = Op::Get("vm.shape_of"); + return Call(op, {expr}, Attrs(attrs), {}); +}); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index b2eab8f96987..50de8711a4c1 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -81,6 +81,7 @@ class ConstantFolder : public ExprMutator { : executor_(executor), module_(module), shape_of_op_(Op::Get("shape_of")), + vm_shape_of_op_(Op::Get("vm.shape_of")), invoke_tvm_op_(Op::Get("memory.invoke_tvm_op")), shape_func_op_(Op::Get("memory.shape_func")), alloc_tensor_op_(Op::Get("memory.alloc_tensor")), @@ -123,7 +124,7 @@ class ConstantFolder : public ExprMutator { // skip stateful ops. if (op_stateful.get(GetRef(op), false)) return res; // Try to evaluate shape_of op - if (call->op == shape_of_op_) { + if (call->op == shape_of_op_ || call->op == vm_shape_of_op_) { return EvaluateShapeOf(res, origin_args, call->attrs); } @@ -166,6 +167,7 @@ class ConstantFolder : public ExprMutator { // Cache the following ops for equivalence checking in this pass. const Op& shape_of_op_; + const Op& vm_shape_of_op_; const Op& invoke_tvm_op_; const Op& shape_func_op_; const Op& alloc_tensor_op_; diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index 65b1a2f29707..f5204044ac78 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -417,6 +417,11 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { fields.push_back(instr.pc_offset); break; } + case Opcode::ShapeOf: { + // Number of fields = 2 + fields.assign({instr.shape_of.tensor, instr.dst}); + break; + } default: LOG(FATAL) << "Invalid opcode" << static_cast(instr.op); break; @@ -683,6 +688,11 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) { DCHECK_EQ(instr.fields.size(), 1U); return Instruction::Goto(instr.fields[0]); } + case Opcode::ShapeOf: { + // Number of fields = 2 + DCHECK_EQ(instr.fields.size(), 2U); + return Instruction::ShapeOf(instr.fields[0], instr.fields[1]); + } default: LOG(FATAL) << "Invalid opcode" << instr.opcode; return Instruction(); diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 0c0ca350f444..6b10a89d969a 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -145,6 +145,9 @@ Instruction::Instruction(const Instruction& instr) { case Opcode::AllocStorage: this->alloc_storage = instr.alloc_storage; return; + case Opcode::ShapeOf: + this->shape_of.tensor = instr.shape_of.tensor; + return; default: std::ostringstream out; out << "Invalid instruction " << static_cast(instr.op); @@ -239,6 +242,9 @@ Instruction& Instruction::operator=(const Instruction& instr) { case Opcode::AllocStorage: this->alloc_storage = instr.alloc_storage; return *this; + case Opcode::ShapeOf: + this->shape_of.tensor = instr.shape_of.tensor; + return *this; default: std::ostringstream out; out << "Invalid instruction " << static_cast(instr.op); @@ -258,6 +264,7 @@ Instruction::~Instruction() { case Opcode::Goto: case Opcode::LoadConsti: case Opcode::AllocStorage: + case Opcode::ShapeOf: case Opcode::Fatal: return; case Opcode::AllocTensor: @@ -351,6 +358,14 @@ Instruction Instruction::AllocStorage(RegName size, Index alignment, DLDataType return instr; } +Instruction Instruction::ShapeOf(RegName tensor, Index dst) { + Instruction instr; + instr.op = Opcode::ShapeOf; + instr.dst = dst; + instr.shape_of.tensor = tensor; + return instr; +} + Instruction Instruction::AllocADT(Index tag, Index num_fields, const std::vector& datatype_fields, Index dst) { Instruction instr; @@ -585,6 +600,10 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { << DLDataType2String(instr.alloc_storage.dtype_hint); break; } + case Opcode::ShapeOf: { + os << "shape_of $" << instr.dst << " $" << instr.shape_of.tensor; + break; + } default: LOG(FATAL) << "should never hit this case" << static_cast(instr.op); break; @@ -1057,6 +1076,18 @@ void VirtualMachine::RunLoop() { pc_++; goto main_loop; } + case Opcode::ShapeOf: { + auto input = ReadRegister(instr.shape_of.tensor); + NDArray input_array = Downcast(input); + int ndim = input_array->ndim; + auto out_tensor = NDArray::Empty({ndim}, {kDLInt, 64, 1}, {kDLCPU, 0}); + for (int i = 0; i < ndim; ++i) { + reinterpret_cast(out_tensor->data)[i] = input_array->shape[i]; + } + WriteRegister(instr.dst, out_tensor); + pc_++; + goto main_loop; + } case Opcode::Ret: { // If we have hit the point from which we started // running, we should return to the caller breaking diff --git a/tests/python/relay/test_vm_serialization.py b/tests/python/relay/test_vm_serialization.py index 5d20651a8126..95e6c6f3c89e 100644 --- a/tests/python/relay/test_vm_serialization.py +++ b/tests/python/relay/test_vm_serialization.py @@ -19,7 +19,6 @@ import numpy as np import tvm -from tvm import te from tvm.runtime import vm as _vm from tvm.relay import vm as rly_vm from tvm import relay @@ -41,11 +40,15 @@ def create_exec(f, target="llvm", params=None): return executable -def veval(vm, *args, ctx=tvm.cpu()): - assert isinstance(vm, _vm.VirtualMachine), "expected VirtualMachine" - ret = vm.run(*args) - return ret - +def get_serialized_output(mod, *data, params=None, target="llvm", + ctx=tvm.cpu()): + exe = create_exec(mod, target, params=params) + code, lib = exe.save() + des_exec = _vm.Executable.load_exec(code, lib) + des_vm = _vm.VirtualMachine(des_exec) + des_vm.init(ctx) + result = des_vm.run(*data) + return result def run_network(mod, params, @@ -56,24 +59,16 @@ def get_vm_output(mod, data, params, target, ctx, dtype='float32'): result = ex.evaluate()(data, **params) return result.asnumpy().astype(dtype) - def get_serialized_output(mod, data, params, target, ctx, dtype='float32'): - exe = create_exec(mod, target, params=params) - code, lib = exe.save() - des_exec = _vm.Executable.load_exec(code, lib) - des_vm = _vm.VirtualMachine(des_exec) - des_vm.init(ctx) - result = des_vm.run(data) - return result.asnumpy().astype(dtype) - data = np.random.uniform(size=data_shape).astype(dtype) target = "llvm" ctx = tvm.cpu(0) tvm_out = get_vm_output(mod, tvm.nd.array(data.astype(dtype)), params, target, ctx, dtype) - vm_out = get_serialized_output(mod, tvm.nd.array(data.astype(dtype)), params, - target, ctx, dtype) - tvm.testing.assert_allclose(vm_out, tvm_out, rtol=1e-5, atol=1e-5) + vm_out = get_serialized_output(mod, tvm.nd.array(data.astype(dtype)), + params=params, target=target, ctx=ctx) + tvm.testing.assert_allclose(vm_out.asnumpy().astype(dtype), tvm_out, + rtol=1e-5, atol=1e-5) def test_serializer(): @@ -143,7 +138,7 @@ def test_save_load(): des_vm = _vm.VirtualMachine(des_exec) des_vm.init(tvm.cpu()) - res = veval(des_vm, x_data) + res = des_vm.run(x_data) tvm.testing.assert_allclose(res.asnumpy(), x_data + x_data) @@ -151,14 +146,8 @@ def test_const(): c = relay.const(1.0, "float32") x = relay.var('x', shape=(10, 10), dtype='float32') f = relay.Function([x], x + c) - exe = create_exec(f) - code, lib = exe.save() - assert isinstance(code, bytearray) - des_exec = _vm.Executable.load_exec(code, lib) - des_vm = _vm.VirtualMachine(des_exec) - des_vm.init(tvm.cpu()) x_data = np.random.rand(10, 10).astype('float32') - res = veval(des_vm, x_data) + res = get_serialized_output(f, x_data) tvm.testing.assert_allclose(res.asnumpy(), x_data + 1) @@ -172,18 +161,12 @@ def test_if(): x_data = np.random.rand(10, 10).astype('float32') y_data = np.random.rand(10, 10).astype('float32') - exe = create_exec(f) - code, lib = exe.save() - des_exec = _vm.Executable.load_exec(code, lib) - des_vm = _vm.VirtualMachine(des_exec) - des_vm.init(tvm.cpu()) - # same - res = veval(des_vm, x_data, x_data) + res = get_serialized_output(f, x_data, x_data) tvm.testing.assert_allclose(res.asnumpy(), x_data) # diff - res = veval(des_vm, x_data, y_data) + res = get_serialized_output(f, x_data, y_data) tvm.testing.assert_allclose(res.asnumpy(), y_data) @@ -208,13 +191,7 @@ def test_loop(): aarg = relay.var('accum', shape=[], dtype='int32') mod["main"] = relay.Function([iarg, aarg], sum_up(iarg, aarg)) - exe = create_exec(mod) - code, lib = exe.save() - des_exec = _vm.Executable.load_exec(code, lib) - des_vm = _vm.VirtualMachine(des_exec) - des_vm.init(tvm.cpu()) - - result = veval(des_vm, i_data, accum_data) + result = get_serialized_output(mod, i_data, accum_data) tvm.testing.assert_allclose(result.asnumpy(), sum(range(1, loop_bound + 1))) @@ -225,13 +202,7 @@ def test_tuple(): i_data = np.random.rand(41).astype('float32') j_data = np.random.rand(10).astype('float32') - exe = create_exec(f) - code, lib = exe.save() - des_exec = _vm.Executable.load_exec(code, lib) - des_vm = _vm.VirtualMachine(des_exec) - des_vm.init(tvm.cpu()) - - result = veval(des_vm, (i_data, j_data)) + result = get_serialized_output(f, (i_data, j_data)) tvm.testing.assert_allclose(result.asnumpy(), j_data) @@ -246,13 +217,7 @@ def test_adt_list(): f = relay.Function([], l321) mod["main"] = f - exe = create_exec(mod) - code, lib = exe.save() - des_exec = _vm.Executable.load_exec(code, lib) - des_vm = _vm.VirtualMachine(des_exec) - des_vm.init(tvm.cpu()) - - result = veval(des_vm) + result = get_serialized_output(mod) assert len(result) == 2 assert len(result[1]) == 2 assert len(result[1][1]) == 2 @@ -292,15 +257,8 @@ def test_adt_compose(): f = relay.Function([y], add_two_body) mod["main"] = f - exe = create_exec(mod) - code, lib = exe.save() - des_exec = _vm.Executable.load_exec(code, lib) - des_vm = _vm.VirtualMachine(des_exec) - des_vm.init(tvm.cpu()) - x_data = np.array(np.random.rand()).astype('float32') - result = veval(des_vm, x_data) - + result = get_serialized_output(mod, x_data) tvm.testing.assert_allclose(result.asnumpy(), x_data + 2.0) @@ -312,13 +270,7 @@ def test_closure(): clo = ff(relay.const(1.0)) main = clo(relay.const(2.0)) - exe = create_exec(main) - code, lib = exe.save() - des_exec = _vm.Executable.load_exec(code, lib) - des_vm = _vm.VirtualMachine(des_exec) - des_vm.init(tvm.cpu()) - - res = veval(des_vm) + res = get_serialized_output(main) tvm.testing.assert_allclose(res.asnumpy(), 3.0) @@ -332,6 +284,20 @@ def test_mobilenet(): run_network(mod, params) +def test_vm_shape_of(): + x = relay.var('x', shape=(relay.Any(), relay.Any(), relay.Any()), dtype="float32") + relu_x = relay.nn.relu(x) + data = np.random.uniform(size=(2, 3, 4)).astype('float32') + args = [data] + + newshape_var = relay.var('newshape', shape=(2,), dtype='int64') + args.append(np.array((1, -1), dtype='int64')) + main = relay.reshape(relu_x, newshape=newshape_var) + + res = get_serialized_output(main, *args).asnumpy() + tvm.testing.assert_allclose(res.flatten(), data.flatten()) + + if __name__ == "__main__": test_serializer() test_save_load() @@ -344,3 +310,4 @@ def test_mobilenet(): test_closure() test_resnet() test_mobilenet() + test_vm_shape_of()