Skip to content

Commit

Permalink
Add input output aliasing to CustomCallOp
Browse files Browse the repository at this point in the history
  • Loading branch information
subhankarshah committed Oct 29, 2022
1 parent 885b8e3 commit 85210b9
Show file tree
Hide file tree
Showing 8 changed files with 407 additions and 80 deletions.
44 changes: 44 additions & 0 deletions stablehlo/dialect/StablehloAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,50 @@ def StableHLO_ConvDimensionNumbers : AttrDef<StableHLO_Dialect, "ConvDimensionNu
let hasCustomAssemblyFormat = 1;
}

def OutputOperandAlias : AttrDef<StableHLO_Dialect, "OutputOperandAlias"> {
let cppNamespace = "::mlir::stablehlo";
let mnemonic = "output_operand_alias";
let summary =
"Attribute that models the alias relationship of output and operand of a CustomCall op";
let description = [{
This attribute captures the alias relationship of the output to one of the
operands for a CustomCall op, denoted by `operand_index`. The
`output_tuple_indices` and `operand_tuple_indices` are used to index into
output and operand types. These indices lists are empty if the corresponding
types are not tuple types, and can be arbitrarily long in case of
arbitrarily nested tuple types.

See https://www.tensorflow.org/xla/aliasing.

Example when used as array with in stablehlo.custom-call:

```mlir
%0 = "stablehlo.custom_call"(%arg0, %arg1) {
// other attributes
output_operand_alias = [
#stablehlo.output_operand_alias<output_tuple_indices = [0],
operand_index = 0,
operand_tuple_indices = [1]>
]
} : (tuple<tensor<1x1xf32>, tensor<2x3xf32>>, tensor<5x5xf32>) -> tuple<tensor<2x3xf32>>

The output and the 0th operand are both tuples. The aliasing shows the
relationship between the 0th element in output tuple with the 1st element in
the 0th operand. And both of them are of the same type: tensor<2x3xf32>.
```
}];
let parameters = (ins
StableHLO_Dim:$outputTupleIndices,
"int64_t":$operandIndex,
StableHLO_Dim:$operandTupleIndices
);
let assemblyFormat = "`<` "
"`output_tuple_indices` `=` `[` $outputTupleIndices `]` `,`"
"`operand_index` `=` $operandIndex `,`"
"`operand_tuple_indices` `=` `[` $operandTupleIndices `]`"
"`>`";
}

