Skip to content

Commit

Permalink
PR #19066: [XLA:CPU][oneDNN] Handle oneDNN scalar
Browse files Browse the repository at this point in the history
Imported from GitHub PR #19066

This PR makes sure oneDNN handles the scalar properly.
Copybara import of the project:

--
2fb157a by Mahmoud Abuzaina <[email protected]>:

Handle oneDNN scalar

--
77a39b6 by Mahmoud Abuzaina <[email protected]>:

Addressed review comments

--
32b5aba by Mahmoud Abuzaina <[email protected]>:

Return output instead of having parameter

--
576e244 by Mahmoud Abuzaina <[email protected]>:

Unpack the pair return

Merging this change closes #19066

COPYBARA_INTEGRATE_REVIEW=#19066 from Intel-tensorflow:mabuzain/handle-onednn-scalar 576e244
PiperOrigin-RevId: 714289598
  • Loading branch information
mahmoud-abuzaina authored and Google-ML-Automation committed Jan 11, 2025
1 parent 6d25be3 commit 724919f
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 18 deletions.
45 changes: 27 additions & 18 deletions xla/service/cpu/onednn_memory_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,34 @@ MemrefInfoHandler CreateMemrefInfoFromLiteral(const Literal* literal) {
return CreateMemrefFromShape(shape, buf);
}

std::pair<std::vector<int64_t>, std::vector<int64_t>> GetDimsStrides(
const Shape& shape) {
// oneDNN handles scalar as a vector of size 1.
const bool is_scalar = shape.rank() == 0;
int64_t rank = is_scalar ? 1 : shape.rank();
std::vector<int64_t> strides(rank);
std::vector<int64_t> scalar_shape(1, 1);
absl::Span<const int64_t> dimensions =
is_scalar ? scalar_shape : shape.dimensions();
std::vector<int64_t> dims(dimensions.begin(), dimensions.end());
if (is_scalar) {
strides[0] = 1;
} else {
int64_t stride = 1;
for (int i : shape.layout().minor_to_major()) {
strides.at(i) = stride;
stride *= dims.at(i);
}
}
return std::make_pair(dims, strides);
}

StackAlloca GetAllocaAndEmitMemrefInfo(llvm::IRBuilderBase& builder,
const llvm_ir::IrArray& ir_array) {
const Shape& shape = ir_array.GetShape();
int64_t rank = shape.rank();
absl::Span<const int64_t> dims = shape.dimensions();

std::vector<int64_t> strides(rank);
int64_t stride = 1;
for (int i : shape.layout().minor_to_major()) {
strides.at(i) = stride;
stride *= dims.at(i);
}
// oneDNN handles scalar as a vector of size 1.
int64_t rank = shape.rank() == 0 ? 1 : shape.rank();
auto [dims, strides] = GetDimsStrides(shape);

// Type of struct
llvm::Type* i64_type = builder.getInt64Ty();
Expand Down Expand Up @@ -184,17 +200,10 @@ absl::StatusOr<dnnl::memory::desc> TransposeLastTwoDims(
}

dnnl::memory::desc ShapeToMemDesc(const Shape& shape) {
auto dimensions = shape.dimensions();
if (dimensions.empty()) {
auto [dims, strides] = GetDimsStrides(shape);
if (dims.empty()) {
return dnnl::memory::desc{};
}
auto dims = dnnl::memory::dims(dimensions.begin(), dimensions.end());
dnnl::memory::dims strides(dims.size());
dnnl::memory::dim stride = 1;
for (auto i : shape.layout().minor_to_major()) {
strides.at(i) = stride;
stride *= dims.at(i);
}
auto dt = ToOneDnnDataType(static_cast<PrimitiveType>(shape.element_type()));
return dnnl::memory::desc(dims, dt, strides);
}
Expand Down
17 changes: 17 additions & 0 deletions xla/service/cpu/tests/onednn_convolution_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,23 @@ TEST_P(ConvolutionTest, Simple2DTest1) {
RunCompareAndMatchOptimizedHlo(outline, {});
}

TEST_P(ConvolutionTest, SimpleScalarTest) {
const absl::string_view outline = R"(
HloModule convolution.test
ENTRY convolution.test {
arg.0 = $dtype[1,22,22,1] parameter(0)
arg.1 = $dtype[1] parameter(1)
reshape.1 = $dtype[1,1,1,1] reshape(arg.1)
convolution.0 = $dtype[1,14,14,1] convolution(arg.0, reshape.1),
window={size=1x1 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f
tuple.0 = ($dtype[1,14,14,1]) tuple(convolution.0)
ROOT gte.0 = $dtype[1,14,14,1] get-tuple-element(tuple.0), index=0
})";

RunCompareAndMatchOptimizedHlo(outline, {});
}

TEST_P(ConvolutionTest, Simple3DTest1) {
const absl::string_view outline = R"(
HloModule convolution.test
Expand Down

0 comments on commit 724919f

Please sign in to comment.