Skip to content

Commit

Permalink
Prettyprinting for Einsum and UnaryEinsum (#727)
Browse files Browse the repository at this point in the history
Had this change in an internal repo from awhile ago, but did not submit
since I was unsure if these would move to CHLO. This change can be
submitted independently to that work.

```
%0 = "stablehlo.einsum"(%arg0, %arg1) { einsum_config = "ab,bc->ac" } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32>
%1 = "stablehlo.unary_einsum"(%arg0) { einsum_config = "ab->a" } : (tensor<8x16xf32>) -> tensor<8xf32>

-->

%0 = stablehlo.einsum %arg0, %arg1, config = "ab,bc->ac" : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32>
%1 = stablehlo.unary_einsum %arg0, config = "ab->a" : (tensor<8x16xf32>) -> tensor<8xf32>
```
  • Loading branch information
GleasonK authored Dec 10, 2022
1 parent 570b112 commit 63813b3
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 2 deletions.
4 changes: 2 additions & 2 deletions docs/status.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ one of the following tracking labels.
| dynamic_reshape | no | revisit | infeasible | yes | no |
| dynamic_slice | yes | revisit | yes | yes | no |
| dynamic_update_slice | yes | yes | yes | yes | no |
| einsum | no | revisit | no | no | no |
| einsum | no | revisit | no | yes | no |
| exponential | yes | yes | yes | yes | no |
| exponential_minus_one | yes | yes | yes | yes | no |
| fft | yes | revisit | yes | yes | no |
Expand Down Expand Up @@ -152,7 +152,7 @@ one of the following tracking labels.
| transpose | yes | yes | yes | yes | yes |
| triangular_solve | yes | revisit | yes | no | no |
| tuple | yes | yes | yes | yes | no |
| unary_einsum | no | revisit | no | no | no |
| unary_einsum | no | revisit | no | yes | no |
| uniform_dequantize | no | yes* | yes* | yes | no |
| uniform_quantize | no | yes* | infeasible | yes | no |
| while | yes | revisit | yes | revisit | no |
Expand Down
8 changes: 8 additions & 0 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2240,6 +2240,10 @@ def StableHLO_EinsumOp: StableHLO_Op<"einsum", [Pure]>, BASE_EinsumOp {

// TODO(hinsu): Canonicalize to lower this client side HLO op to server
// side HLO ops.

let assemblyFormat = [{
$lhs `,` $rhs `,` `config` `=` $einsum_config attr-dict `:` functional-type(operands, results)
}];
}

def StableHLO_UnaryEinsumOp: StableHLO_Op<"unary_einsum", [Pure]>, BASE_EinsumOp {
Expand All @@ -2249,6 +2253,10 @@ def StableHLO_UnaryEinsumOp: StableHLO_Op<"unary_einsum", [Pure]>, BASE_EinsumOp
);

let results = (outs HLO_Tensor);

let assemblyFormat = [{
$operand `,` `config` `=` $einsum_config attr-dict `:` functional-type(operands, results)
}];
}

def StableHLO_FftOp: StableHLO_Op<"fft", [InferTensorType, Pure]> {
Expand Down
9 changes: 9 additions & 0 deletions stablehlo/tests/print_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,15 @@ func.func @dimension_attr(%arg0 : tensor<1x2xf32>, %arg1 : tensor<3xi32>, %arg2
"stablehlo.return"() : () -> ()
}

// CHECK-LABEL: func @op_einsum
func.func @op_einsum(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8xf32> {
// CHECK: %0 = stablehlo.einsum %arg0, %arg1, config = "ab,bc->ac" : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32>
// CHECK-NEXT: %1 = stablehlo.unary_einsum %arg0, config = "ab->a" : (tensor<8x16xf32>) -> tensor<8xf32>
%0 = "stablehlo.einsum"(%arg0, %arg1) { einsum_config = "ab,bc->ac" } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32>
%1 = "stablehlo.unary_einsum"(%arg0) { einsum_config = "ab->a" } : (tensor<8x16xf32>) -> tensor<8xf32>
func.return %1 : tensor<8xf32>
}

// CHECK-LABEL: func @fft_op
func.func @fft_op(%arg0: tensor<16xcomplex<f32>>) -> tensor<16xcomplex<f32>> {
// CHECK: %0 = stablehlo.fft %arg0, type = FFT, length = [16] : (tensor<16xcomplex<f32>>) -> tensor<16xcomplex<f32>>
Expand Down

0 comments on commit 63813b3

Please sign in to comment.