Skip to content

Commit

Permalink
Reshape multi-dimensional constants to 1d.
Browse files Browse the repository at this point in the history
The LLVM lowering doesn't support arbitrary shapes.

PiperOrigin-RevId: 633203497
  • Loading branch information
jreiffers authored and copybara-github committed May 13, 2024
1 parent 468cdcc commit 9f9d7f6
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
7 changes: 7 additions & 0 deletions xla/service/gpu/fusions/mlir/lower_tensors.cc
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,13 @@ mlir::LLVM::GlobalOp CreateGlobalOp(mlir::Attribute value,
mlir::ModuleOp module, bool is_constant,
int addr_space,
mlir::ImplicitLocOpBuilder& b) {
if (auto elements = mlir::dyn_cast_or_null<mlir::DenseElementsAttr>(value)) {
// The lowering to LLVM only works for 1d tensors or those with trailing
// unit dimensions.
value = elements.reshape(mlir::RankedTensorType::get(
{elements.getNumElements()}, elements.getElementType()));
}

Type element_type = shaped_ty.getElementType();
// Needed to support complex element type.
mlir::LLVMTypeConverter converter(b.getContext());
Expand Down
4 changes: 2 additions & 2 deletions xla/service/gpu/fusions/mlir/tests/lower_tensors.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ module {
return %0 : f32
}
}
// CHECK: llvm.mlir.global private constant @global_cst_0(dense<[
// CHECK-SAME: [1.000000e+00], [2.000000e+00]]> : tensor<2x1xf32>) {addr_space = 0 : i32} : !llvm.array<2 x f32>
// CHECK: llvm.mlir.global private constant @global_cst_0(dense<
// CHECK-SAME: [1.000000e+00, 2.000000e+00]> : tensor<2xf32>) {addr_space = 0 : i32} : !llvm.array<2 x f32>
// CHECK: @extract_from_constant
// CHECK: %[[ADDR_OF:.*]] = llvm.mlir.addressof @global_cst_0 : !llvm.ptr
// CHECK: %[[GEP:.*]] = llvm.getelementptr inbounds %[[ADDR_OF]][%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32
Expand Down

0 comments on commit 9f9d7f6

Please sign in to comment.