diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index ce0df9532d661..55c011adbe019 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -30,6 +30,7 @@ from . import nn from . import annotation from . import memory +from . import dialect from . import image from . import vision from . import op_attrs diff --git a/python/tvm/relay/op/dialect/__init__.py b/python/tvm/relay/op/dialect/__init__.py new file mode 100644 index 0000000000000..107af6eb59646 --- /dev/null +++ b/python/tvm/relay/op/dialect/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import +"""Operators for manipulating low level memory.""" +from __future__ import absolute_import as _abs +from .vm import * diff --git a/python/tvm/relay/op/dialect/_make.py b/python/tvm/relay/op/dialect/_make.py new file mode 100644 index 0000000000000..fc34f4e78ed9f --- /dev/null +++ b/python/tvm/relay/op/dialect/_make.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Constructor APIs""" +import tvm._ffi + +tvm._ffi._init_api("relay.op.dialect._make", __name__) diff --git a/python/tvm/relay/op/dialect/vm.py b/python/tvm/relay/op/dialect/vm.py new file mode 100644 index 0000000000000..0107e18b009d0 --- /dev/null +++ b/python/tvm/relay/op/dialect/vm.py @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return,invalid-name,len-as-condition,too-many-nested-blocks +"""Dialect operators for Relay VM.""" +from . import _make + + +def shape_of(expr): + """Invoke a function to get the shape of a tensor. + + Parameters + ---------- + expr : tvm.relay.Expr + The expr used to evaluate its tensor shape. + + Returns + ------- + result : tvm.relay.Expr + The expression with the evaluated tensor shape. + """ + return _make.shape_of(expr) diff --git a/python/tvm/relay/transform/memory_alloc.py b/python/tvm/relay/transform/memory_alloc.py index 07169ff8ac5eb..e42022915cc04 100644 --- a/python/tvm/relay/transform/memory_alloc.py +++ b/python/tvm/relay/transform/memory_alloc.py @@ -44,7 +44,7 @@ class ManifestAllocPass(ExprMutator): def __init__(self, target_host): self.invoke_tvm = op.memory.invoke_tvm_op self.shape_func = op.memory.shape_func - self.shape_of = op.memory.shape_of + self.shape_of = op.dialect.shape_of self.scopes = [ScopeBuilder()] self.target_host = target_host self.default_context = cpu(0) diff --git a/src/relay/op/dialect/vm.cc b/src/relay/op/dialect/vm.cc new file mode 100644 index 0000000000000..60fea7ed114e8 --- /dev/null +++ b/src/relay/op/dialect/vm.cc @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/op/dialect/vm.cc + * \brief Dialect operators for Relay VM. + */ + +#include +#include +#include +#include +#include + +#include "../../transforms/infer_layout_util.h" +#include "../op_common.h" +#include "../type_relations.h" + +namespace tvm { +namespace relay { + +// Forward declare the shape_of type relation function. +bool ShapeOfRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter); + +RELAY_REGISTER_OP("vm.shape_of") + .describe(R"code(Get the shape of an input tensor. +)code" TVM_ADD_FILELINE) + .set_num_inputs(1) + .add_argument("tensor", "Tensor", "The input tensor") + .add_type_rel("ShapeOf", ShapeOfRel) + .set_support_level(10) + .set_attr("TOpPattern", kOpaque) + .set_attr("TOpIsStateful", false) + .set_attr("TNonComputational", true) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); + +TVM_REGISTER_GLOBAL("relay.op.dialect._make.shape_of").set_body_typed([](Expr expr) { + auto attrs = make_object(); + attrs->dtype = DataType::Int(64); + static const Op& op = Op::Get("vm.shape_of"); + return Call(op, {expr}, Attrs(attrs), {}); +}); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/op/memory/memory.cc b/src/relay/op/memory/memory.cc index 7094c6cff1551..e5081adbf6a7d 100644 --- a/src/relay/op/memory/memory.cc +++ b/src/relay/op/memory/memory.cc @@ -36,10 +36,6 @@ namespace tvm { namespace relay { -// Forward declare the shape_of type relation function. -bool ShapeOfRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter); - TVM_REGISTER_NODE_TYPE(AllocStorageAttrs); TVM_REGISTER_NODE_TYPE(AllocTensorAttrs); TVM_REGISTER_NODE_TYPE(ShapeFuncAttrs); @@ -427,24 +423,5 @@ RELAY_REGISTER_OP("memory.shape_func") return {topi::identity(inputs[0])}; }); -RELAY_REGISTER_OP("vm.shape_of") - .describe(R"code(Get the shape of an input tensor. -)code" TVM_ADD_FILELINE) - .set_num_inputs(1) - .add_argument("tensor", "Tensor", "The input tensor") - .add_type_rel("ShapeOf", ShapeOfRel) - .set_support_level(10) - .set_attr("TOpPattern", kOpaque) - .set_attr("TOpIsStateful", false) - .set_attr("TNonComputational", true) - .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); - -TVM_REGISTER_GLOBAL("relay.op.memory._make.shape_of").set_body_typed([](Expr expr) { - auto attrs = make_object(); - attrs->dtype = DataType::Int(64); - static const Op& op = Op::Get("vm.shape_of"); - return Call(op, {expr}, Attrs(attrs), {}); -}); - } // namespace relay } // namespace tvm