Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cinn_instruction_run_op for launching execution of a cinn instruction #39435

Merged
merged 11 commits into from
Feb 15, 2022
Merged
20 changes: 16 additions & 4 deletions paddle/fluid/framework/paddle2cinn/cinn_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ const CinnCompiledObject& CinnCompiler::Compile(
if (cache_by_struct_.count(cur_key_by_struct) != 0) {
exist = true;
cache_by_address_[cur_key_by_address] =
cache_by_struct_.at(cur_key_by_struct).get();
cache_by_struct_.at(cur_key_by_struct);
}
}
}
Expand All @@ -98,12 +98,13 @@ const CinnCompiledObject& CinnCompiler::Compile(
CompileGraph(graph, input_tensors, target, compiled_num, stream);
pten::AutoWRLock w_guard{&rwlock_};
if (!cache_by_struct_.count(cur_key_by_struct)) {
cache_by_address_[cur_key_by_address] = compiled_res.get();
cache_by_struct_[cur_key_by_struct] = std::move(compiled_res);
cache_by_address_[cur_key_by_address] = compiled_num;
cache_by_struct_[cur_key_by_struct] = compiled_num;
index2cache_.emplace(compiled_num, std::move(compiled_res));
}
}
pten::AutoRDLock guard{&rwlock_};
const auto& cached_boj = *cache_by_address_[cur_key_by_address];
const auto& cached_boj = *index2cache_[cache_by_address_[cur_key_by_address]];
return cached_boj;
}

Expand All @@ -115,6 +116,15 @@ const CinnCompiledObject& CinnCompiler::Compile(
return Compile(graph, input_tensors, target, stream);
}

const CinnCompiledObject& CinnCompiler::GetCompiledObject(
int64_t cached_index) const {
auto res = index2cache_.find(cached_index);
PADDLE_ENFORCE_NE(res, index2cache_.end(),
platform::errors::InvalidArgument(
"Index(%ld) not found in cache", cached_index));
return *res->second;
}

std::string CinnCompiler::AddGraph(std::unique_ptr<Graph> graph) {
std::string graph_key;
ProgramDesc program;
Expand Down Expand Up @@ -202,6 +212,7 @@ void CinnCompiler::Clear() {
graphs_.clear();
cache_by_address_.clear();
cache_by_struct_.clear();
index2cache_.clear();
}
real_compiled_num_.store(0);
}
Expand Down Expand Up @@ -240,6 +251,7 @@ std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph(
compiled_obj->launch_context =
std::make_unique<operators::details::CinnLaunchContext>(
compiled_obj->paddle2cinn_varmap, compiled_obj->scope);
compiled_obj->cached_index = compiled_num;
return compiled_obj;
}

Expand Down
11 changes: 7 additions & 4 deletions paddle/fluid/framework/paddle2cinn/cinn_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ struct CinnCompiledObject {
std::shared_ptr<::cinn::hlir::framework::Scope> scope;
std::unordered_map<std::string, std::string> paddle2cinn_varmap;
std::unique_ptr<operators::details::CinnLaunchContext> launch_context;
std::int64_t cached_index;
};

// Entrance to use CINN.
Expand All @@ -70,6 +71,8 @@ class CinnCompiler {
const std::map<std::string, const LoDTensor*>& input_tensors,
const ::cinn::common::Target& target, void* stream = nullptr);

const CinnCompiledObject& GetCompiledObject(int64_t cached_index) const;

std::string AddGraph(std::unique_ptr<ir::Graph> graph);

const ir::Graph& FindGraph(const std::string& graph_key) const;
Expand All @@ -95,12 +98,12 @@ class CinnCompiler {
void* stream = nullptr) const;

std::unordered_map<std::string, std::unique_ptr<ir::Graph>> graphs_;
std::unordered_map<CinnCacheKeyByAddress, CinnCompiledObject*,
CinnCacheKey::Hash>
std::unordered_map<CinnCacheKeyByAddress, std::int64_t, CinnCacheKey::Hash>
cache_by_address_;
std::unordered_map<CinnCacheKeyByStructure,
std::unique_ptr<CinnCompiledObject>, CinnCacheKey::Hash>
std::unordered_map<CinnCacheKeyByStructure, std::int64_t, CinnCacheKey::Hash>
cache_by_struct_;
std::unordered_map<std::int64_t, std::unique_ptr<CinnCompiledObject>>
index2cache_;
std::atomic_int64_t real_compiled_num_{0};
mutable pten::RWLock rwlock_;

