-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add cinn_instruction_run_op for launching execution of a cinn instruction #39435
Changes from 7 commits
f7e9811
81c7ebe
cc4e106
a645c3e
f64f8c2
b6fff8d
3220d21
79c0bb7
370cfb0
3937d3c
2462113
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/operators/cinn/cinn_instruction_run_op.h" | ||
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" | ||
#include "paddle/fluid/operators/cinn/cinn_launch_context.h" | ||
#include "paddle/fluid/platform/enforce.h" | ||
|
||
namespace paddle::operators { | ||
|
||
class CinnInstructionRunOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
void InferShape(framework::InferShapeContext* ctx) const override { | ||
OP_INOUT_CHECK(ctx->HasInputs(kX), "Input", kX, "CinnInstructionRun"); | ||
OP_INOUT_CHECK(ctx->HasOutputs(kOutputs), "Output", kOutputs, | ||
"CinnInstructionRun"); | ||
const CinnCompiledObject& compiled_object = | ||
CinnCompiler::GetInstance()->GetCompiledObject( | ||
ctx->Attrs().Get<int64_t>(kCachedIndex)); | ||
|
||
details::CinnLaunchContext* launch_context = | ||
compiled_object.launch_context.get(); | ||
std::vector<std::string> output_args = ctx->Outputs(kOutputs); | ||
std::vector<framework::DDim> output_dims; | ||
std::transform(output_args.begin(), output_args.end(), output_dims.begin(), | ||
[launch_context](const std::string& var_name) { | ||
cinn_buffer_t* buffer = | ||
launch_context->GetCinnBufferOfVar(var_name); | ||
return framework::DDim(buffer->dims, buffer->dimensions); | ||
}); | ||
ctx->SetOutputsDim(kOutputs, output_dims); | ||
} | ||
}; | ||
|
||
class CinnInstructionRunOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
void Make() override { | ||
AddInput(kX, | ||
"(vector<LoDTensor>)" | ||
"which are the input arguments of this cinn instruction") | ||
.AsDuplicable(); | ||
AddOutput(kOutputs, | ||
"(vector<LoDTensor>)" | ||
"which are the output arguments of this cinn instruction") | ||
.AsDuplicable(); | ||
AddAttr<int64_t>( | ||
kCachedIndex, | ||
"(int64_t)" | ||
"the stored index of the cached compilation result in CinnCompiler," | ||
"which is used to fetch the CinnCompiledObject where this cinn " | ||
"instruction is included"); | ||
AddAttr<int64_t>( | ||
kInstructionIndex, | ||
"(int64_t)" | ||
"the index of this instruction to the cinn runtime program"); | ||
AddComment(R"DOC( | ||
CinnInstructionRun Operator. | ||
|
||
This operator is used to launch a | ||
CINN(https://github.com/PaddlePaddle/CINN/blob/develop/README.md) instruction execution | ||
|
||
Both the input and output of this operator are a set of variables | ||
which are the input and output arguments of the bound cinn instruction respectively. | ||
In addition, there is an attribute named 'cached_index' should be | ||
set necessarily to get the CinnCompiledObject where the instruction is included | ||
and 'instruction_index' is fetch the instruction object from complied runtime prograrm. | ||
|
||
It accomplishes the execution of the instruction according to the following steps: | ||
0. Set the shapes ot the output variables at InferShape function with | ||
compilation result. | ||
1. Fetch the cinn instruction bound to this operator by 'cached_index' | ||
and 'instruction_index' from CinnCompiler. | ||
2. Prepare the input and output variables of the instruction in Paddle and share | ||
their buffers to CINN by setting 'memory' of according cinn_buffer_t. | ||
3. Launch CINN runtime to execute the instruction. | ||
|
||
)DOC"); | ||
} | ||
}; | ||
|
||
} // namespace paddle::operators | ||
|
||
namespace ops = paddle::operators; | ||
using CPUDeviceContext = paddle::platform::CPUDeviceContext; | ||
REGISTER_OPERATOR( | ||
cinn_instruction_run, ops::CinnInstructionRunOp, | ||
ops::CinnInstructionRunOpMaker, | ||
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, | ||
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); | ||
REGISTER_OP_CPU_KERNEL( | ||
cinn_instruction_run, | ||
ops::CinnInstructionRunOpKernel<CPUDeviceContext, bool>, | ||
ops::CinnInstructionRunOpKernel<CPUDeviceContext, int>, | ||
ops::CinnInstructionRunOpKernel<CPUDeviceContext, int64_t>, | ||
ops::CinnInstructionRunOpKernel<CPUDeviceContext, float>, | ||
ops::CinnInstructionRunOpKernel<CPUDeviceContext, double>); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/fluid/operators/cinn/cinn_instruction_run_op.h" | ||
#include "paddle/fluid/framework/op_registry.h" | ||
|
||
namespace ops = paddle::operators; | ||
using CUDADeviceContext = paddle::platform::CUDADeviceContext; | ||
REGISTER_OP_CUDA_KERNEL( | ||
cinn_instruction_run, | ||
ops::CinnInstructionRunOpKernel<CUDADeviceContext, bool>, | ||
ops::CinnInstructionRunOpKernel<CUDADeviceContext, int>, | ||
ops::CinnInstructionRunOpKernel<CUDADeviceContext, int64_t>, | ||
ops::CinnInstructionRunOpKernel<CUDADeviceContext, float>, | ||
ops::CinnInstructionRunOpKernel<CUDADeviceContext, double>); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include <iterator> | ||
#include <memory> | ||
#include <string> | ||
#include <vector> | ||
#include "cinn/hlir/framework/instruction.h" | ||
#include "paddle/fluid/framework/op_registry.h" | ||
#include "paddle/fluid/framework/operator.h" | ||
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" | ||
#include "paddle/fluid/operators/cinn/cinn_launch_context.h" | ||
#include "paddle/fluid/operators/cinn/cinn_op_helper.h" | ||
|
||
namespace paddle::operators { | ||
|
||
using CinnInstruction = ::cinn::hlir::framework::Instruction; | ||
using CinnCompiledObject = framework::paddle2cinn::CinnCompiledObject; | ||
using CinnCompiler = framework::paddle2cinn::CinnCompiler; | ||
|
||
template <typename DeviceContext, typename T> | ||
class CinnInstructionRunOpKernel : public framework::OpKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext& ctx) const override { | ||
// step 1: fetch the cinn instruction bound to this operator | ||
auto cached_index = ctx.template Attr<int64_t>(kCachedIndex); | ||
auto ins_index = ctx.template Attr<int64_t>(kInstructionIndex); | ||
const CinnCompiledObject& compiled_object = | ||
CinnCompiler::GetInstance()->GetCompiledObject(cached_index); | ||
const std::vector<std::unique_ptr<CinnInstruction>>& instructions = | ||
compiled_object.runtime_program->GetRunInstructions(); | ||
PADDLE_ENFORCE_LT(ins_index, instructions.size(), | ||
platform::errors::InvalidArgument( | ||
"Index(%ld) > instructions.size(%ld).", ins_index, | ||
instructions.size())); | ||
auto&& instruction = instructions.at(ins_index); | ||
|
||
// step 2: prepare the input and output arguments of the instruction | ||
details::CinnLaunchContext* launch_context = | ||
compiled_object.launch_context.get(); | ||
auto share_argument_buffer_fn = [launch_context, | ||
&ctx](const std::string& var_name) { | ||
cinn_buffer_t* buffer = launch_context->GetCinnBufferOfVar(var_name); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这儿是否可能返回为空指针或者野指针啊?特别是在多线程环境下?建议加个 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. GetCinnBufferOfVar函数中有 ENFORCE 判断,若是变量名不存在会报错 |
||
framework::Variable* var = ctx.scope().GetVar(var_name); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同理,这儿要不先判断下 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. scope.GetVar 函数中有 ENFORCE 判断 |
||
auto* tensor = var->template GetMutable<framework::LoDTensor>(); | ||
buffer->memory = | ||
reinterpret_cast<uint8_t*>(tensor->mutable_data<T>(ctx.GetPlace())); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这儿为啥要转为 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 裸指针,是。cinn 会根据实际数据类型读取 |
||
}; | ||
std::vector<std::string> in_args = ctx.InputNames(kX); | ||
std::for_each(in_args.begin(), in_args.end(), share_argument_buffer_fn); | ||
std::vector<std::string> out_args = ctx.OutputNames(kOutputs); | ||
std::for_each(out_args.begin(), out_args.end(), share_argument_buffer_fn); | ||
|
||
// step 3: launch CINN runtime to execute the instruction | ||
// TODO(CtfGo): simplify format of arguments package as a vector in CINN | ||
// and update this usage call | ||
instruction->Run(&launch_context->FinalizeArguments(), false, | ||
details::GetStream<DeviceContext>(ctx)); | ||
} | ||
}; | ||
|
||
} // namespace paddle::operators |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,12 +24,31 @@ CinnLaunchContext::CinnLaunchContext( | |
const std::unordered_map<std::string, std::string>& paddle2cinn_varmap, | ||
const std::shared_ptr<CinnScope>& cinn_scope) | ||
: paddle2cinn_varmap_(paddle2cinn_varmap), cinn_scope_(cinn_scope) { | ||
// generate all names of cinn used variables | ||
auto var_names = cinn_scope_->var_names(); | ||
cinn_variable_names_.reserve(var_names.size()); | ||
std::transform( | ||
var_names.begin(), var_names.end(), | ||
std::inserter(cinn_variable_names_, cinn_variable_names_.end()), | ||
[](const auto& name_view) { return std::string(name_view.data()); }); | ||
// build the variable name map of cinn2paddle | ||
for (const auto& x : paddle2cinn_varmap_) { | ||
auto res = cinn2paddle_varmap_.emplace(x.second, x.first); | ||
PADDLE_ENFORCE_EQ( | ||
res.second, true, | ||
platform::errors::InvalidArgument( | ||
"Cinn variable(%s) maps to more than one paddle variable(%s,%s)", | ||
x.second, res.first->second, x.first)); | ||
} | ||
// supplement the relations of the remain variables not appearing in above | ||
// map, | ||
// they are internal variables and here we use the name from cinn compiled. | ||
for (const auto& var_name : cinn_variable_names_) { | ||
if (!cinn2paddle_varmap_.count(var_name)) { | ||
cinn2paddle_varmap_.emplace(var_name, var_name); | ||
paddle2cinn_varmap_.emplace(var_name, var_name); | ||
} | ||
} | ||
} | ||
|
||
void CinnLaunchContext::UpdateCapturedEnv(const framework::Scope& scope, | ||
|
@@ -189,6 +208,20 @@ CinnLaunchContext::FinalizeArguments() const { | |
return name2argument_; | ||
} | ||
|
||
cinn_buffer_t* CinnLaunchContext::GetCinnBufferOfVar( | ||
const std::string& paddle_var_name) { | ||
auto res = paddle2cinn_varmap_.find(paddle_var_name); | ||
PADDLE_ENFORCE_NE( | ||
res, paddle2cinn_varmap_.end(), | ||
platform::errors::InvalidArgument( | ||
"Variable(%s) not found in compilation result", paddle_var_name)); | ||
auto it = name2argument_.find(res->second); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 请在注释中注明 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已在头文件声明处添加注释 |
||
PADDLE_ENFORCE_NE(it, name2argument_.end(), | ||
platform::errors::InvalidArgument( | ||
"Argument(%s) not be initialized", res->second)); | ||
return static_cast<cinn_buffer_t*>(it->second); | ||
} | ||
|
||
} // namespace details | ||
} // namespace operators | ||
} // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果这两句无用请直接删除,请不要在此注释。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done