Skip to content

Commit

Permalink
Addressed review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mahmoud-abuzaina committed Jan 7, 2025
1 parent 10dbb32 commit 77a39b6
Showing 1 changed file with 22 additions and 22 deletions.
44 changes: 22 additions & 22 deletions xla/service/cpu/onednn_memory_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,14 @@ MemrefInfoHandler CreateMemrefInfoFromLiteral(const Literal* literal) {
return CreateMemrefFromShape(shape, buf);
}

StackAlloca GetAllocaAndEmitMemrefInfo(llvm::IRBuilderBase& builder,
const llvm_ir::IrArray& ir_array) {
const Shape& shape = ir_array.GetShape();
// oneDNN handles scalar as a vector of size 1.
int64_t rank = shape.rank() == 0 ? 1 : shape.rank();
void SetDimsStrides(const Shape& shape, std::vector<int64_t>& dims,
std::vector<int64_t>& strides) {
XLA_LIGHTWEIGHT_CHECK(strides.size() > 0);
std::vector<int64_t> scalar_shape(1, 1);
absl::Span<const int64_t> dims =
absl::Span<const int64_t> dimensions =
shape.dimensions().size() == 0 ? scalar_shape : shape.dimensions();

std::vector<int64_t> strides(rank);
if (shape.dimensions().size() == 0) {
dims.assign(dimensions.begin(), dimensions.end());
if (shape.rank() == 0) {
// Scalar case.
strides[0] = 1;
} else {
Expand All @@ -93,6 +90,16 @@ StackAlloca GetAllocaAndEmitMemrefInfo(llvm::IRBuilderBase& builder,
stride *= dims.at(i);
}
}
}

StackAlloca GetAllocaAndEmitMemrefInfo(llvm::IRBuilderBase& builder,
const llvm_ir::IrArray& ir_array) {
const Shape& shape = ir_array.GetShape();
// oneDNN handles scalar as a vector of size 1.
int64_t rank = shape.rank() == 0 ? 1 : shape.rank();
std::vector<int64_t> dims;
std::vector<int64_t> strides(rank);
SetDimsStrides(shape, dims, strides);

// Type of struct
llvm::Type* i64_type = builder.getInt64Ty();
Expand Down Expand Up @@ -192,21 +199,14 @@ absl::StatusOr<dnnl::memory::desc> TransposeLastTwoDims(
}

dnnl::memory::desc ShapeToMemDesc(const Shape& shape) {
std::vector<int64_t> scalar_shape(1, 1);
auto dimensions = shape.rank() == 0 ? scalar_shape : shape.dimensions();
if (dimensions.empty()) {
// oneDNN handles scalar as a vector of size 1.
int64_t rank = shape.rank() == 0 ? 1 : shape.rank();
std::vector<int64_t> dims;
std::vector<int64_t> strides(rank);
SetDimsStrides(shape, dims, strides);
if (dims.empty()) {
return dnnl::memory::desc{};
}
auto dims = dnnl::memory::dims(dimensions.begin(), dimensions.end());
dnnl::memory::dims strides(dims.size());
dnnl::memory::dim stride = 1;
if (shape.rank() == 0) {
strides[0] = 1;
}
for (auto i : shape.layout().minor_to_major()) {
strides.at(i) = stride;
stride *= dims.at(i);
}
auto dt = ToOneDnnDataType(static_cast<PrimitiveType>(shape.element_type()));
return dnnl::memory::desc(dims, dt, strides);
}
Expand Down

0 comments on commit 77a39b6

Please sign in to comment.