Expand Down
7 changes: 7 additions & 0 deletions paddle/fluid/framework/paddle2cinn/cinn_compiler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,13 +270,20 @@ TEST(CinnCompilerTest, Compile) {
auto compile_fn = [&](const Target& target) {
const auto& compiled_obj =
cinn_compiler->Compile(compiling_graph, input_tensors, target);
ASSERT_NE(compiled_obj.compiler, nullptr);
ASSERT_NE(compiled_obj.runtime_program, nullptr);
ASSERT_NE(compiled_obj.scope, nullptr);
ASSERT_FALSE(compiled_obj.paddle2cinn_varmap.empty());
ASSERT_NE(compiled_obj.launch_context, nullptr);
const auto& cached_obj =
cinn_compiler->Compile(compilation_key, input_tensors, target);
ASSERT_EQ(reinterpret_cast<std::uint64_t>(&compiled_obj),
reinterpret_cast<std::uint64_t>(&cached_obj));
ASSERT_EQ(cached_obj.cached_index + 1, cinn_compiler->real_compiled_num());
const auto& ret_obj =
cinn_compiler->GetCompiledObject(cached_obj.cached_index);
ASSERT_EQ(reinterpret_cast<std::uint64_t>(&compiled_obj),
reinterpret_cast<std::uint64_t>(&ret_obj));
};

// GPU Compilation
Expand Down
8 changes: 6 additions & 2 deletions paddle/fluid/operators/cinn/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
include(operators)
register_operators(EXCLUDES cinn_launch_op)

cc_library(cinn_op_helper SRCS cinn_op_helper.cc DEPS operator device_context)
cc_library(cinn_launch_context SRCS cinn_launch_context.cc DEPS ddim lod_tensor scope cinn)
op_library(cinn_launch_op SRCS cinn_launch_op.cc cinn_launch_op.cu.cc DEPS string_helper cinn cinn_compiler cinn_launch_context)

SET(CINN_OP_DEPS string_helper cinn cinn_compiler cinn_op_helper cinn_launch_context)
register_operators(DEPS ${CINN_OP_DEPS})
#op_library(cinn_launch_op SRCS cinn_launch_op.cc cinn_launch_op.cu.cc DEPS string_helper cinn cinn_compiler cinn_op_helper cinn_launch_context)
#op_library(cinn_instruction_run_op SRCS cinn_instruction_run_op.cc cinn_instruction_run_op.cu.cc DEPS string_helper cinn cinn_compiler cinn_op_helper cinn_launch_context)
Copy link
Contributor

@wzzju wzzju Feb 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果这两句无用请直接删除,请不要在此注释。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


if (WITH_TESTING)
cc_test(cinn_launch_context_test SRCS cinn_launch_context_test.cc DEPS ddim lod_tensor scope cinn_launch_context)
Expand Down
109 changes: 109 additions & 0 deletions paddle/fluid/operators/cinn/cinn_instruction_run_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/cinn/cinn_instruction_run_op.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
#include "paddle/fluid/operators/cinn/cinn_launch_context.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle::operators {

class CinnInstructionRunOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs(kX), "Input", kX, "CinnInstructionRun");
OP_INOUT_CHECK(ctx->HasOutputs(kOutputs), "Output", kOutputs,
"CinnInstructionRun");
const CinnCompiledObject& compiled_object =
CinnCompiler::GetInstance()->GetCompiledObject(
ctx->Attrs().Get<int64_t>(kCachedIndex));

details::CinnLaunchContext* launch_context =
compiled_object.launch_context.get();
std::vector<std::string> output_args = ctx->Outputs(kOutputs);
std::vector<framework::DDim> output_dims;
std::transform(output_args.begin(), output_args.end(), output_dims.begin(),
[launch_context](const std::string& var_name) {
cinn_buffer_t* buffer =
launch_context->GetCinnBufferOfVar(var_name);
return framework::DDim(buffer->dims, buffer->dimensions);
});
ctx->SetOutputsDim(kOutputs, output_dims);
}
};

class CinnInstructionRunOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput(kX,
"(vector<LoDTensor>)"
"which are the input arguments of this cinn instruction")
.AsDuplicable();
AddOutput(kOutputs,
"(vector<LoDTensor>)"
"which are the output arguments of this cinn instruction")
.AsDuplicable();
AddAttr<int64_t>(
kCachedIndex,
"(int64_t)"
"the stored index of the cached compilation result in CinnCompiler,"
"which is used to fetch the CinnCompiledObject where this cinn "
"instruction is included");
AddAttr<int64_t>(
kInstructionIndex,
"(int64_t)"
"the index of this instruction to the cinn runtime program");
AddComment(R"DOC(
CinnInstructionRun Operator.

This operator is used to launch a
CINN(https://github.com/PaddlePaddle/CINN/blob/develop/README.md) instruction execution

Both the input and output of this operator are a set of variables
which are the input and output arguments of the bound cinn instruction respectively.
In addition, there is an attribute named 'cached_index' should be
set necessarily to get the CinnCompiledObject where the instruction is included
and 'instruction_index' is fetch the instruction object from complied runtime prograrm.

It accomplishes the execution of the instruction according to the following steps:
0. Set the shapes ot the output variables at InferShape function with
compilation result.
1. Fetch the cinn instruction bound to this operator by 'cached_index'
and 'instruction_index' from CinnCompiler.
2. Prepare the input and output variables of the instruction in Paddle and share
their buffers to CINN by setting 'memory' of according cinn_buffer_t.
3. Launch CINN runtime to execute the instruction.

)DOC");
}
};

} // namespace paddle::operators

namespace ops = paddle::operators;
using CPUDeviceContext = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(
cinn_instruction_run, ops::CinnInstructionRunOp,
ops::CinnInstructionRunOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
cinn_instruction_run,
ops::CinnInstructionRunOpKernel<CPUDeviceContext, bool>,
ops::CinnInstructionRunOpKernel<CPUDeviceContext, int>,
ops::CinnInstructionRunOpKernel<CPUDeviceContext, int64_t>,
ops::CinnInstructionRunOpKernel<CPUDeviceContext, float>,
ops::CinnInstructionRunOpKernel<CPUDeviceContext, double>);
26 changes: 26 additions & 0 deletions paddle/fluid/operators/cinn/cinn_instruction_run_op.cu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/cinn/cinn_instruction_run_op.h"
#include "paddle/fluid/framework/op_registry.h"

namespace ops = paddle::operators;
using CUDADeviceContext = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(
cinn_instruction_run,
ops::CinnInstructionRunOpKernel<CUDADeviceContext, bool>,
ops::CinnInstructionRunOpKernel<CUDADeviceContext, int>,
ops::CinnInstructionRunOpKernel<CUDADeviceContext, int64_t>,
ops::CinnInstructionRunOpKernel<CUDADeviceContext, float>,
ops::CinnInstructionRunOpKernel<CUDADeviceContext, double>);
75 changes: 75 additions & 0 deletions paddle/fluid/operators/cinn/cinn_instruction_run_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <iterator>
#include <memory>
#include <string>
#include <vector>
#include "cinn/hlir/framework/instruction.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
#include "paddle/fluid/operators/cinn/cinn_launch_context.h"
#include "paddle/fluid/operators/cinn/cinn_op_helper.h"

