diff --git a/cinn/frontend/net_builder_test.cc b/cinn/frontend/net_builder_test.cc index deb1a4206d..568df7b91b 100644 --- a/cinn/frontend/net_builder_test.cc +++ b/cinn/frontend/net_builder_test.cc @@ -197,7 +197,7 @@ TEST(net_build, program_execute_pool2d_grad) { NetBuilder builder("net_builder"); Placeholder input = builder.CreateInput(Float(32), {B, C, H, W}, "In"); - Variable output = builder.Squeeze(Float(32), {1, 3}); + Variable output = builder.Squeeze(input, {1, 3}); auto program = builder.Build(); Target target = common::DefaultHostTarget(); @@ -232,7 +232,7 @@ TEST(net_build, program_execute_pool2d_grad) { int index = w + W * (h + H * (c + C * b)); float in_data = input_data[index]; float out_data = output_data[index]; - line += (std::to_string(data) + ", "); + line += (std::to_string(out_data) + ", "); EXPECT_EQ(in_data, out_data); } VLOG(6) << line; diff --git a/cinn/hlir/op/contrib/squeeze.h b/cinn/hlir/op/contrib/squeeze.h index a166565314..3d4db8326f 100644 --- a/cinn/hlir/op/contrib/squeeze.h +++ b/cinn/hlir/op/contrib/squeeze.h @@ -25,7 +25,7 @@ namespace cinn { namespace hlir { namespace op { -ir::Tensor Squeeze(const ir::Tensor& A, const std::vector& axis, poly::StageMap stages, const std::string& name); +ir::Tensor Squeeze(const ir::Tensor& A, const std::vector& axis, const std::string& name); } // namespace op } // namespace hlir diff --git a/cinn/hlir/op/contrib/squeeze_test.cc b/cinn/hlir/op/contrib/squeeze_test.cc index ac6440f180..00e4811955 100644 --- a/cinn/hlir/op/contrib/squeeze_test.cc +++ b/cinn/hlir/op/contrib/squeeze_test.cc @@ -32,7 +32,7 @@ namespace cinn { namespace hlir { namespace op { -TEST(Squeeze, SqueezeCase0) { +TEST(GenerateCode_Cpu, Squeeze) { common::Context::Global().ResetNameId(); common::Target target = common::DefaultHostTarget(); @@ -47,7 +47,7 @@ TEST(Squeeze, SqueezeCase0) { poly::StageMap stages = poly::CreateStages({res}); std::vector funcs = - lang::LowerVec("TestGenerateCodeCpu_Squeeze", stages, res, {}, {}, nullptr, target, true); + lang::LowerVec("TestGenerateCodeCpu_Squeeze", stages, {res}, {}, {}, nullptr, target, true); VLOG(6) << "Expr before CPU codegen:"; VLOG(6) << funcs[0]->body;