Skip to content

Commit

Permalink
exchange assign and assign_raw kernel name (PaddlePaddle#41625)
Browse files Browse the repository at this point in the history
* exchange assign and assign_raw kernel name

* fix register error
  • Loading branch information
MingMingShangTian committed Apr 13, 2022
1 parent f72ba37 commit dbd15d3
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 24 deletions.
28 changes: 14 additions & 14 deletions paddle/phi/kernels/assign_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,22 @@
namespace phi {

template <typename Context>
void AssignRawKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
void AssignKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
Copy<Context>(dev_ctx, x, x.place(), false, out);
}

template <typename Context>
void AssignKernel(const Context& dev_ctx,
paddle::optional<const DenseTensor&> x,
DenseTensor* out) {
void AssignRawKernel(const Context& dev_ctx,
paddle::optional<const DenseTensor&> x,
DenseTensor* out) {
if (x) {
if (!x->IsInitialized()) {
return;
}
auto& x_tensor = *x.get_ptr();
AssignRawKernel<Context>(dev_ctx, x_tensor, out);
AssignKernel<Context>(dev_ctx, x_tensor, out);
}
}

Expand Down Expand Up @@ -111,14 +111,14 @@ void AssignValueKernel(const Context& dev_ctx,

} // namespace phi

PD_REGISTER_GENERAL_KERNEL(
assign, CPU, ALL_LAYOUT, phi::AssignKernel<phi::CPUContext>, ALL_DTYPE) {}

PD_REGISTER_GENERAL_KERNEL(assign_raw,
CPU,
ALL_LAYOUT,
phi::AssignRawKernel<phi::CPUContext>,
ALL_DTYPE) {}

PD_REGISTER_GENERAL_KERNEL(
assign, CPU, ALL_LAYOUT, phi::AssignKernel<phi::CPUContext>, ALL_DTYPE) {
ALL_DTYPE) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_GENERAL_KERNEL(assign_array,
Expand All @@ -136,13 +136,13 @@ PD_REGISTER_KERNEL(assign_value,
int64_t) {}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_GENERAL_KERNEL(
assign, GPU, ALL_LAYOUT, phi::AssignKernel<phi::GPUContext>, ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(assign_raw,
GPU,
ALL_LAYOUT,
phi::AssignRawKernel<phi::GPUContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(
assign, GPU, ALL_LAYOUT, phi::AssignKernel<phi::GPUContext>, ALL_DTYPE) {
ALL_DTYPE) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_GENERAL_KERNEL(assign_array,
Expand Down
12 changes: 6 additions & 6 deletions paddle/phi/kernels/assign_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@
namespace phi {

template <typename Context>
void AssignRawKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out);
void AssignKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out);

// In order to be compatible with the `AsDispensable` input in the original
// assign op maker, the input parameter here needs to be dispensable, but
// this looks weird
template <typename Context>
void AssignKernel(const Context& dev_ctx,
paddle::optional<const DenseTensor&> x,
DenseTensor* out);
void AssignRawKernel(const Context& dev_ctx,
paddle::optional<const DenseTensor&> x,
DenseTensor* out);

template <typename Context>
void AssignArrayKernel(const Context& dev_ctx,
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/ops/compat/assign_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ KernelSignature AssignOpArgumentMapping(const ArgumentMappingContext& ctx) {
} else if (ctx.IsSelectedRowsInput("X")) {
return KernelSignature("assign_sr", {"X"}, {}, {"Out"});
} else {
return KernelSignature("assign", {"X"}, {}, {"Out"});
return KernelSignature("assign_raw", {"X"}, {}, {"Out"});
}
} else {
return KernelSignature("assign", {"X"}, {}, {"Out"});
return KernelSignature("assign_raw", {"X"}, {}, {"Out"});
}
}

Expand Down
2 changes: 1 addition & 1 deletion python/paddle/utils/code_gen/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@
infer_meta :
func : UnchangedInferMeta
kernel :
func : assign_raw
func : assign
backward : assign_grad

# atan
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/utils/code_gen/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@
func : UnchangedInferMeta
param : [out_grad]
kernel :
func : assign_raw
func : assign

- backward_api : atan2_grad
forward : atan2 (Tensor x, Tensor y) -> Tensor(out)
Expand Down

0 comments on commit dbd15d3

Please sign in to comment.