namespace paddle::operators {

using CinnInstruction = ::cinn::hlir::framework::Instruction;
using CinnCompiledObject = framework::paddle2cinn::CinnCompiledObject;
using CinnCompiler = framework::paddle2cinn::CinnCompiler;

template <typename DeviceContext, typename T>
class CinnInstructionRunOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
// step 1: fetch the cinn instruction bound to this operator
auto cached_index = ctx.template Attr<int64_t>(kCachedIndex);
auto ins_index = ctx.template Attr<int64_t>(kInstructionIndex);
const CinnCompiledObject& compiled_object =
CinnCompiler::GetInstance()->GetCompiledObject(cached_index);
const std::vector<std::unique_ptr<CinnInstruction>>& instructions =
compiled_object.runtime_program->GetRunInstructions();
PADDLE_ENFORCE_LT(ins_index, instructions.size(),
platform::errors::InvalidArgument(
"Index(%ld) > instructions.size(%ld).", ins_index,
instructions.size()));
auto&& instruction = instructions.at(ins_index);

// step 2: prepare the input and output arguments of the instruction
details::CinnLaunchContext* launch_context =
compiled_object.launch_context.get();
auto share_argument_buffer_fn = [launch_context,
&ctx](const std::string& var_name) {
cinn_buffer_t* buffer = launch_context->GetCinnBufferOfVar(var_name);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这儿是否可能返回为空指针或者野指针啊?特别是在多线程环境下?建议加个PADDLE_ENFORCE_NE

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GetCinnBufferOfVar函数中有 ENFORCE 判断,若是变量名不存在会报错

framework::Variable* var = ctx.scope().GetVar(var_name);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同理,这儿要不先判断下ctx.scope().HasVar(var_name)(如果有这个接口)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scope.GetVar 函数中有 ENFORCE 判断

auto* tensor = var->template GetMutable<framework::LoDTensor>();
buffer->memory =
reinterpret_cast<uint8_t*>(tensor->mutable_data<T>(ctx.GetPlace()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这儿为啥要转为uint8_t*啊?CINN的buffer指针都是uint8_t*类型的么?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

裸指针,是。cinn 会根据实际数据类型读取

};
std::vector<std::string> in_args = ctx.InputNames(kX);
std::for_each(in_args.begin(), in_args.end(), share_argument_buffer_fn);
std::vector<std::string> out_args = ctx.OutputNames(kOutputs);
std::for_each(out_args.begin(), out_args.end(), share_argument_buffer_fn);

// step 3: launch CINN runtime to execute the instruction
// TODO(CtfGo): simplify format of arguments package as a vector in CINN
// and update this usage call
instruction->Run(&launch_context->FinalizeArguments(), false,
details::GetStream<DeviceContext>(ctx));
}
};

} // namespace paddle::operators
33 changes: 33 additions & 0 deletions paddle/fluid/operators/cinn/cinn_launch_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,31 @@ CinnLaunchContext::CinnLaunchContext(
const std::unordered_map<std::string, std::string>& paddle2cinn_varmap,
const std::shared_ptr<CinnScope>& cinn_scope)
: paddle2cinn_varmap_(paddle2cinn_varmap), cinn_scope_(cinn_scope) {
// generate all names of cinn used variables
auto var_names = cinn_scope_->var_names();
cinn_variable_names_.reserve(var_names.size());
std::transform(
var_names.begin(), var_names.end(),
std::inserter(cinn_variable_names_, cinn_variable_names_.end()),
[](const auto& name_view) { return std::string(name_view.data()); });
// build the variable name map of cinn2paddle
for (const auto& x : paddle2cinn_varmap_) {
auto res = cinn2paddle_varmap_.emplace(x.second, x.first);
PADDLE_ENFORCE_EQ(
res.second, true,
platform::errors::InvalidArgument(
"Cinn variable(%s) maps to more than one paddle variable(%s,%s)",
x.second, res.first->second, x.first));
}
// supplement the relations of the remain variables not appearing in above
// map,
// they are internal variables and here we use the name from cinn compiled.
for (const auto& var_name : cinn_variable_names_) {
if (!cinn2paddle_varmap_.count(var_name)) {
cinn2paddle_varmap_.emplace(var_name, var_name);
paddle2cinn_varmap_.emplace(var_name, var_name);
}
}
}

void CinnLaunchContext::UpdateCapturedEnv(const framework::Scope& scope,
Expand Down Expand Up @@ -189,6 +208,20 @@ CinnLaunchContext::FinalizeArguments() const {
return name2argument_;
}

cinn_buffer_t* CinnLaunchContext::GetCinnBufferOfVar(
const std::string& paddle_var_name) {
auto res = paddle2cinn_varmap_.find(paddle_var_name);
PADDLE_ENFORCE_NE(
res, paddle2cinn_varmap_.end(),
platform::errors::InvalidArgument(
"Variable(%s) not found in compilation result", paddle_var_name));
auto it = name2argument_.find(res->second);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请在注释中注明name2argument_中的key使用的是cinn var的名字。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已在头文件声明处添加注释

PADDLE_ENFORCE_NE(it, name2argument_.end(),
platform::errors::InvalidArgument(
"Argument(%s) not be initialized", res->second));
return static_cast<cinn_buffer_t*>(it->second);
}

} // namespace details
} // namespace operators
} // namespace paddle
Loading