diff --git a/include/tvm/ir/type_relation.h b/include/tvm/ir/type_relation.h index ac205b05504a6..06bcb7207c74b 100644 --- a/include/tvm/ir/type_relation.h +++ b/include/tvm/ir/type_relation.h @@ -208,6 +208,7 @@ class TypeRelationNode : public TypeConstraintNode { bool SEqualReduce(const TypeRelationNode* other, SEqualReducer equal) const { return + equal(func, other->func) && equal(args, other->args) && equal(num_inputs, other->num_inputs) && equal(attrs, other->attrs); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index ec299b59c7365..87b4602095a20 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2142,6 +2142,9 @@ Expr MakeSplit(Expr data, TVM_REGISTER_GLOBAL("relay.op._make.split") .set_body([](const TVMArgs& args, TVMRetValue* rv) { if (args.type_codes[1] == kDLInt) { + // Note: we change it from Int(64) to Int(32) for now as + // combine_parallel_dense will transform the graph with Int(32). + // More invetigation is needs to check which one we should use. *rv = MakeSplit(args[0], tir::make_const(DataType::Int(32), static_cast(args[1])), args[2]); diff --git a/tests/cpp/relay_pass_alpha_equal.cc b/tests/cpp/relay_pass_alpha_equal.cc deleted file mode 100644 index 0207fca00cf76..0000000000000 --- a/tests/cpp/relay_pass_alpha_equal.cc +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include -#include -#include -#include -#include -#include - -using namespace tvm; - -class TestAlphaEquals { - runtime::PackedFunc *_packed_func; - public: - TestAlphaEquals(const char* func_name) { - _packed_func = new runtime::PackedFunc(); - TVMFuncGetGlobal(func_name, reinterpret_cast(&_packed_func)); - } - - void UpdatePackedFunc(const char* func_name) { - TVMFuncGetGlobal(func_name, reinterpret_cast(&_packed_func)); - } - - bool operator()(ObjectRef input_1, ObjectRef input_2) { - TVMRetValue rv; - std::vector values(2); - std::vector codes(2); - runtime::TVMArgsSetter setter(values.data(), codes.data()); - setter(0, input_1); - setter(1, input_2); - _packed_func->CallPacked(TVMArgs(values.data(), codes.data(), 2), &rv); - return bool(rv); - }; - -}; - -TEST(Relay, AlphaTestEmptyTypeNodes) { - auto x = TypeVar("x", kTypeData); - auto y = TypeVar(); - EXPECT_FALSE(relay::AlphaEqual(x, y)); - - TestAlphaEquals test_equals("relay._make._alpha_equal"); - EXPECT_FALSE(test_equals(x, y)); -} - -int main(int argc, char ** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} diff --git a/tests/cpp/relay_pass_type_infer_test.cc b/tests/cpp/relay_pass_type_infer_test.cc index f951a8f386a68..3c416918e4414 100644 --- a/tests/cpp/relay_pass_type_infer_test.cc +++ b/tests/cpp/relay_pass_type_infer_test.cc @@ -18,6 +18,7 @@ */ #include +#include #include #include #include @@ -38,7 +39,7 @@ TEST(Relay, SelfReference) { auto type_fx = mod->Lookup("main"); auto expected = relay::FuncType(tvm::Array{ tensor_type }, tensor_type, {}, {}); - CHECK(relay::AlphaEqual(type_fx->checked_type(), expected)); + CHECK(tvm::StructuralEqual()(type_fx->checked_type(), expected)); } int main(int argc, char ** argv) { diff --git a/tests/cpp/relay_transform_sequential.cc b/tests/cpp/relay_transform_sequential.cc index 756468c9b110f..d974f023d74b6 100644 --- a/tests/cpp/relay_transform_sequential.cc +++ b/tests/cpp/relay_transform_sequential.cc @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -102,7 +103,7 @@ TEST(Relay, Sequential) { auto mod1 = IRModule::FromExpr(expected_func); mod1 = relay::transform::InferType()(mod1); auto expected = mod1->Lookup("main"); - CHECK(relay::AlphaEqual(f, expected)); + CHECK(tvm::StructuralEqual()(f, expected)); } int main(int argc, char** argv) {