Skip to content

Commit

Permalink
[Inference] Enable infer shape cache. (#48312)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiweibo authored Dec 8, 2022
1 parent fe86771 commit f88713e
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 23 deletions.
21 changes: 20 additions & 1 deletion paddle/fluid/framework/ir/runtime_context_cache_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,36 @@ limitations under the License. */

#include "paddle/fluid/framework/ir/runtime_context_cache_pass.h"

#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/operator.h"

namespace paddle {
namespace framework {
namespace ir {

void RuntimeContextCachePass::ApplyImpl(ir::Graph* graph) const {
static constexpr char kNotAllowInferShapeCahce[] =
"@NOT_ALLOW_INFERSHAPE_CACHE@";
VLOG(3) << "Applies Runtime Context Cache strategy.";
for (const Node* n : graph->Nodes()) {
if (n->IsOp() && n->Op()) {
n->Op()->SetAttr(kEnableCacheRuntimeContext, true);
n->Op()->SetAttr(framework::kEnableCacheRuntimeContext, true);
}
}

// if op1 -> var0 and op2 -> var0, then op1 and op2 not support
// InferShapeCache.
std::unordered_map<std::string, std::vector<Node*>> var2ops;
for (auto* op_node : TopologySortOperations(*graph)) {
for (auto* var_node : op_node->outputs) {
var2ops[var_node->Name()].push_back(op_node);
}
}
for (auto& it : var2ops) {
if (it.second.size() > 1) {
for (auto op_node : it.second) {
op_node->Op()->SetAttr(kNotAllowInferShapeCahce, true);
}
}
}
}
Expand Down
80 changes: 74 additions & 6 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */

#include <sstream>
#include <string>
#include <unordered_set>

#include "gflags/gflags.h"
#include "paddle/fluid/framework/convert_utils.h"
Expand All @@ -36,6 +37,7 @@ limitations under the License. */
#include "paddle/fluid/platform/profiler/supplement_tracing.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/kernel_context.h"
#include "paddle/phi/core/kernel_factory.h"
#include "paddle/phi/ops/compat/signatures.h"
Expand Down Expand Up @@ -562,6 +564,14 @@ phi::DenseTensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var) {
}
}

OperatorWithKernel::OperatorWithKernel(const std::string& type,
const VariableNameMap& inputs,
const VariableNameMap& outputs,
const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}

OperatorWithKernel::~OperatorWithKernel() = default;

bool ExecutionContext::HasInput(const std::string& name) const {
auto* var = InputVar(name);
return var != nullptr;
Expand Down Expand Up @@ -1204,19 +1214,54 @@ class RuntimeInferShapeContext : public InferShapeContext {
};

struct OperatorWithKernel::CacheImpl {
static const char kNotAllowInferShapeCahce[];
explicit CacheImpl(phi::KernelContext* kernel_ctx,
RuntimeInferShapeContext* infer_shape_ctx)
: kernel_ctx_(kernel_ctx), infer_shape_ctx_(infer_shape_ctx) {}
RuntimeInferShapeContext* infer_shape_ctx,
const std::vector<phi::DenseTensor*>& tensors,
bool not_allow_infer_shape_cache)
: kernel_ctx_(kernel_ctx),
infer_shape_ctx_(infer_shape_ctx),
tensors_(tensors),
not_allow_infer_shape_cache_(not_allow_infer_shape_cache) {}

phi::KernelContext* getKernelContext() { return kernel_ctx_.get(); }
RuntimeInferShapeContext* getRuntimeInferShapeContext() {
return infer_shape_ctx_.get();
}

bool NeedInferShape() {
if (not_allow_infer_shape_cache_) return true;

bool ret{false};
if (last_ddims_.empty() || tensors_.empty()) ret = true;
if (!ret) {
CHECK_EQ(last_ddims_.size(), tensors_.size());
for (size_t i = 0; i < last_ddims_.size(); ++i) {
if (tensors_[i]->dims() != last_ddims_[i]) {
ret = true;
break;
}
}
}
if (ret) {
last_ddims_.resize(tensors_.size());
for (size_t i = 0; i < last_ddims_.size(); ++i) {
last_ddims_[i] = tensors_[i]->dims();
}
}
VLOG(3) << "need infer shape is " << ret;
return ret;
}

private:
std::unique_ptr<phi::KernelContext> kernel_ctx_;
std::unique_ptr<RuntimeInferShapeContext> infer_shape_ctx_;
std::vector<phi::DenseTensor*> tensors_;
bool not_allow_infer_shape_cache_;
std::vector<phi::DDim> last_ddims_;
};
const char OperatorWithKernel::CacheImpl::kNotAllowInferShapeCahce[] =
"@NOT_ALLOW_INFERSHAPE_CACHE@";

static void CheckTensorNANOrInf(const std::string& op_type,
const std::string& name,
Expand Down Expand Up @@ -1524,8 +1569,9 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
pre_scope_ = cur_scope;
} else if (run_phi_kernel_ && impl_ != nullptr && !need_prepare_data_ &&
!need_prepare_phi_data_) {
if (!all_kernels_must_compute_runtime_shape_)
if (!all_kernels_must_compute_runtime_shape_ && impl_->NeedInferShape()) {
this->Info().infer_shape_(impl_->getRuntimeInferShapeContext());
}
(*phi_kernel_)(impl_->getKernelContext());
} else {
if (runtime_ctx_.get() == nullptr || pre_scope_ != cur_scope) {
Expand Down Expand Up @@ -1828,9 +1874,31 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
phi::KernelContext phi_kernel_context;
if (enable_cache_runtime_context_ && !need_prepare_phi_data_ &&
!need_prepare_data_) {
impl_ =
// TODO(inference): Now we only suppor dense_tensor cache, we may be
// support ScalarTensor, SparseTensor in future.
bool all_dense_tensor_input_{true};
for (auto& iter : Inputs()) {
for (auto& name : iter.second) {
all_dense_tensor_input_ &=
scope.FindVar(name)->IsType<phi::DenseTensor>();
}
}

std::vector<phi::DenseTensor*> tensors;
if (all_dense_tensor_input_) {
for (auto& iter : Inputs()) {
for (auto& name : iter.second) {
auto* t = scope.FindVar(name)->GetMutable<phi::DenseTensor>();
tensors.push_back(t);
}
}
}

impl_.reset(
new CacheImpl(new phi::KernelContext(),
new RuntimeInferShapeContext(*this, *runtime_ctx));
new RuntimeInferShapeContext(*this, *runtime_ctx),
tensors,
HasAttr(CacheImpl::kNotAllowInferShapeCahce)));
BuildPhiKernelContext(*runtime_ctx, dev_ctx, impl_->getKernelContext());
(*phi_kernel_)(impl_->getKernelContext());
} else {
Expand Down Expand Up @@ -3246,6 +3314,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
if (phi::OneDNNContext::classof(dev_ctx)) {
phi::OneDNNContext* one_dnn_ctx = static_cast<phi::OneDNNContext*>(dev_ctx);
one_dnn_ctx->ClearDnnAttr();
if (!RuntimeAttrs().empty()) need_prepare_phi_data_ = true;
}
#endif

Expand All @@ -3267,7 +3336,6 @@ void OperatorWithKernel::BuildPhiKernelContext(
#if defined(PADDLE_WITH_MKLDNN) || defined(PADDLE_WITH_CUDA)
auto& runtime_attrs = RuntimeAttrs();
for (const auto& attr_iter : runtime_attrs) {
need_prepare_phi_data_ = true;
auto& attr_name = attr_iter.first;
auto& attr = attr_iter.second;
auto attr_propertys = paddle::operators::GetExtraAttrProperties(attr_name);
Expand Down
8 changes: 5 additions & 3 deletions paddle/fluid/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -612,8 +612,9 @@ class OperatorWithKernel : public OperatorBase {
OperatorWithKernel(const std::string& type,
const VariableNameMap& inputs,
const VariableNameMap& outputs,
const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
const AttributeMap& attrs);

virtual ~OperatorWithKernel();

static paddle::flat_hash_map<std::string /* op_type */, OpKernelMap>&
AllOpKernels() {
Expand Down Expand Up @@ -785,8 +786,9 @@ class OperatorWithKernel : public OperatorBase {
mutable std::unique_ptr<phi::Kernel> phi_kernel_;
mutable std::unique_ptr<phi::ArgumentMappingFn> arg_map_fn_;

private:
struct CacheImpl;
mutable CacheImpl* impl_{nullptr};
mutable std::unique_ptr<CacheImpl> impl_;
};

extern bool OpSupportGPU(const std::string& op_type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ namespace inference {
namespace analysis {

void IrGraphToProgramPass::RunImpl(Argument *argument) {
auto cache_pass =
framework::ir::PassRegistry::Instance().Get("runtime_context_cache_pass");
auto pass =
framework::ir::PassRegistry::Instance().Get("graph_to_program_pass");

Expand All @@ -31,14 +33,12 @@ void IrGraphToProgramPass::RunImpl(Argument *argument) {
new int(argument->memory_optim_sort_kind()));
}

std::unique_ptr<framework::ir::Graph> graph(argument->main_graph_ptr());

// Direct using ProgramDesc desc(argument->main_program()) may cause
// incomplete copies of information.
framework::ProgramDesc desc;
desc.CopyFrom(*argument->main_program().Proto());
pass->SetNotOwned("program", &desc);
pass->Apply(graph.release()); // the argument still own the graph.
pass->Apply(cache_pass->Apply(argument->main_graph_ptr()));

argument->SetIrAnalyzedProgram(
new framework::proto::ProgramDesc(*desc.Proto()));
Expand Down
12 changes: 2 additions & 10 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ const std::vector<std::string> kGpuLowerPrecisionPasses{
"fc_fuse_pass",
"fc_elementwise_layernorm_fuse_pass",
"embedding_eltwise_layernorm_fuse_pass",
"runtime_context_cache_pass",
};

const std::vector<std::string> kTrtLowerPrecisionPasses{
Expand Down Expand Up @@ -254,10 +253,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
#endif //
"transpose_flatten_concat_fuse_pass", //
"constant_folding_pass", //
// following pass should be located in the last, since it will
// work on all fused ops.
"float_to_half_pass", //
"runtime_context_cache_pass"
"float_to_half_pass", //
});

use_gpu_ = true;
Expand Down Expand Up @@ -322,10 +318,7 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
"conv_transpose_bn_fuse_pass", //
"conv_transpose_eltwiseadd_bn_fuse_pass", //
"is_test_pass", //
"constant_folding_pass",
// following pass should be located in the last, since
// it will work on all fused ops.
"runtime_context_cache_pass"});
"constant_folding_pass"});

use_gpu_ = false;
}
Expand Down Expand Up @@ -475,7 +468,6 @@ void CpuPassStrategy::EnableMkldnnInt8() {
passes_.push_back("int8_scale_calculation_mkldnn_pass");
passes_.push_back("params_quantization_mkldnn_pass");
passes_.push_back("mkldnn_inplace_pass");
passes_.push_back("runtime_context_cache_pass");
}
use_mkldnn_int8_ = true;
#else
Expand Down

0 comments on commit f88713e

Please sign in to comment.