Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

【PaddlePaddle Hackathon 67】Add arange op #919

Merged
merged 8 commits into from
Sep 6, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions cinn/frontend/net_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,18 @@ Variable NetBuilder::Clip(const std::vector<Variable>& inputs, const float& max_
return instr.GetOutput(0);
}

Variable NetBuilder::Arange(const float start, const float stop, const float step, const std::string& dtype) {
Instruction instr("arange");
instr.SetInputs({});
instr.SetAttr("start", start);
instr.SetAttr("stop", stop);
instr.SetAttr("step", step);
instr.SetAttr("dtype", dtype);
InferShape(instr);
AppendInstruction(instr);
return instr.GetOutput(0);
}

// conv2d grad, output(grad_x, grad_w)
std::vector<Variable> NetBuilder::Conv2dGrad(const Variable& dy,
const Variable& x,
Expand Down
2 changes: 2 additions & 0 deletions cinn/frontend/net_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ class NetBuilder : public BaseBuilder {

Variable Clip(const std::vector<Variable>& inputs, const float& max_val, const float& min_val);

Variable Arange(const float start, const float stop, const float step, const std::string& dtype);

// conv2d grad, output(grad_x, grad_w)
std::vector<Variable> Conv2dGrad(const Variable& dy,
const Variable& x,
Expand Down
30 changes: 30 additions & 0 deletions cinn/frontend/net_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,5 +293,35 @@ TEST(net_build, program_execute_cast) {
}
}

TEST(net_build, program_execute_arange) {
float start = 1.5F;
float stop = 31.5F;
float step = 2.0F;
std::string dtype = "float32";

NetBuilder builder("net_builder");
Variable out = builder.Arange(start, stop, step, dtype);
auto program = builder.Build();

Target target = common::DefaultHostTarget();

auto graph = std::make_shared<hlir::framework::Graph>(program, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
auto runtime_program = gc.Build();

scope->Var<hlir::framework::Tensor>(std::string(out->id));

runtime_program->Execute();

auto out_tensor = scope->GetTensor(std::string(out->id));
const std::vector<int>& out_tensor_shape = out_tensor->shape().data();
float* out_data = out_tensor->mutable_data<float>(target);

for (int i = 0; i < out_tensor_shape[0]; ++i) {
VLOG(6) << out_data[i];
zhhsplendid marked this conversation as resolved.
Show resolved Hide resolved
}
}

} // namespace frontend
} // namespace cinn
2 changes: 2 additions & 0 deletions cinn/hlir/op/contrib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ core_gather_headers()
gather_srcs(cinnapi_src SRCS
cast.cc
clip.cc
arange.cc
)

cc_test(test_cast SRCS cast_test.cc DEPS cinncore)
cc_test(test_clip SRCS clip_test.cc DEPS cinncore)
cc_test(test_arange SRCS arange_test.cc DEPS cinncore)
157 changes: 157 additions & 0 deletions cinn/hlir/op/contrib/arange.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "cinn/hlir/op/contrib/arange.h"

#include <gflags/gflags.h>

#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "cinn/common/cas.h"
#include "cinn/common/common.h"
#include "cinn/common/context.h"
#include "cinn/common/macros.h"
#include "cinn/hlir/framework/node.h"
#include "cinn/hlir/framework/op.h"
#include "cinn/hlir/framework/op_strategy.h"
#include "cinn/hlir/pe/nn.h"
#include "cinn/ir/ir.h"
#include "cinn/ir/ir_base.h"
#include "cinn/ir/tensor.h"
#include "cinn/lang/builtin.h"
#include "cinn/lang/compute.h"

DECLARE_bool(cinn_ir_schedule);

namespace cinn {
namespace hlir {
namespace op {

std::vector<ir::Tensor> Arange(
const float start, const float stop, const float step, const Type &dtype, const std::string &output_name) {
int num_elem = static_cast<int>(std::ceil((stop - start) / step));
ir::Tensor res = lang::Compute(
{Expr(num_elem)},
[=](const std::vector<ir::Expr> &indices) {
return ir::Cast::Make(dtype, start + step * cinn::common::cast(indices[0], common::Float(32)));
},
common::UniqName(output_name));
return {res};
}

std::vector<std::vector<int>> InferShapeForArange(const std::vector<std::vector<int>> &inputs_shape,
const framework::AttrMapType &attrs) {
float start = 0.0F;
zhhsplendid marked this conversation as resolved.
Show resolved Hide resolved
float stop = 0.0F;
float step = 1.0F;

if (attrs.find("start") != attrs.end()) {
start = absl::get<float>(attrs.at("start"));
}
if (attrs.find("stop") != attrs.end()) {
stop = absl::get<float>(attrs.at("stop"));
}
if (attrs.find("step") != attrs.end()) {
step = absl::get<float>(attrs.at("step"));
}

int num_elem = static_cast<int>(std::ceil((stop - start) / step));
CHECK_GT(num_elem, 0) << "Invalid arange attributes.";
MayYouBeProsperous marked this conversation as resolved.
Show resolved Hide resolved

std::vector<std::vector<int>> res = {{num_elem}};
return res;
}

std::vector<Type> InferDtypeForArange(const std::vector<Type> &inputs_type, const framework::AttrMapType &attrs) {
std::string dtype = "float32";
if (attrs.find("dtype") != attrs.end()) {
dtype = absl::get<std::string>(attrs.at("dtype"));
}
std::vector<Type> res{common::Str2Type(dtype)};
return res;
}

std::shared_ptr<framework::OpStrategy> StrategyForArange(const framework::NodeAttr &attrs,
const std::vector<ir::Tensor> &inputs,
const std::vector<Type> &out_type,
const std::vector<std::vector<int>> &output_shapes,
const Target &target) {
std::string dtype = "float32";
float start = 0.0F;
float stop = 0.0F;
float step = 1.0F;

for (auto &iter : attrs.attr_store) {
if (iter.first == "dtype") {
dtype = absl::get<std::string>(iter.second);
} else if (iter.first == "start") {
start = absl::get<float>(iter.second);
} else if (iter.first == "stop") {
stop = absl::get<float>(iter.second);
} else if (iter.first == "step") {
step = absl::get<float>(iter.second);
}
}

CHECK_GT(step, 0) << "Invalid arange attributes.";

framework::CINNCompute arange_compute([=](lang::Args args, lang::RetValue *ret) {
CHECK(!args.empty()) << "The input argument of arange compute is empty! Please check.\n";

std::vector<ir::Tensor> out = Arange(start, stop, step, common::Str2Type(dtype), common::UniqName("T_Arange_out"));

CHECK(out.size() == 1U) << "The size of Arange's output should be 1";
std::vector<common::CINNValue> res;
auto stages = CreateStages({});
for (auto &t : out) {
stages->InsertLazily(t);
res.push_back(common::CINNValue(t));
}

res.push_back(common::CINNValue(stages));
*ret = common::CINNValuePack{res};
});

framework::CINNSchedule arange_schedule([=](lang::Args args, lang::RetValue *ret) {
CHECK(!args.empty()) << "The input argument of arange schedule is empty! Please check.\n";
common::CINNValuePack arg_pack = args[0];
Expr out = arg_pack[0];
CHECK(out.as_tensor());
*ret = arg_pack;
});

auto strategy = std::make_shared<framework::OpStrategy>();
strategy->AddImpl(arange_compute, arange_schedule, "strategy.arange.x86", 1);
return strategy;
}

} // namespace op
} // namespace hlir
} // namespace cinn

CINN_REGISTER_HELPER(arange_ops) {
CINN_REGISTER_OP(arange)
.describe("Returns evenly spaced values within a given interval.")
.set_num_inputs(0)
.set_num_outputs(1)
.set_attr<cinn::hlir::framework::StrategyFunction>("CINNStrategy", cinn::hlir::op::StrategyForArange)
.set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForArange))
.set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForArange))
.set_support_level(4);