def StableHLO_ArgResultAlias : AttrDef<StableHLO_Dialect, "ArgResultAlias"> {
let cppNamespace = "::mlir::stablehlo";
let mnemonic = "result_alias";
Expand Down
219 changes: 140 additions & 79 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,94 +427,155 @@ void ConstantOp::print(::mlir::OpAsmPrinter& p) {
// CustomCallOp
//===----------------------------------------------------------------------===//

void CustomCallOp::build(
::mlir::OpBuilder& odsBuilder, ::mlir::OperationState& odsState,
::mlir::TypeRange resultType, ::mlir::ValueRange operands,
::mlir::StringAttr callTargetName, ::mlir::BoolAttr hasSideEffect,
::mlir::StringAttr backendConfig,
::mlir::stablehlo::CustomCallApiVersionAttr apiVersion,
::mlir::ArrayAttr calledComputations, ::mlir::ArrayAttr operandLayouts,
::mlir::ArrayAttr resultLayouts) {
return CustomCallOp::build(odsBuilder, odsState, resultType, operands,
callTargetName, hasSideEffect, backendConfig,
apiVersion, calledComputations, operandLayouts,
resultLayouts, nullptr);
}

LogicalResult CustomCallOp::verify() {
// If both operand and result layout attributes are not specified then nothing
// to verify.
if (!getOperandLayouts().has_value() && !getResultLayouts().has_value())
return success();

// Layout constraints for either both operands & results or none should be
// specified.
if (getOperandLayouts().has_value() != getResultLayouts().has_value())
return emitOpError() << "Layout attributes should be specified for "
"either both operands and results or none.";

// Helper function to verify types and the corresponding layouts.
auto verifyTypesAndLayouts =
[this](TypeRange types, mlir::ArrayAttr layouts,
const std::string& valueName) -> LogicalResult {
if (types.size() != layouts.size())
return emitOpError() << "Number of " << valueName
<< "s must match the number of " << valueName
<< " layouts, " << types.size()
<< " != " << layouts.size();

for (const auto& indexedTypeAndLayout :
llvm::enumerate(llvm::zip(types, layouts))) {
// Get index for more descriptive error message.
auto index = indexedTypeAndLayout.index();

auto type = std::get<0>(indexedTypeAndLayout.value());
auto layout = std::get<1>(indexedTypeAndLayout.value())
.cast<DenseIntElementsAttr>();

if (type.isa<TupleType>())
return emitOpError() << "Tuple types are not fully supported with "
"layout constraints yet";
auto tensorType = type.dyn_cast<TensorType>();

// For non-tensor types e.g. !stablehlo.token, the layout should be empty.
if (!tensorType) {
if (layout.empty()) continue;
if (getOperandLayouts().has_value() || getResultLayouts().has_value()) {
// Layout constraints for either both operands & results or none should be
// specified.
if (getOperandLayouts().has_value() != getResultLayouts().has_value())
return emitOpError() << "Layout attributes should be specified for "
"either both operands and results or none.";

// Helper function to verify types and the corresponding layouts.
auto verifyTypesAndLayouts =
[this](TypeRange types, mlir::ArrayAttr layouts,
const std::string& valueName) -> LogicalResult {
if (types.size() != layouts.size())
return emitOpError()
<< "Only tensor types can have non-empty layout: " << valueName
<< " #" << index << " of type " << type << " has layout "
<< layout;
<< "Number of " << valueName << "s must match the number of "
<< valueName << " layouts, " << types.size()
<< " != " << layouts.size();

for (const auto& indexedTypeAndLayout :
llvm::enumerate(llvm::zip(types, layouts))) {
// Get index for more descriptive error message.
auto index = indexedTypeAndLayout.index();

auto type = std::get<0>(indexedTypeAndLayout.value());
auto layout = std::get<1>(indexedTypeAndLayout.value())
.cast<DenseIntElementsAttr>();

if (type.isa<TupleType>())
return emitOpError() << "Tuple types are not fully supported with "
"layout constraints yet";
auto tensorType = type.dyn_cast<TensorType>();

// For non-tensor types such as !stablehlo.token, the layout should be empty.
if (!tensorType) {
if (layout.empty()) continue;
return emitOpError()
<< "Only tensor types can have non-empty layout: " << valueName
<< " #" << index << " of type " << type << " has layout "
<< layout;
}

// For unranked tensors, we cannot verify the compatibility with layout
// any further.
if (!tensorType.hasRank()) continue;

// Layout must be a permutation of [0, N) where N is the rank of the
// tensor type.
std::vector<int64_t> range(tensorType.getRank());
std::iota(range.begin(), range.end(), 0);
if (tensorType.getRank() != layout.size() ||
!std::is_permutation(range.begin(), range.end(), layout.begin()))
return emitOpError()
<< "incorrect layout " << layout << " for type " << type
<< ", layout must be a permutation of [0, "
<< tensorType.getRank() << ")";
}
return success();
};

// For unranked tensors, we cannot verify the compatibility with layout
// any further.
if (!tensorType.hasRank()) continue;

// Layout must be a permutation of [0, N) where N is the rank of the
// tensor type.
std::vector<int64_t> range(tensorType.getRank());
std::iota(range.begin(), range.end(), 0);
if (tensorType.getRank() != layout.size() ||
!std::is_permutation(range.begin(), range.end(), layout.begin()))
return emitOpError() << "incorrect layout " << layout << " for type "
<< type << ", layout must be a permutation of [0, "
<< tensorType.getRank() << ")";
}
return success();
};
// At this point both `operand_layouts` and `result_layouts` are defined.
ArrayAttr operandLayouts = this->getOperandLayouts().value();
ArrayAttr resultLayouts = this->getResultLayouts().value();

// At this point both `operand_layouts` and `result_layouts` are defined.
ArrayAttr operandLayouts = this->getOperandLayouts().value();
ArrayAttr resultLayouts = this->getResultLayouts().value();

// Full support for layouts for arbitrary nesting of tuples is not
// supported yet.
//
// If result does not have any tuples, then i-th element of `result_layouts`
// specifies the layout constraints on i-th result.
//
// For the common case of a single tuple result packing non-tuple values, the
// i-th element of `result_layouts` specifies layout for i-th element of the
// result tuple.
TypeRange resultTypes;
if (getNumResults() == 1 && getResult(0).getType().isa<TupleType>())
resultTypes = getResult(0).getType().cast<TupleType>().getTypes();
else
resultTypes = getResultTypes();
// Full support for layouts for arbitrary nesting of tuples is not
// supported yet.
//
// If result does not have any tuples, then i-th element of `result_layouts`
// specifies the layout constraints on i-th result.
//
// For the common case of a single tuple result packing non-tuple values,
// the i-th element of `result_layouts` specifies layout for i-th element of
// the result tuple.
TypeRange resultTypes;
if (getNumResults() == 1 && getResult(0).getType().isa<TupleType>())
resultTypes = getResult(0).getType().cast<TupleType>().getTypes();
else
resultTypes = getResultTypes();

// Verify that operands and operand layouts match.
if (failed(verifyTypesAndLayouts(getOperandTypes(), operandLayouts,
"operand")))
return failure();

// Verify that operands and operand layouts match.
if (failed(
verifyTypesAndLayouts(getOperandTypes(), operandLayouts, "operand")))
return failure();
// Verify that results and result layouts match.
if (failed(verifyTypesAndLayouts(resultTypes, resultLayouts, "result")))
return failure();
}

// Verify that results and result layouts match.
return verifyTypesAndLayouts(resultTypes, resultLayouts, "result");
// Check output_operand_aliases

auto aliasArrayAttr = getOutputOperandAliases();
for (auto attr : aliasArrayAttr) {
auto alias = attr.cast<OutputOperandAliasAttr>();
auto outputTupleIndices = alias.getOutputTupleIndices();
auto operandIndex = alias.getOperandIndex();
auto operandTupleIndices = alias.getOperandTupleIndices();

if (operandIndex < 0 ||
operandIndex >= static_cast<int64_t>(getInputs().size()))
return emitOpError()
<< "expects operandIndex in the output_operand_alias attribute "
"to be in range [0, "
<< getInputs().size() << "); got: " << operandIndex << ".";

Type operandPart = getOperand(operandIndex).getType();
for (auto i : operandTupleIndices) {
if (!operandPart.isa<TupleType>() ||
i >= static_cast<int64_t>(operandPart.cast<TupleType>().size()) ||
i < 0)
return emitOpError()
<< "operand_tuple_indices in the output_operand_alias "
"attribute out of bounds";
operandPart = operandPart.cast<TupleType>().getType(i);
}
Type outputPart = getNumResults() > 1
? TupleType::get(getContext(), getResultTypes())
: getResult(0).getType();
for (auto i : outputTupleIndices) {
if (!outputPart.isa<TupleType>() ||
i >= static_cast<int64_t>(outputPart.cast<TupleType>().size()) ||
i < 0)
return emitOpError()
<< "output_tuple_indices in the output_operand_alias "
"attribute out of bounds";
outputPart = outputPart.cast<TupleType>().getType(i);
}
if (operandPart != outputPart)
return emitOpError()
<< "shapes mismatch in the output_operand_alias attribute: "
<< "operand part has type " << operandPart
<< " and output part has type " << outputPart;
}
return success();
}

void CustomCallOp::getEffects(
Expand Down
20 changes: 19 additions & 1 deletion stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2012,7 +2012,12 @@ def StableHLO_CustomCallOp: StableHLO_Op<"custom_call",
$api_version,
DefaultValuedOptionalAttr<StableHLO_FlatSymbolRefArrayAttr, "{}">:$called_computations,
OptionalAttr<StableHLO_ArrayOfLayoutAttr>:$operand_layouts,
OptionalAttr<StableHLO_ArrayOfLayoutAttr>:$result_layouts
OptionalAttr<StableHLO_ArrayOfLayoutAttr>:$result_layouts,
DefaultValuedOptionalAttr<
TypedArrayAttrBase<
OutputOperandAlias,
"Aliasing attribute for outputs and operands of CustomCall">,
"{}">:$output_operand_aliases
);
let results = (outs Variadic<HLO_TensorOrTokenOrTuple>);
let hasVerifier = 1;
Expand All @@ -2021,6 +2026,19 @@ def StableHLO_CustomCallOp: StableHLO_Op<"custom_call",
$call_target_name `(` $inputs `)` attr-dict
`:` functional-type(operands, results)
}];

// TODO(b/244367323): Need update all usage by adding the arg
// `output_operand_aliases`, and remove this builder after the bug fix.
let builders = [
OpBuilder<(ins
"::mlir::TypeRange":$result_type, "::mlir::ValueRange":$operands,
"::mlir::StringAttr":$call_target_name,
"::mlir::BoolAttr":$has_side_effect,
"::mlir::StringAttr":$backend_config,
"::mlir::stablehlo::CustomCallApiVersionAttr":$api_version,
"::mlir::ArrayAttr":$called_computations,
"::mlir::ArrayAttr":$operand_layouts,
"::mlir::ArrayAttr":$result_layouts)>];
}

