Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
co63oc committed Jun 25, 2024
1 parent 9cf56d1 commit 428e3b5
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 0 deletions.
18 changes: 18 additions & 0 deletions paddle/phi/kernels/gpu/expand_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ void ExpandGradKernel(const Context& ctx,
}
}

template <typename T, typename Context>
void LegacyExpandGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const IntArray& shape,
DenseTensor* x_grad) {
ExpandGradKernel<T, Context>(ctx, x, out_grad, shape, x_grad);
}
} // namespace phi

PD_REGISTER_KERNEL(expand_grad,
Expand All @@ -57,3 +65,13 @@ PD_REGISTER_KERNEL(expand_grad,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(legacy_expand_grad,
GPU,
ALL_LAYOUT,
phi::LegacyExpandGradKernel,
float,
double,
int,
int64_t,
phi::dtype::float16) {}
21 changes: 21 additions & 0 deletions paddle/phi/kernels/gpu/expand_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "paddle/phi/kernels/expand_kernel.h"

#include "glog/logging.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
Expand All @@ -31,6 +32,8 @@ void ExpandKernel(const Context& ctx,
auto diff = expand_shape.size() - x.dims().size();
auto out_shape = common::vectorize<int64_t>(x.dims());
out_shape.insert(out_shape.begin(), diff, 1);
VLOG(2) << "expand_shape: " << expand_shape;
VLOG(2) << "out_shape: " << out_shape;
for (size_t i = 0; i < out_shape.size(); ++i) {
PADDLE_ENFORCE_NE(
expand_shape[i],
Expand Down Expand Up @@ -76,6 +79,13 @@ void ExpandKernel(const Context& ctx,
phi::funcs::BroadcastKernel<T>(ctx, ins, &outs, kps::IdentityFunctor<T>());
}

template <typename T, typename Context>
void LegacyExpandKernel(const Context& ctx,
const DenseTensor& x,
const IntArray& shape,
DenseTensor* out) {
ExpandKernel<T, Context>(ctx, x, shape, out);
}
} // namespace phi

PD_REGISTER_KERNEL(expand,
Expand All @@ -94,3 +104,14 @@ PD_REGISTER_KERNEL(expand,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(legacy_expand,
GPU,
ALL_LAYOUT,
phi::LegacyExpandKernel,
float,
double,
int,
int64_t,
bool,
phi::dtype::float16) {}

0 comments on commit 428e3b5

Please sign in to comment.