Skip to content

Commit

Permalink
check the half_type for flatten_decomp
Browse files Browse the repository at this point in the history
  • Loading branch information
zeroRains committed Mar 14, 2024
1 parent 64c4a74 commit 8ed614b
Showing 1 changed file with 49 additions and 9 deletions.
58 changes: 49 additions & 9 deletions paddle/fluid/primitive/composite/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -735,9 +735,16 @@ std::tuple<Tensor, Tensor> flatten_decomp(const Tensor& x,
PADDLE_THROW(phi::errors::Unimplemented(
"end_axis must be greater than or equal to start_axis."));
}
auto org_dtype = x.dtype();
Tensor x_cast = x;

if (has_dynamic_shape(x.shape())) {
auto x_shape = shape<T>(x);
bool need_cast = is_half_dtype(org_dtype);
if (need_cast) {
x_cast = cast<T>(x, DataType::FLOAT32);
}

if (has_dynamic_shape(x_dim)) {
auto x_shape = shape<T>(x_cast);
Tensor x_shape_tensor = full<T>({1}, 0, x_shape.dtype());
std::vector<Tensor> tmp_shape;
tmp_shape.push_back(x_shape_tensor);
Expand All @@ -748,7 +755,14 @@ std::tuple<Tensor, Tensor> flatten_decomp(const Tensor& x,
x_shape_tensor =
backend::full_with_tensor<T>(x_shape_tensor, 0.0, DataType::FLOAT32);
if (end_axis == start_axis) {
return std::make_tuple(backend::reshape<T>(x, x_shape), x_shape_tensor);
Tensor out = backend::reshape<T>(x_cast, x_shape);
Tensor res;
if (need_cast) {
res = cast<T>(out, org_dtype);
} else {
res = out;
}
return std::make_tuple(res, x_shape_tensor);
}
std::vector<Tensor> out_shape;
for (size_t i = 0; i < x_dim.size();) {
Expand All @@ -767,18 +781,38 @@ std::tuple<Tensor, Tensor> flatten_decomp(const Tensor& x,
}
}
Tensor out_shape_tensor = concat<T>(out_shape);
return std::make_tuple(backend::reshape<T>(x, out_shape_tensor),
x_shape_tensor);
Tensor out = backend::reshape<T>(x_cast, out_shape_tensor);
Tensor res;
if (need_cast) {
res = cast<T>(out, org_dtype);
} else {
res = out;
}
return std::make_tuple(res, x_shape_tensor);
} else {
std::vector<int64_t> tmp_shape(x_dim);
tmp_shape.insert(tmp_shape.begin(), 0);
auto xshape = full<T>(tmp_shape, 0.0, DataType::FLOAT32);
if (x_dim.size() == 0) {
std::vector<int64_t> res_shape(1, 1);
return std::make_tuple(reshape<T>(x, res_shape), xshape);
Tensor out = reshape<T>(x_cast, res_shape);
Tensor res;
if (need_cast) {
res = cast<T>(out, org_dtype);
} else {
res = out;
}
return std::make_tuple(res, xshape);
}
if (end_axis == start_axis) {
return std::make_tuple(reshape<T>(x, x_dim), xshape);
Tensor out = reshape<T>(x_cast, x_dim);
Tensor res;
if (need_cast) {
res = cast<T>(out, org_dtype);
} else {
res = out;
}
return std::make_tuple(res, xshape);
}

int slice_numel = 1;
Expand All @@ -793,8 +827,14 @@ std::tuple<Tensor, Tensor> flatten_decomp(const Tensor& x,
for (size_t i = end_axis + 1; i < x_dim.size(); ++i) {
out_shape.push_back(x_dim[i]);
}

return std::make_tuple(reshape<T>(x, out_shape), xshape);
Tensor out = reshape<T>(x_cast, out_shape);
Tensor res;
if (need_cast) {
res = cast<T>(out, org_dtype);
} else {
res = out;
}
return std::make_tuple(res, xshape);
}
}

Expand Down

0 comments on commit 8ed614b

Please sign in to comment.