def StableHLO_DotOp: StableHLO_Op<"dot",
Expand Down
54 changes: 54 additions & 0 deletions stablehlo/integrations/c/StablehloAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,60 @@ int64_t stablehloConvDimensionNumbersGetOutputSpatialDimensionsElem(
.getOutputSpatialDimensions()[pos];
}

//===----------------------------------------------------------------------===//
// OutputOperandAliasAttr
//===----------------------------------------------------------------------===//

MLIR_CAPI_EXPORTED MlirAttribute stablehloOutputOperandAliasGet(
MlirContext ctx, intptr_t nOutputTupleIndices,
const int64_t *outputTupleIndices, int64_t operandIndex,
intptr_t nOperandTupleIndices, const int64_t *operandTupleIndices) {
return wrap(mlir::stablehlo::OutputOperandAliasAttr::get(
unwrap(ctx), llvm::makeArrayRef(outputTupleIndices, nOutputTupleIndices),
operandIndex,
llvm::makeArrayRef(operandTupleIndices, nOperandTupleIndices)));
}

bool stablehloAttributeIsAOutputOperandAlias(MlirAttribute attr) {
return unwrap(attr).isa<mlir::stablehlo::OutputOperandAliasAttr>();
}

intptr_t stablehloOutputOperandAliasGetOutputTupleIndicesSize(
MlirAttribute attr) {
return unwrap(attr)
.cast<mlir::stablehlo::OutputOperandAliasAttr>()
.getOutputTupleIndices()
.size();
}

int64_t stablehloOutputOperandAliasGetOutputTupleIndicesElem(MlirAttribute attr,
intptr_t pos) {
return unwrap(attr)
.cast<mlir::stablehlo::OutputOperandAliasAttr>()
.getOutputTupleIndices()[pos];
}

int64_t stablehloOutputOperandAliasGetOperandIndex(MlirAttribute attr) {
return unwrap(attr)
.cast<mlir::stablehlo::OutputOperandAliasAttr>()
.getOperandIndex();
}

intptr_t stablehloOutputOperandAliasGetOperandTupleIndicesSize(
MlirAttribute attr) {
return unwrap(attr)
.cast<mlir::stablehlo::OutputOperandAliasAttr>()
.getOperandTupleIndices()
.size();
}

int64_t stablehloOutputOperandAliasGetOperandTupleIndicesElem(MlirAttribute attr,
intptr_t pos) {
return unwrap(attr)
.cast<mlir::stablehlo::OutputOperandAliasAttr>()
.getOperandTupleIndices()[pos];
}

//===----------------------------------------------------------------------===//
// ComparisonDirectionAttr
//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit 85210b9

Please sign in to comment.