Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Mar 28, 2020
1 parent 8c250ba commit de405dd
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 69 deletions.
1 change: 1 addition & 0 deletions include/tvm/ir/type_relation.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ class TypeRelationNode : public TypeConstraintNode {

bool SEqualReduce(const TypeRelationNode* other, SEqualReducer equal) const {
return
equal(func, other->func) &&
equal(args, other->args) &&
equal(num_inputs, other->num_inputs) &&
equal(attrs, other->attrs);
Expand Down
3 changes: 3 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2142,6 +2142,9 @@ Expr MakeSplit(Expr data,
TVM_REGISTER_GLOBAL("relay.op._make.split")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
if (args.type_codes[1] == kDLInt) {
// Note: we change it from Int(64) to Int(32) for now as
// combine_parallel_dense will transform the graph with Int(32).
// More invetigation is needs to check which one we should use.
*rv = MakeSplit(args[0],
tir::make_const(DataType::Int(32), static_cast<int>(args[1])),
args[2]);
Expand Down
67 changes: 0 additions & 67 deletions tests/cpp/relay_pass_alpha_equal.cc

This file was deleted.

3 changes: 2 additions & 1 deletion tests/cpp/relay_pass_type_infer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/

#include <gtest/gtest.h>
#include <tvm/node/structural_equal.h>
#include <tvm/te/operation.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/type.h>
Expand All @@ -38,7 +39,7 @@ TEST(Relay, SelfReference) {
auto type_fx = mod->Lookup("main");

auto expected = relay::FuncType(tvm::Array<relay::Type>{ tensor_type }, tensor_type, {}, {});
CHECK(relay::AlphaEqual(type_fx->checked_type(), expected));
CHECK(tvm::StructuralEqual()(type_fx->checked_type(), expected));
}

int main(int argc, char ** argv) {
Expand Down
3 changes: 2 additions & 1 deletion tests/cpp/relay_transform_sequential.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include <gtest/gtest.h>
#include <topi/generic/injective.h>
#include <tvm/node/structural_equal.h>
#include <tvm/driver/driver_api.h>
#include <tvm/relay/expr.h>
#include <tvm/ir/module.h>
Expand Down Expand Up @@ -102,7 +103,7 @@ TEST(Relay, Sequential) {
auto mod1 = IRModule::FromExpr(expected_func);
mod1 = relay::transform::InferType()(mod1);
auto expected = mod1->Lookup("main");
CHECK(relay::AlphaEqual(f, expected));
CHECK(tvm::StructuralEqual()(f, expected));
}

int main(int argc, char** argv) {
Expand Down

0 comments on commit de405dd

Please sign in to comment.