return true;
}
33 changes: 33 additions & 0 deletions cinn/hlir/op/contrib/arange.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>
#include <vector>

#include "cinn/ir/ir.h"
#include "cinn/ir/ir_base.h"
#include "cinn/ir/tensor.h"

namespace cinn {
namespace hlir {
namespace op {

std::vector<ir::Tensor> Arange(
const float start, const float stop, const float step, const Type& dtype, const std::string& output_name);

} // namespace op
} // namespace hlir
} // namespace cinn
89 changes: 89 additions & 0 deletions cinn/hlir/op/contrib/arange_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "cinn/hlir/op/contrib/arange.h"

#include <glog/logging.h>
#include <gtest/gtest.h>

#include <string>
#include <vector>

#include "cinn/backends/codegen_c.h"
#include "cinn/backends/codegen_c_x86.h"
#include "cinn/backends/codegen_cuda_dev.h"
#include "cinn/common/context.h"
#include "cinn/lang/lower.h"
#include "cinn/lang/placeholder.h"
#include "cinn/poly/stage.h"

namespace cinn {
namespace hlir {
namespace op {

TEST(GenerateCode_Cpu, Arange) {
common::Context::Global().ResetNameId();

common::Target target = common::DefaultHostTarget();
float start = 1.5F;
float stop = 31.5F;
float step = 2.0F;

std::vector<ir::Tensor> res = Arange(start, stop, step, common::Float(32), "test_arange");

poly::StageMap stages = poly::CreateStages({res});
std::vector<ir::LoweredFunc> funcs =
lang::LowerVec("TestGenerateCodeCpu_Arange", stages, res, {}, {}, nullptr, target, true);

VLOG(6) << "Expr before CPU codegen:";
VLOG(6) << funcs[0]->body;

ir::Module::Builder builder("Arange_Module", target);
for (auto &f : funcs) {
builder.AddFunction(f);
}

backends::CodeGenCX86 codegen(target, backends::CodeGenCX86::Feature::AVX512);
codegen.SetInlineBuiltinCodes(false);
std::string code = codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl);
VLOG(6) << "Cpu Codegen result:";
VLOG(6) << code << std::endl;
}

TEST(GenerateCode_Cuda, Arange) {
common::Context::Global().ResetNameId();

common::Target target = common::DefaultNVGPUTarget();
float start = 1.5F;
float stop = 31.5F;
float step = 2.0F;

std::vector<ir::Tensor> res = Arange(start, stop, step, common::Float(32), "test_arange");

poly::StageMap stages = poly::CreateStages({res});
std::vector<ir::LoweredFunc> funcs =
lang::LowerVec("TestGenerateCodeCpu_Arange", stages, res, {}, {}, nullptr, target, true);

VLOG(6) << "Expr before CUDA codegen:";
VLOG(6) << funcs[0]->body;

ir::Module::Builder builder("Arange_Module", target);
for (auto &f : funcs) {
builder.AddFunction(f);
}
zhhsplendid marked this conversation as resolved.
Show resolved Hide resolved
}

} // namespace op
} // namespace hlir
} // namespace cinn
1 change: 1 addition & 0 deletions cinn/hlir/op/use_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ CINN_USE_REGISTER(transform_ops)
CINN_USE_REGISTER(cast_ops)
CINN_USE_REGISTER(reduce_ops)
CINN_USE_REGISTER(clip_ops)
CINN_USE_REGISTER(arange_ops)