Skip to content

Commit

Permalink
Addressed Eugene's feedback. Added better type pretty print for Compl…
Browse files Browse the repository at this point in the history
…exOp.
  • Loading branch information
GleasonK committed Aug 30, 2022
1 parent 8d33f62 commit a7838fb
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 39 deletions.
138 changes: 104 additions & 34 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include <array>
#include <cstdint>
#include <functional>
#include <iostream> // FIXME
#include <numeric>
#include <set>
#include <unordered_map>
Expand Down Expand Up @@ -1855,8 +1856,7 @@ LogicalResult verifyCollectivePermuteSourceTargetPairs(
}

LogicalResult CollectivePermuteOp::verify() {
return verifyCollectivePermuteSourceTargetPairs(*this,
source_target_pairs());
return verifyCollectivePermuteSourceTargetPairs(*this, source_target_pairs());
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2861,6 +2861,98 @@ LogicalResult ClampOp::reifyReturnTypeShapes(
// ComplexOp
//===----------------------------------------------------------------------===//

namespace {
// Utility function, used by printSelectOpType and
// printSameOperandsAndResultType. Given a FunctionType, assign the types
// to operands and results, erroring if any mismatch in number of operands
// or results occurs.
ParseResult assignFromFunctionType(OpAsmParser& parser, llvm::SMLoc loc,
ArrayRef<Type*> operands, Type& result,
FunctionType& fnType) {
assert(fnType);
if (fnType.getInputs().size() != operands.size()) {
return parser.emitError(loc)
<< operands.size() << " operands present, but expected "
<< fnType.getInputs().size();
}

// Set operand types to function input types
for (auto [operand, input] : llvm::zip(operands, fnType.getInputs())) {
*operand = input;
}

// Set result type
if (fnType.getResults().size() != 1) {
return parser.emitError(loc, "expected single output");
}
result = fnType.getResults()[0];

return success();
}

// getInferredComplexType takes a complex tensor type and returns a
// type that maintains the shape, but removes the complex type for the
// underlying data type
// Ex: tensor<4xcomplex<f32>> --> tensor<4xf32>
Type getInferredComplexType(Type result) {
assert(result.isa<TensorType>() &&
result.cast<TensorType>().getElementType().isa<ComplexType>());
TensorType tensorTy = result.cast<TensorType>();
ComplexType complexTy = tensorTy.getElementType().cast<ComplexType>();
Type elementTy = complexTy.getElementType();

return hlo::getSameShapeTensorType(tensorTy, elementTy);
}

// ComplexOpType - only print result type if the inferred complex type
// matches all operand types.
//
// Inferring operand types for complex ops:
// %0 = mhlo.complex %1, %2 : tensor<4xcomplex<f32>>
// %0 : tensor<4xcomplex<f32>>
// %1 : tensor<4xf32>
// %2 : tensor<4xf32>
void printComplexOpType(OpAsmPrinter& p, Operation* op, Type lhs, Type rhs,
Type result) {
Type inferredResult = getInferredComplexType(result);

if (lhs != inferredResult || rhs != inferredResult) {
p.printFunctionalType(op);
return;
}

p.printType(result);
}

ParseResult parseComplexOpType(OpAsmParser& parser, Type& lhs, Type& rhs,
Type& result) {
// Operand and result types are the same, use copy constructor
llvm::SMLoc loc = parser.getCurrentLocation();
Type type;
if (failed(parser.parseType(type))) {
return failure();
}

// Handle if function type, all operand types did not match result type.
if (auto fnType = type.dyn_cast<FunctionType>()) {
return assignFromFunctionType(parser, loc, {&lhs, &rhs}, result, fnType);
}

// Otherwise, operand type is inferred from complex type
if (!type.isa<TensorType>() ||
!type.dyn_cast<TensorType>().getElementType().isa<ComplexType>()) {
return parser.emitError(loc, "expected tensor with complex element type");
}

// Assign LHS and RHS to inferred type
Type inferredTy = getInferredComplexType(type);
lhs = rhs = inferredTy;
result = type;
return success();
}

} // namespace

LogicalResult ComplexOp::inferReturnTypes(
MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
Expand Down Expand Up @@ -4165,30 +4257,6 @@ LogicalResult RngOp::reifyReturnTypeShapes(
//===----------------------------------------------------------------------===//

namespace {
// Utility function, used by printSelectOpType and
// printSameOperandsAndResultType
ParseResult assignFromFunctionType(OpAsmParser& parser, llvm::SMLoc loc,
ArrayRef<Type*> operands, Type& result,
FunctionType& fnType) {
assert(fnType);
if (fnType.getInputs().size() != operands.size()) {
return parser.emitError(loc)
<< operands.size() << " operands present, but expected "
<< fnType.getInputs().size();
}
if (fnType.getResults().size() != 1) {
return parser.emitError(loc, "expected single output");
}

// Set operand types to function input types
for (auto [operand, input] : llvm::zip(operands, fnType.getInputs())) {
*operand = input;
}
result = fnType.getResults().front();
return success();
}
} // namespace

void printSelectOpType(OpAsmPrinter& p, Operation* op, Type pred, Type onTrue,
Type onFalse, Type result) {
// Print functional type if true/false branches don't match return type.
Expand All @@ -4212,23 +4280,28 @@ ParseResult parseSelectOpType(OpAsmParser& parser, Type& pred, Type& onTrue,

// Error handling for invalid types
// Fail if not two types, or single functional type
if (types.size() != 2 &&
(types.size() != 1 || !types.front().dyn_cast<FunctionType>())) {
bool isValidType = (types.size() == 2 ||
(types.size() == 1 && types[0].isa<FunctionType>()));
if (!isValidType) {
return parser.emitError(loc,
"expected functional type or list of two types");
}

// stablehlo.select %0, %1 : <pred_type>, <op_and_result_type>
if (types.size() == 2) {
pred = types.front();
onTrue = onFalse = result = types.back();
pred = types[0];
onTrue = onFalse = result = types[1];
return success();
}

FunctionType fnType = types.front().dyn_cast<FunctionType>();
// stablehlo.select %0, %1 : (<op_types> ...) -> <result_type>
auto fnType = types[0].dyn_cast<FunctionType>();
return assignFromFunctionType(parser, loc, {&pred, &onTrue, &onFalse}, result,
fnType);
}

} // namespace

LogicalResult SelectOp::verify() {
// The operands 'on_true' and 'on_false' should have compatible types, i.e.,
// (a) have the same element type, and
Expand Down Expand Up @@ -5660,7 +5733,6 @@ void printSameOperandsAndResultTypeImpl(OpAsmPrinter& p, Operation* op,
TypeRange operands, Type result) {
// Handle zero operand types `() -> a` prints `a`
if (operands.empty()) {
// TODO(gleasonk): Unit test these lines once after_all is converted, with a
// call that has no operands and single output.
p.printType(result);
return;
Expand Down Expand Up @@ -5721,8 +5793,6 @@ ParseResult parseSameOperandsAndResultType(OpAsmParser& parser,
*typesRef.back());
}

// The following implementation is for SameOperandsAndResultType with variadic
// input.
void printVariadicSameOperandsAndResultType(OpAsmPrinter& p, Operation* op,
OperandRange operands,
TypeRange opTypes, Type result) {
Expand Down
7 changes: 5 additions & 2 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -682,13 +682,16 @@ def StableHLO_ComplexOp: StableHLO_BinaryElementwiseOp<"complex", [NoSideEffect,
Example:

```mlir
%0 = stablehlo.complex %arg0, %arg0 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
%0 = stablehlo.complex %arg0, %arg0 : tensor<4xcomplex<f32>>
```
}];
let arguments = (ins HLO_Fp32Or64Tensor:$lhs, HLO_Fp32Or64Tensor:$rhs);
let results = (outs HLO_ComplexTensor:$result);

let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
let assemblyFormat = [{
operands attr-dict
`:` custom<ComplexOpType>(type($lhs), type($rhs), type($result))
}];
}

def StableHLO_DivOp : StableHLO_BinaryElementwiseOp<"divide",
Expand Down
4 changes: 3 additions & 1 deletion stablehlo/tests/print_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func.func @type_convert_ops(%arg0 : tensor<2xf32>) -> () {
func.func @no_attr_ops(%arg0 : tensor<4xf32>, %arg1 : !stablehlo.token,
%arg2 : tensor<4xi32>, %arg3 : index) -> !stablehlo.token {
// CHECK-NEXT: %0 = stablehlo.clamp %arg0, %arg0, %arg0 : tensor<4xf32>
// CHECK-NEXT: %1 = stablehlo.complex %arg0, %arg0 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
// CHECK-NEXT: %1 = stablehlo.complex %arg0, %arg0 : tensor<4xcomplex<f32>>
// CHECK-NEXT: %2 = stablehlo.compute_reshape_shape %arg3, %arg2 : (index, tensor<4xi32>) -> tensor<4xi32>
// CHECK-NEXT: %3 = stablehlo.uniform_quantize %arg0 : (tensor<4xf32>) -> tensor<4x!quant.uniform<u8:f32, 3.400000e+01:16>>
// CHECK-NEXT: %4 = stablehlo.uniform_dequantize %3 : (tensor<4x!quant.uniform<u8:f32, 3.400000e+01:16>>) -> tensor<4xf32>
Expand Down Expand Up @@ -217,11 +217,13 @@ func.func @encodings(%arg0: tensor<10x20xf32, #CSR>,
// CHECK-NEXT: %1 = stablehlo.add %arg1, %arg1 : tensor<10x20xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
// CHECK-NEXT: %2 = stablehlo.abs %arg0 : (tensor<10x20xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>>) -> tensor<10x20xf32>
// CHECK-NEXT: %3 = stablehlo.abs %arg0 : tensor<10x20xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>>
// CHECK-NEXT: %4 = stablehlo.complex %arg0, %arg0 : (tensor<10x20xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>>, tensor<10x20xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>>) -> tensor<10x20xcomplex<f32>>
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<10x20xf32, #CSR>,
tensor<10x20xf32, #DCSR>) -> tensor<10x20xf32>
%1 = "stablehlo.add"(%arg1, %arg1) : (tensor<10x20xf32, #DCSR>,
tensor<10x20xf32, #DCSR>) -> tensor<10x20xf32, #DCSR>
%2 = "stablehlo.abs"(%arg0) : (tensor<10x20xf32, #CSR>) -> tensor<10x20xf32>
%3 = "stablehlo.abs"(%arg0) : (tensor<10x20xf32, #CSR>) -> tensor<10x20xf32, #CSR>
%4 = "stablehlo.complex"(%arg0, %arg0) : (tensor<10x20xf32, #CSR>, tensor<10x20xf32, #CSR>) -> tensor<10x20xcomplex<f32>>
func.return %0 : tensor<10x20xf32>
}
28 changes: 26 additions & 2 deletions stablehlo/tests/print_types_invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,30 @@ func.func @binary_eltwise_multiple_out(%arg0: tensor<?x?xf64>,

// -----

func.func @complex_type_not_type(%arg0: tensor<1xf64>) -> tensor<1xf64> {
// expected-error @+1 {{expected non-function type}}
%0 = stablehlo.complex %arg0, %arg0 : %arg0
func.return %0 : tensor<1xf64>
}

// -----

func.func @complex_type_not_tensor(%arg0: tensor<1xf64>) -> () {
// expected-error @+1 {{custom op 'stablehlo.complex' expected tensor with complex element type}}
%0 = stablehlo.complex %arg0, %arg0 : complex<f64>
func.return
}

// -----

func.func @complex_type_not_complex(%arg0: tensor<1xf64>) -> () {
// expected-error @+1 {{custom op 'stablehlo.complex' expected tensor with complex element type}}
%0 = stablehlo.complex %arg0, %arg0 : tensor<1xf64>
func.return
}

// -----

func.func @select_type_wrong_type(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>) -> () {
// expected-error @+1 {{custom op 'stablehlo.select' expected functional type or list of two types}}
%0 = stablehlo.select %arg0, %arg1, %arg1 : tensor<2x3xi1>
Expand Down Expand Up @@ -88,7 +112,7 @@ func.func @pairwise_count_mismatch(%arg0: tensor<1xf64>) -> tensor<1xf64> {
// -----

func.func @pairwise_type_not_list(%arg0: tensor<1xf64>) -> tensor<1xf64> {
// expected-error @+2 {{xpected non-function type}}
// expected-error @+2 {{expected non-function type}}
// expected-error @+1 {{custom op 'stablehlo.optimization_barrier' expected type list}}
%0 = stablehlo.optimization_barrier %arg0, %arg0 : %arg0
func.return %0 : tensor<1xf64>
Expand All @@ -97,7 +121,7 @@ func.func @pairwise_type_not_list(%arg0: tensor<1xf64>) -> tensor<1xf64> {
// -----

func.func @one_result_type(%arg0: tensor<1xf64>) -> tensor<1xf64> {
// expected-error @+1 {{xpected non-function type}}
// expected-error @+1 {{expected non-function type}}
%0 = stablehlo.abs %arg0 : %arg0
func.return %0 : tensor<1xf64>
}

0 comments on commit a7838fb

Please sign in to comment.