From 48c046a165dbcd6338031b14d8105d4b0528a449 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Mon, 7 Oct 2019 14:25:50 -0700 Subject: [PATCH] Fix serialization test --- tests/python/relay/test_vm_serialization.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/python/relay/test_vm_serialization.py b/tests/python/relay/test_vm_serialization.py index a32ec2768540..7bf91ce355b5 100644 --- a/tests/python/relay/test_vm_serialization.py +++ b/tests/python/relay/test_vm_serialization.py @@ -28,18 +28,18 @@ from tvm.contrib import util from tvm.relay import testing -def create_vm(f, ctx=tvm.cpu(), target="llvm"): +def create_vm(f, ctx=tvm.cpu(), target="llvm", params=None): if isinstance(f, relay.Expr): mod = relay.Module() mod["main"] = f compiler = relay.vm.VMCompiler() - vm = compiler.compile(mod, target) + vm = compiler.compile(mod, target=target, params=params) vm.init(ctx) return vm else: assert isinstance(f, relay.Module), "expected mod as relay.Module" compiler = relay.vm.VMCompiler() - vm = compiler.compile(f, target) + vm = compiler.compile(f, target=target, params=params) vm.init(ctx) return vm @@ -61,7 +61,7 @@ def get_vm_output(mod, data, params, target, ctx, dtype='float32'): return result.asnumpy().astype(dtype) def get_serialized_output(mod, data, params, target, ctx, dtype='float32'): - vm = create_vm(mod, ctx, target) + vm = create_vm(mod, ctx, target, params=params) ser = serializer.Serializer(vm) code, lib = ser.serialize() deser = deserializer.Deserializer(code, lib)