diff --git a/paddle/fluid/operators/flatten_op.cc b/paddle/fluid/operators/flatten_op.cc index 14f2e9061b742..0c67578420a72 100644 --- a/paddle/fluid/operators/flatten_op.cc +++ b/paddle/fluid/operators/flatten_op.cc @@ -365,6 +365,25 @@ class FlattenContiguousRangeOp : public framework::OperatorWithKernel { return out_shape; } + + framework::KernelSignature GetExpectedPtKernelArgs( + const framework::ExecutionContext &ctx) const override { + if (ctx.HasOutput("XShape")) { + return std::make_pair( + "flatten_contiguous_range.mid", + std::make_tuple( + paddle::SmallVector({"X"}), + paddle::SmallVector({"start_axis", "stop_axis"}), + paddle::SmallVector({"Out", "XShape"}))); + } else { + return std::make_pair( + "flatten_contiguous_range", + std::make_tuple( + paddle::SmallVector({"X"}), + paddle::SmallVector({"start_axis", "stop_axis"}), + paddle::SmallVector({"Out"}))); + } + } }; class FlattenContiguousRangeOpMaker : public FlattenOpMaker { diff --git a/paddle/fluid/operators/flatten_op.h b/paddle/fluid/operators/flatten_op.h index efcb0cbe2e2a8..40fd7b05d9a49 100644 --- a/paddle/fluid/operators/flatten_op.h +++ b/paddle/fluid/operators/flatten_op.h @@ -15,10 +15,13 @@ limitations under the License. */ #pragma once #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tcmpt_utils.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/pooling.h" #include "paddle/fluid/platform/device_context.h" +#include "paddle/tcmpt/api/include/core.h" +#include "paddle/tcmpt/api/include/manipulation.h" namespace paddle { namespace operators { @@ -122,13 +125,17 @@ class FlattenContiguousRangeKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext &context) const override { auto *in = context.Input("X"); auto *out = context.Output("Out"); - auto out_dims = out->dims(); - out->mutable_data(context.GetPlace(), in->type()); - framework::TensorCopy( - *in, context.GetPlace(), - context.template device_context(), out); - out->Resize(out_dims); + auto &start_axis = context.Attr("start_axis"); + auto &stop_axis = context.Attr("stop_axis"); + auto &dev_ctx = context.device_context(); + auto pt_x = framework::MakeTensorImpl(*in, in->place(), + in->type()); + auto pt_out = framework::MakeTensorImpl(*out, out->place(), + out->type()); + + // call new kernel + pt::Flatten(dev_ctx, *pt_x.get(), start_axis, stop_axis, pt_out.get()); } };