Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference] Enable infer shape cache. #48312

Merged
merged 14 commits into from
Dec 8, 2022
21 changes: 20 additions & 1 deletion paddle/fluid/framework/ir/runtime_context_cache_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,36 @@ limitations under the License. */

#include "paddle/fluid/framework/ir/runtime_context_cache_pass.h"

#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/operator.h"

namespace paddle {
namespace framework {
namespace ir {

void RuntimeContextCachePass::ApplyImpl(ir::Graph* graph) const {
static constexpr char kNotAllowInferShapeCahce[] =
"@NOT_ALLOW_INFERSHAPE_CACHE@";
VLOG(3) << "Applies Runtime Context Cache strategy.";
for (const Node* n : graph->Nodes()) {
if (n->IsOp() && n->Op()) {
n->Op()->SetAttr(kEnableCacheRuntimeContext, true);
n->Op()->SetAttr(framework::kEnableCacheRuntimeContext, true);
}
}

// if op1 -> var0 and op2 -> var0, then op1 and op2 not support
// InferShapeCache.
std::unordered_map<std::string, std::vector<Node*>> var2ops;
for (auto* op_node : TopologySortOperations(*graph)) {
for (auto* var_node : op_node->outputs) {
var2ops[var_node->Name()].push_back(op_node);
}
}
for (auto& it : var2ops) {
if (it.second.size() > 1) {
for (auto op_node : it.second) {
op_node->Op()->SetAttr(kNotAllowInferShapeCahce, true);
}
}
}
}
Expand Down
80 changes: 74 additions & 6 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */

#include <sstream>
#include <string>
#include <unordered_set>

#include "gflags/gflags.h"
#include "paddle/fluid/framework/convert_utils.h"
Expand All @@ -36,6 +37,7 @@ limitations under the License. */
#include "paddle/fluid/platform/profiler/supplement_tracing.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/kernel_context.h"
#include "paddle/phi/core/kernel_factory.h"
#include "paddle/phi/ops/compat/signatures.h"
Expand Down Expand Up @@ -562,6 +564,14 @@ phi::DenseTensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var) {
}
}

OperatorWithKernel::OperatorWithKernel(const std::string& type,
const VariableNameMap& inputs,
const VariableNameMap& outputs,
const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}

OperatorWithKernel::~OperatorWithKernel() = default;

bool ExecutionContext::HasInput(const std::string& name) const {
auto* var = InputVar(name);
return var != nullptr;
Expand Down Expand Up @@ -1204,19 +1214,54 @@ class RuntimeInferShapeContext : public InferShapeContext {
};

struct OperatorWithKernel::CacheImpl {
static const char kNotAllowInferShapeCahce[];
explicit CacheImpl(phi::KernelContext* kernel_ctx,
RuntimeInferShapeContext* infer_shape_ctx)
: kernel_ctx_(kernel_ctx), infer_shape_ctx_(infer_shape_ctx) {}
RuntimeInferShapeContext* infer_shape_ctx,
const std::vector<phi::DenseTensor*>& tensors,
bool not_allow_infer_shape_cache)
: kernel_ctx_(kernel_ctx),
infer_shape_ctx_(infer_shape_ctx),
tensors_(tensors),
not_allow_infer_shape_cache_(not_allow_infer_shape_cache) {}

phi::KernelContext* getKernelContext() { return kernel_ctx_.get(); }
RuntimeInferShapeContext* getRuntimeInferShapeContext() {
return infer_shape_ctx_.get();
}

bool NeedInferShape() {
if (not_allow_infer_shape_cache_) return true;

bool ret{false};
if (last_ddims_.empty() || tensors_.empty()) ret = true;
if (!ret) {
CHECK_EQ(last_ddims_.size(), tensors_.size());
for (size_t i = 0; i < last_ddims_.size(); ++i) {
if (tensors_[i]->dims() != last_ddims_[i]) {
ret = true;
break;
}
}
}
if (ret) {
last_ddims_.resize(tensors_.size());
for (size_t i = 0; i < last_ddims_.size(); ++i) {
last_ddims_[i] = tensors_[i]->dims();
}
}
VLOG(3) << "need infer shape is " << ret;
return ret;
}

private:
std::unique_ptr<phi::KernelContext> kernel_ctx_;
std::unique_ptr<RuntimeInferShapeContext> infer_shape_ctx_;
std::vector<phi::DenseTensor*> tensors_;
bool not_allow_infer_shape_cache_;
std::vector<phi::DDim> last_ddims_;
};
const char OperatorWithKernel::CacheImpl::kNotAllowInferShapeCahce[] =
"@NOT_ALLOW_INFERSHAPE_CACHE@";

static void CheckTensorNANOrInf(const std::string& op_type,
const std::string& name,
Expand Down Expand Up @@ -1524,8 +1569,9 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
pre_scope_ = cur_scope;
} else if (run_phi_kernel_ && impl_ != nullptr && !need_prepare_data_ &&
!need_prepare_phi_data_) {
if (!all_kernels_must_compute_runtime_shape_)
if (!all_kernels_must_compute_runtime_shape_ && impl_->NeedInferShape()) {
this->Info().infer_shape_(impl_->getRuntimeInferShapeContext());
}
(*phi_kernel_)(impl_->getKernelContext());
} else {
if (runtime_ctx_.get() == nullptr || pre_scope_ != cur_scope) {
Expand Down Expand Up @@ -1828,9 +1874,31 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
phi::KernelContext phi_kernel_context;
if (enable_cache_runtime_context_ && !need_prepare_phi_data_ &&
!need_prepare_data_) {
impl_ =
// TODO(inference): Now we only suppor dense_tensor cache, we may be
// support ScalarTensor, SparseTensor in future.
bool all_dense_tensor_input_{true};
for (auto& iter : Inputs()) {
for (auto& name : iter.second) {
all_dense_tensor_input_ &=
scope.FindVar(name)->IsType<phi::DenseTensor>();
}
}

std::vector<phi::DenseTensor*> tensors;
if (all_dense_tensor_input_) {
for (auto& iter : Inputs()) {
for (auto& name : iter.second) {
auto* t = scope.FindVar(name)->GetMutable<phi::DenseTensor>();
tensors.push_back(t);
}
}
}

impl_.reset(
new CacheImpl(new phi::KernelContext(),
new RuntimeInferShapeContext(*this, *runtime_ctx));
new RuntimeInferShapeContext(*this, *runtime_ctx),
tensors,
HasAttr(CacheImpl::kNotAllowInferShapeCahce)));
BuildPhiKernelContext(*runtime_ctx, dev_ctx, impl_->getKernelContext());
(*phi_kernel_)(impl_->getKernelContext());
} else {
Expand Down Expand Up @@ -3246,6 +3314,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
if (phi::OneDNNContext::classof(dev_ctx)) {
phi::OneDNNContext* one_dnn_ctx = static_cast<phi::OneDNNContext*>(dev_ctx);
one_dnn_ctx->ClearDnnAttr();
if (!RuntimeAttrs().empty()) need_prepare_phi_data_ = true;
}
#endif

Expand All @@ -3267,7 +3336,6 @@ void OperatorWithKernel::BuildPhiKernelContext(
#if defined(PADDLE_WITH_MKLDNN) || defined(PADDLE_WITH_CUDA)
auto& runtime_attrs = RuntimeAttrs();
for (const auto& attr_iter : runtime_attrs) {
need_prepare_phi_data_ = true;
auto& attr_name = attr_iter.first;
auto& attr = attr_iter.second;
auto attr_propertys = paddle::operators::GetExtraAttrProperties(attr_name);
Expand Down
8 changes: 5 additions & 3 deletions paddle/fluid/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -612,8 +612,9 @@ class OperatorWithKernel : public OperatorBase {
OperatorWithKernel(const std::string& type,
const VariableNameMap& inputs,
const VariableNameMap& outputs,
const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
const AttributeMap& attrs);

virtual ~OperatorWithKernel();

static paddle::flat_hash_map<std::string /* op_type */, OpKernelMap>&
AllOpKernels() {
Expand Down Expand Up @@ -785,8 +786,9 @@ class OperatorWithKernel : public OperatorBase {
mutable std::unique_ptr<phi::Kernel> phi_kernel_;
mutable std::unique_ptr<phi::ArgumentMappingFn> arg_map_fn_;

private:
struct CacheImpl;
mutable CacheImpl* impl_{nullptr};
mutable std::unique_ptr<CacheImpl> impl_;
};

extern bool OpSupportGPU(const std::string& op_type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ namespace inference {
namespace analysis {

void IrGraphToProgramPass::RunImpl(Argument *argument) {
auto cache_pass =
framework::ir::PassRegistry::Instance().Get("runtime_context_cache_pass");
auto pass =
framework::ir::PassRegistry::Instance().Get("graph_to_program_pass");

Expand All @@ -31,14 +33,12 @@ void IrGraphToProgramPass::RunImpl(Argument *argument) {
new int(argument->memory_optim_sort_kind()));
}

std::unique_ptr<framework::ir::Graph> graph(argument->main_graph_ptr());

// Direct using ProgramDesc desc(argument->main_program()) may cause
// incomplete copies of information.
framework::ProgramDesc desc;
desc.CopyFrom(*argument->main_program().Proto());
pass->SetNotOwned("program", &desc);
pass->Apply(graph.release()); // the argument still own the graph.
pass->Apply(cache_pass->Apply(argument->main_graph_ptr()));

argument->SetIrAnalyzedProgram(
new framework::proto::ProgramDesc(*desc.Proto()));
Expand Down
12 changes: 2 additions & 10 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ const std::vector<std::string> kGpuLowerPrecisionPasses{
"fc_fuse_pass",
"fc_elementwise_layernorm_fuse_pass",
"embedding_eltwise_layernorm_fuse_pass",
"runtime_context_cache_pass",
};

const std::vector<std::string> kTrtLowerPrecisionPasses{
Expand Down Expand Up @@ -247,10 +246,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
#endif //
"transpose_flatten_concat_fuse_pass", //
"constant_folding_pass", //
// following pass should be located in the last, since it will
// work on all fused ops.
"float_to_half_pass", //
"runtime_context_cache_pass"
"float_to_half_pass", //
});

use_gpu_ = true;
Expand Down Expand Up @@ -315,10 +311,7 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
"conv_transpose_bn_fuse_pass", //
"conv_transpose_eltwiseadd_bn_fuse_pass", //
"is_test_pass", //
"constant_folding_pass",
// following pass should be located in the last, since
// it will work on all fused ops.
"runtime_context_cache_pass"});
"constant_folding_pass"});

use_gpu_ = false;
}
Expand Down Expand Up @@ -468,7 +461,6 @@ void CpuPassStrategy::EnableMkldnnInt8() {
passes_.push_back("int8_scale_calculation_mkldnn_pass");
passes_.push_back("params_quantization_mkldnn_pass");
passes_.push_back("mkldnn_inplace_pass");
passes_.push_back("runtime_context_cache_pass");
}
use_mkldnn_int8_ = true;
#else
Expand Down