diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.cc b/paddle/fluid/framework/new_executor/new_executor_defs.cc index ccdd9dc9d50ce..bf114a9ee0f48 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.cc +++ b/paddle/fluid/framework/new_executor/new_executor_defs.cc @@ -322,6 +322,10 @@ bool InterpretercoreInferShapeContext::IsRunMKLDNNKernel() const { } } +Scope* InterpretercoreInferShapeContext::GetScopePtr() const { + return nullptr; +} + // TODO(paddle-dev): Can this be template? std::vector InterpretercoreInferShapeContext::GetInputVarPtrs( const std::string& name) const { diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.h b/paddle/fluid/framework/new_executor/new_executor_defs.h index 5704fa414bbb2..14e44ba071fd0 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.h +++ b/paddle/fluid/framework/new_executor/new_executor_defs.h @@ -88,6 +88,8 @@ class InterpretercoreInferShapeContext : public InferShapeContext { bool IsRunMKLDNNKernel() const override; + Scope* GetScopePtr() const override; + // TODO(paddle-dev): Can this be template? std::vector GetInputVarPtrs( const std::string& name) const override; diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index f31fefcfade89..464c3726c765b 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -244,6 +244,8 @@ class CompileTimeInferShapeContext : public InferShapeContext { bool IsRunMKLDNNKernel() const override; + Scope* GetScopePtr() const override; + std::vector GetInputsVarType( const std::string &name) const override { return GetVarTypes(Inputs(name)); @@ -947,6 +949,8 @@ bool CompileTimeInferShapeContext::IsRuntime() const { return false; } bool CompileTimeInferShapeContext::IsRunMKLDNNKernel() const { return false; } +Scope* CompileTimeInferShapeContext::GetScopePtr() const { return nullptr;} + proto::VarType::Type CompileTimeInferShapeContext::GetVarType( const std::string &name) const { return block_.FindVarRecursive(name)->GetType(); diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index cb0a25b1f5abb..21f4640662c6b 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -949,6 +949,10 @@ bool RuntimeInferShapeContext::IsRunMKLDNNKernel() const { } } +Scope* RuntimeInferShapeContext::GetScopePtr() const { + return const_cast(scope_); +} + // TODO(paddle-dev): Can this be template? std::vector RuntimeInferShapeContext::GetInputVarPtrs( const std::string& name) const { @@ -1224,13 +1228,13 @@ void OperatorWithKernel::InferShape(InferShapeContext* ctx) const { void OperatorWithKernel::RuntimeInferShape(const Scope& scope, const platform::Place& place, const RuntimeContext& ctx) const { - RuntimeInferShapeContext infer_shape_ctx(*this, ctx); + RuntimeInferShapeContext infer_shape_ctx(*this, ctx, scope); this->Info().infer_shape_(&infer_shape_ctx); } void OperatorWithKernel::RuntimeInferShape(const Scope& scope) const { RuntimeContext ctx(Inputs(), Outputs(), scope); - RuntimeInferShapeContext infer_shape_ctx(*this, ctx); + RuntimeInferShapeContext infer_shape_ctx(*this, ctx, scope); this->Info().infer_shape_(&infer_shape_ctx); } @@ -1442,7 +1446,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, platform::RecordEvent record_event("infer_shape", platform::TracerEventType::OperatorInner, 1, platform::EventRole::kInnerOp); - RuntimeInferShapeContext infer_shape_ctx(*this, *runtime_ctx); + RuntimeInferShapeContext infer_shape_ctx(*this, *runtime_ctx, exec_scope); this->Info().infer_shape_(&infer_shape_ctx); } diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 31a085661c5de..32d414243321e 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -137,7 +137,6 @@ class RuntimeContext { RuntimeContext(const VariableValueMap& invars, const VariableValueMap& outvars) : inputs(invars), outputs(outvars) {} - VariableValueMap inputs; VariableValueMap outputs; }; @@ -715,8 +714,8 @@ class OperatorWithKernel : public OperatorBase { class RuntimeInferShapeContext : public InferShapeContext { public: - RuntimeInferShapeContext(const OperatorBase& op, const RuntimeContext& ctx) - : op_(op), ctx_(ctx) {} + RuntimeInferShapeContext(const OperatorBase& op, const RuntimeContext& ctx, const Scope& scope) + : op_(op), ctx_(ctx), scope_(&scope) {} bool HasInput(const std::string &name) const override; bool HasOutput(const std::string &name) const override; @@ -760,6 +759,8 @@ class RuntimeInferShapeContext : public InferShapeContext { bool IsRunMKLDNNKernel() const override; + Scope* GetScopePtr() const override; + std::vector GetInputVarPtrs( const std::string &name) const override; std::vector GetOutputVarPtrs( @@ -789,6 +790,7 @@ class RuntimeInferShapeContext : public InferShapeContext { const std::vector& OutputVars(const std::string& name) const; const OperatorBase& op_; const RuntimeContext& ctx_; + const Scope* scope_; }; extern bool OpSupportGPU(const std::string& op_type); diff --git a/paddle/fluid/framework/ps_gpu_worker.cc b/paddle/fluid/framework/ps_gpu_worker.cc index 165b82ee0dacb..8909159ce4e81 100644 --- a/paddle/fluid/framework/ps_gpu_worker.cc +++ b/paddle/fluid/framework/ps_gpu_worker.cc @@ -258,7 +258,7 @@ int PSGPUWorker::OpRunAndShapeCheck(OperatorBase& op, auto& after_dims = check_data.after_dims; auto& after_lods = check_data.after_lods; RuntimeContext ctx(op.Inputs(), op.Outputs(), scope); - RuntimeInferShapeContext infer_shape_ctx(op, ctx); + RuntimeInferShapeContext infer_shape_ctx(op, ctx, scope); auto outnames = op.Outputs(); for (auto& var_name_item : outnames) { pre_dims.push_back(infer_shape_ctx.GetOutputsDim(var_name_item.first)); diff --git a/paddle/fluid/framework/scope.cc b/paddle/fluid/framework/scope.cc index 0463f5788f154..c8435ea84cb6d 100644 --- a/paddle/fluid/framework/scope.cc +++ b/paddle/fluid/framework/scope.cc @@ -95,6 +95,25 @@ Variable* Scope::FindVar(const std::string& name) const { return FindVarInternal(name); } +std::vector Scope::FindVarFromChild(const std::string& name) const { + std::vector ret; + { + SCOPE_VARS_READER_LOCK + auto it = vars_.find(name); + if (it != vars_.end()) { + ret.push_back(it->second.get()); + } + } + { + SCOPE_KIDS_READER_LOCK + for (Scope* s : kids_) { + auto child_ret = s->FindVarFromChild(name); + ret.insert(ret.end(), child_ret.begin(), child_ret.end()); + } + } + return ret; +} + Variable* Scope::GetVar(const std::string& name) const { auto* var = FindVar(name); PADDLE_ENFORCE_NOT_NULL( diff --git a/paddle/fluid/framework/scope.h b/paddle/fluid/framework/scope.h index 1669fba1327e5..464d5f2097bbc 100644 --- a/paddle/fluid/framework/scope.h +++ b/paddle/fluid/framework/scope.h @@ -107,6 +107,8 @@ class Scope : public ScopeBase { /// Caller doesn't own the returned Variable. Variable* FindVar(const std::string& name) const; + std::vector FindVarFromChild(const std::string& name) const; + // Get a variable in the scope or any of its ancestors. Enforce /// the returned Variable is not nullptr Variable* GetVar(const std::string& name) const; diff --git a/paddle/fluid/framework/shape_inference.h b/paddle/fluid/framework/shape_inference.h index 31e3929362a04..417d85740ba7e 100644 --- a/paddle/fluid/framework/shape_inference.h +++ b/paddle/fluid/framework/shape_inference.h @@ -105,6 +105,8 @@ class InferShapeContext { virtual bool IsRunMKLDNNKernel() const = 0; + virtual Scope* GetScopePtr() const = 0; + virtual std::vector GetInputVarPtrs( const std::string &name) const = 0; virtual std::vector GetOutputVarPtrs( diff --git a/paddle/fluid/framework/var_type_traits.h b/paddle/fluid/framework/var_type_traits.h index 9fe67e1dcdff3..768567650f3fb 100644 --- a/paddle/fluid/framework/var_type_traits.h +++ b/paddle/fluid/framework/var_type_traits.h @@ -21,6 +21,8 @@ #include #include +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/memory/memory.h" #include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/string_array.h" @@ -98,6 +100,42 @@ class OrderedMultiDeviceLoDTensorBlockingQueueHolder; namespace paddle { namespace framework { +class GpuPinnedVector { +public: + GpuPinnedVector() {} + void cpu_to_pinedcpu(void* buf, size_t len) { + mem_cpu_ = memory::Alloc(phi::GPUPinnedPlace(), len); + memcpy(reinterpret_cast(mem_cpu_->ptr()), buf, len); + len_ = len; + } + void pinedcpu_to_gpu(paddle::gpuStream_t stream, phi::Place place) { + mem_gpu_ = memory::Alloc(place, len_); + cudaMemcpyAsync(reinterpret_cast(mem_gpu_->ptr()), reinterpret_cast(mem_cpu_->ptr()), + len_, cudaMemcpyHostToDevice, stream); + } + void cpu_to_gpu(void* buf, size_t len, paddle::gpuStream_t stream, phi::Place place) { + mem_cpu_ = memory::Alloc(phi::GPUPinnedPlace(), len); + memcpy(reinterpret_cast(mem_cpu_->ptr()), buf, len); + mem_gpu_ = memory::Alloc(place, len); + cudaMemcpyAsync(reinterpret_cast(mem_gpu_->ptr()), reinterpret_cast(mem_cpu_->ptr()), + len, cudaMemcpyHostToDevice, stream); + len_ = len; + } + template + Type* get_gpu_ptr() { + return reinterpret_cast(mem_gpu_->ptr()); + } + template + Type* get_cpu_ptr() { + return reinterpret_cast(mem_cpu_->ptr()); + } +private: + memory::allocation::AllocationPtr mem_cpu_; + memory::allocation::AllocationPtr mem_gpu_; + size_t len_; +}; + + const char *ToTypeName(int var_id); const std::type_index &VarTraitIdToTypeIndex(int var_id); int TypeIndexToVarTraitId(const std::type_index &type); @@ -189,7 +227,8 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl< #if defined(PADDLE_WITH_CNCL) cnclCliqueId, #endif - int, float, Vocab>; + int, float, Vocab, + GpuPinnedVector>; template struct VarTypeTrait { static_assert(VarTypeRegistry::IsRegistered(), "Must be registered type"); diff --git a/paddle/fluid/imperative/infer_shape_context.h b/paddle/fluid/imperative/infer_shape_context.h index f871e77fdf6e2..8309ab540b6cd 100644 --- a/paddle/fluid/imperative/infer_shape_context.h +++ b/paddle/fluid/imperative/infer_shape_context.h @@ -224,6 +224,10 @@ class DygraphInferShapeContext : public framework::InferShapeContext { return (op_kernel_type_ && (op_kernel_type_->data_layout_ == framework::DataLayout::kMKLDNN)); } + + framework::Scope* GetScopePtr() const override { + return nullptr; + } std::vector GetInputVarPtrs( const std::string& name) const override { diff --git a/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc b/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc index e7abf4689b1fc..dd671a5b2e4b4 100644 --- a/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc +++ b/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc @@ -62,18 +62,20 @@ class FusedSeqpoolCVMOp : public framework::OperatorWithKernel { if (ctx->IsRuntime()) { int batch_size = -1; auto inputs_tensor = ctx->GetInputVarPtrs("X"); + uint64_t tmp_var_key = 0; for (size_t i = 0; i < num_inputs; ++i) { const auto dims = ins_dims[i]; int rank = dims.size(); int cur_batch_size = 0; framework::Variable* x_var = BOOST_GET(framework::Variable*, inputs_tensor[i]); - const auto& x_tensor = x_var->Get(); - const auto& x_lod = x_tensor.lod(); + const auto x_tensor = x_var->GetMutable(); + tmp_var_key += (uint64_t)(x_tensor); + const auto& x_lod = x_tensor->lod(); if (x_lod.size() > 0) { cur_batch_size = x_lod[0].size() - 1; } else { - cur_batch_size = x_tensor.dims()[0]; + cur_batch_size = x_tensor->dims()[0]; } if (batch_size == -1) { batch_size = cur_batch_size; @@ -93,6 +95,41 @@ class FusedSeqpoolCVMOp : public framework::OperatorWithKernel { } outs_dims[i] = phi::make_ddim(out_dim); } + + //准备lod的gpu数据,不然放到computer里面会拖垮性能 + { + auto scope = ctx->GetScopePtr(); + auto& child_scope = scope->NewScope(); + std::string var_name = "FusedSeqpoolCVMOp_"; + var_name.append(std::to_string(tmp_var_key)); + auto var = child_scope.Var(var_name); + paddle::framework::GpuPinnedVector* pin_ptr = var->GetMutable(); + + std::vector mix_lods; + mix_lods.reserve(num_inputs * (batch_size + 1)); + for (size_t i = 0; i < num_inputs; ++i) { + framework::Variable* x_var = BOOST_GET(framework::Variable*, inputs_tensor[i]); + const auto& x_tensor = x_var->Get(); + const auto& x_lod = x_tensor.lod(); + if (x_lod.size() != 0) { + PADDLE_ENFORCE_EQ(x_lod.size(), 1, + platform::errors::PreconditionNotMet( + "The lod size of all input should be 1, " + "please cheack")); + PADDLE_ENFORCE_EQ(x_lod[0].size(), batch_size + 1, + platform::errors::PreconditionNotMet( + "The lod[0] size of all input should be batch_size + 1, " + "please cheack")); + mix_lods.insert(mix_lods.end(), x_lod[0].begin(), x_lod[0].end()); + } else { + mix_lods.push_back(0); + for (int i = 0; i < x_tensor.dims()[0]; i++) { + mix_lods.push_back(i + 1); + } + } + } + pin_ptr->cpu_to_pinedcpu(mix_lods.data(), mix_lods.size() * sizeof(size_t)); + } } else { for (size_t i = 0; i < num_inputs; ++i) { const auto dims = ins_dims[i]; @@ -222,6 +259,60 @@ class FusedSeqpoolCVMGradOp : public framework::OperatorWithKernel { ctx->ShareLoD("X", framework::GradVarName("X"), i, i); ctx->ShareDim("X", framework::GradVarName("X"), i, i); } + + //准备lod的gpu数据,不然放到computer里面会拖垮性能 + if (ctx->IsRuntime()) { + auto inputs_tensor = ctx->GetOutputVarPtrs(framework::GradVarName("X")); + size_t num_inputs = inputs_tensor.size(); + uint64_t tmp_var_key = 0; + framework::Variable* x_var = BOOST_GET(framework::Variable*, inputs_tensor[0]); + const LoDTensor* x_tensor = x_var->GetMutable(); + int batch_size = x_tensor->lod().size() ? x_tensor->lod()[0].size() - 1 : x_tensor->dims()[0]; + + std::vector mix_lods; + mix_lods.reserve(num_inputs * (batch_size + 1)); + for (size_t i = 0; i < num_inputs; i++) { + x_var = BOOST_GET(framework::Variable*, inputs_tensor[i]); + x_tensor = x_var->GetMutable(); + tmp_var_key += (uint64_t)(x_tensor); + const auto& x_lod = x_tensor->lod(); + if (x_lod.size() != 0) { + PADDLE_ENFORCE_EQ(x_lod.size(), 1, + platform::errors::PreconditionNotMet( + "The lod size of all in_grad should be 1, " + "please cheack")); + PADDLE_ENFORCE_EQ(x_lod[0].size(), batch_size + 1, + platform::errors::PreconditionNotMet( + "The lod[0] size of all in_grad should be batch_size + 1, " + "please cheack")); + mix_lods.insert(mix_lods.end(), x_lod[0].begin(), x_lod[0].end()); + } else { + mix_lods.push_back(0); + for (int i = 0; i < x_tensor->dims()[0]; i++) { + mix_lods.push_back(i + 1); + } + } + int cur_batch_size = x_tensor->lod().size() ? x_tensor->lod()[0].size() - 1 : x_tensor->dims()[0]; + PADDLE_ENFORCE_EQ(batch_size, cur_batch_size, + platform::errors::PreconditionNotMet( + "The batch size of all in_grad should be same, " + "please cheack, last batchsize is %d, current " + "batchsize is %d", + batch_size, cur_batch_size)); + } + PADDLE_ENFORCE_EQ(mix_lods.size(), num_inputs * (batch_size + 1), + platform::errors::PreconditionNotMet( + "please cheack")); + + std::string var_name = "FusedSeqpoolCVMGradOp_"; + var_name.append(std::to_string(tmp_var_key)); + auto scope = ctx->GetScopePtr(); + auto& child_scope = scope->NewScope(); + auto var = child_scope.Var(var_name); + paddle::framework::GpuPinnedVector* pin_ptr = var->GetMutable(); + pin_ptr->cpu_to_pinedcpu(mix_lods.data(), mix_lods.size() * sizeof(size_t)); + } + } protected: diff --git a/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cu b/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cu index 7c5d5b26da343..d792220e092fa 100644 --- a/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cu +++ b/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cu @@ -131,13 +131,29 @@ void FusedSeqpoolCVM(const framework::ExecutionContext output_data.size() * sizeof(T *), hipMemcpyHostToDevice, stream); #else + const auto& scope = ctx.scope(); + auto& child_scope = scope.NewScope(); + static uint64_t var_index = 0; + var_index++; + std::string var_name_1 = "FusedSeqpoolCVM_KERNEL_"; + var_name_1.append(std::to_string((uint64_t)(&scope))).append("_").append(std::to_string(var_index)); + auto var_1 = child_scope.Var(var_name_1); + paddle::framework::GpuPinnedVector* pinned_inputs = var_1->GetMutable(); T **gpu_input_values = reinterpret_cast(temp_ptr->ptr()); - platform::GpuMemcpyAsync(gpu_input_values, input_data.data(), + pinned_inputs->cpu_to_pinedcpu((void*)input_data.data(), input_data.size() * sizeof(T *)); + platform::GpuMemcpyAsync(gpu_input_values, pinned_inputs->get_cpu_ptr(), input_data.size() * sizeof(T *), cudaMemcpyHostToDevice, stream); + + var_index++; + std::string var_name_2 = "FusedSeqpoolCVM_KERNEL_"; + var_name_2.append(std::to_string((uint64_t)(&scope))).append("_").append(std::to_string(var_index)); + auto var_2 = child_scope.Var(var_name_2); + paddle::framework::GpuPinnedVector* pinned_outputs = var_2->GetMutable(); T **gpu_output_values = reinterpret_cast(&gpu_input_values[input_data.size()]); - platform::GpuMemcpyAsync(gpu_output_values, output_data.data(), + pinned_outputs->cpu_to_pinedcpu((void*)output_data.data(), output_data.size() * sizeof(T *)); + platform::GpuMemcpyAsync(gpu_output_values, pinned_outputs->get_cpu_ptr(), output_data.size() * sizeof(T *), cudaMemcpyHostToDevice, stream); #endif @@ -266,20 +282,41 @@ void FusedSeqpoolCVMGrad(const framework::ExecutionContext &ctx, cvm_data.size() * sizeof(T *), hipMemcpyHostToDevice, stream); #else + const auto& scope = ctx.scope(); + auto& child_scope = scope.NewScope(); + static uint64_t var_index = 0; + var_index++; + std::string var_name_1 = "FusedSeqpoolCVMGrad_"; + var_name_1.append(std::to_string((uint64_t)(&scope))).append("_").append(std::to_string(var_index)); + auto var_1 = child_scope.Var(var_name_1); + paddle::framework::GpuPinnedVector* pinned_tmp_1 = var_1->GetMutable(); T **gpu_out_grads_values = reinterpret_cast(temp_ptr->ptr()); - platform::GpuMemcpyAsync(gpu_out_grads_values, out_grads_data.data(), + pinned_tmp_1->cpu_to_pinedcpu((void*)out_grads_data.data(), out_grads_data.size() * sizeof(T *)); + platform::GpuMemcpyAsync(gpu_out_grads_values, pinned_tmp_1->get_cpu_ptr(), out_grads_data.size() * sizeof(T *), cudaMemcpyHostToDevice, stream); + var_index++; + std::string var_name_2 = "FusedSeqpoolCVMGrad_"; + var_name_2.append(std::to_string((uint64_t)(&scope))).append("_").append(std::to_string(var_index)); + auto var_2 = child_scope.Var(var_name_2); + paddle::framework::GpuPinnedVector* pinned_tmp_2 = var_2->GetMutable(); T **gpu_in_grads_values = reinterpret_cast(&gpu_out_grads_values[out_grads_data.size()]); - platform::GpuMemcpyAsync(gpu_in_grads_values, in_grads_data.data(), + pinned_tmp_2->cpu_to_pinedcpu((void*)in_grads_data.data(), in_grads_data.size() * sizeof(T *)); + platform::GpuMemcpyAsync(gpu_in_grads_values, pinned_tmp_2->get_cpu_ptr(), in_grads_data.size() * sizeof(T *), cudaMemcpyHostToDevice, stream); + var_index++; + std::string var_name_3 = "FusedSeqpoolCVMGrad_"; + var_name_3.append(std::to_string((uint64_t)(&scope))).append("_").append(std::to_string(var_index)); + auto var_3 = child_scope.Var(var_name_3); + paddle::framework::GpuPinnedVector* pinned_tmp_3 = var_3->GetMutable(); T **gpu_cvm_values = reinterpret_cast(&gpu_in_grads_values[in_grads_data.size()]); - platform::GpuMemcpyAsync(gpu_cvm_values, cvm_data.data(), + pinned_tmp_3->cpu_to_pinedcpu((void*)cvm_data.data(), cvm_data.size() * sizeof(T *)); + platform::GpuMemcpyAsync(gpu_cvm_values, pinned_tmp_3->get_cpu_ptr(), cvm_data.size() * sizeof(T *), cudaMemcpyHostToDevice, stream); #endif @@ -319,41 +356,26 @@ class FusedSeqpoolCVMCUDAKernel : public framework::OpKernel { int embedding_size = inputs[0]->numel() / inputs[0]->dims()[0]; int batch_size = inputs[0]->lod().size() ? inputs[0]->lod()[0].size() - 1 : inputs[0]->dims()[0]; - std::vector mix_lods; - mix_lods.reserve(slot_size * (batch_size + 1)); + + const size_t* mix_lods_data = nullptr; + + //逻辑转移到了infer-shape里面去了 + uint64_t tmp_var_key = 0; for (size_t i = 0; i < slot_size; ++i) { const auto *input = inputs[i]; - if (input->lod().size() != 0) { - auto lod = input->lod(); - PADDLE_ENFORCE_EQ(lod.size(), 1, - platform::errors::PreconditionNotMet( - "The lod size of all input should be 1, " - "please cheack")); - PADDLE_ENFORCE_EQ(lod[0].size(), batch_size + 1, - platform::errors::PreconditionNotMet( - "The lod[0] size of all input should be batch_size + 1, " - "please cheack")); - mix_lods.insert(mix_lods.end(), lod[0].begin(), lod[0].end()); - } else { - mix_lods.push_back(0); - for (int i = 0; i < input->dims()[0]; i++) { - mix_lods.push_back(i + 1); - } - } - int cur_batch_size = - input->lod().size() ? input->lod()[0].size() - 1 : input->dims()[0]; - PADDLE_ENFORCE_EQ(batch_size, cur_batch_size, - platform::errors::PreconditionNotMet( - "The batch size of all input should be same, " - "please cheack, last batchsize is %d, current " - "batchsize is %d", - batch_size, cur_batch_size)); + tmp_var_key += (uint64_t)input; } - PADDLE_ENFORCE_EQ(mix_lods.size(), slot_size * (batch_size + 1), + std::string var_name = "FusedSeqpoolCVMOp_"; + var_name.append(std::to_string(tmp_var_key)); + const auto& scope = ctx.scope(); + auto tmp_var_vec = scope.FindVarFromChild(var_name); + PADDLE_ENFORCE_EQ(tmp_var_vec.size(), 1, platform::errors::PreconditionNotMet( "please cheack")); - paddle::framework::MixVector mix_lods_v(&mix_lods); - auto mix_lods_data = mix_lods_v.CUDAData(ctx.GetPlace()); + auto pin_ptr = tmp_var_vec[0]->GetMutable(); + pin_ptr->pinedcpu_to_gpu(ctx.template device_context().stream(), ctx.GetPlace()); + mix_lods_data = pin_ptr->get_gpu_ptr(); + for (size_t i = 0; i < slot_size; ++i) { const auto *input = inputs[i]; input_data[i] = reinterpret_cast(input->data()); @@ -393,40 +415,24 @@ class FusedSeqpoolCVMGradCUDAKernel : public framework::OpKernel { int embedding_size = in_grads[0]->numel() / in_grads[0]->dims()[0]; int batch_size = in_grads[0]->lod().size() ? in_grads[0]->lod()[0].size() - 1 : in_grads[0]->dims()[0]; - std::vector mix_lods; - mix_lods.reserve(slot_size * (batch_size + 1)); + + const size_t* mix_lods_data = nullptr; + //逻辑转移到了infer-shape里面去了 + uint64_t tmp_var_key = 0; for (size_t i = 0; i < slot_size; ++i) { auto *in_grad = in_grads[i]; - if (in_grad->lod().size() != 0) { - auto lod = in_grad->lod(); - PADDLE_ENFORCE_EQ(lod.size(), 1, - platform::errors::PreconditionNotMet( - "The lod size of all in_grad should be 1, " - "please cheack")); - PADDLE_ENFORCE_EQ(lod[0].size(), batch_size + 1, - platform::errors::PreconditionNotMet( - "The lod[0] size of all in_grad should be batch_size + 1, " - "please cheack")); - mix_lods.insert(mix_lods.end(), lod[0].begin(), lod[0].end()); - } else { - mix_lods.push_back(0); - for (int i = 0; i < in_grad->dims()[0]; i++) { - mix_lods.push_back(i + 1); - } - } - int cur_batch_size = in_grad->lod().size() ? in_grad->lod()[0].size() - 1 : in_grad->dims()[0]; - PADDLE_ENFORCE_EQ(batch_size, cur_batch_size, - platform::errors::PreconditionNotMet( - "The batch size of all in_grad should be same, " - "please cheack, last batchsize is %d, current " - "batchsize is %d", - batch_size, cur_batch_size)); + tmp_var_key += (uint64_t)in_grad; } - PADDLE_ENFORCE_EQ(mix_lods.size(), slot_size * (batch_size + 1), + std::string var_name = "FusedSeqpoolCVMGradOp_"; + var_name.append(std::to_string(tmp_var_key)); + const auto& scope = ctx.scope(); + auto tmp_var_vec = scope.FindVarFromChild(var_name); + PADDLE_ENFORCE_EQ(tmp_var_vec.size(), 1, platform::errors::PreconditionNotMet( "please cheack")); - paddle::framework::MixVector mix_lods_v(&mix_lods); - auto mix_lods_data = mix_lods_v.CUDAData(ctx.GetPlace()); + auto pin_ptr = tmp_var_vec[0]->GetMutable(); + pin_ptr->pinedcpu_to_gpu(ctx.template device_context().stream(), ctx.GetPlace()); + mix_lods_data = pin_ptr->get_gpu_ptr(); for (size_t i = 0; i < slot_size; ++i) { auto *in_grad = in_grads[i];