Skip to content

Commit

Permalink
Fork Attributes and Types for VHLO (openxla#849)
Browse files Browse the repository at this point in the history
- Change `CustomCallApiVersion` to an enum instead of an integer value.
+ To avoid transforming _all_ integers, the transformation from
StableHLO --> VHLO special cases this conversion
- Attributes
+ Forked Attributes: `IntegerAttr, StringAttr, UnitAttr, ArrayAttr`,
`DenseIntOrFPElementsV1Attr, FlatSymbolRefV1Attr, FloatV1Attr`
- Types
+ Forked Types: `RankedTensorType, UnrankedTensorType, TupleType,
WitnessType`, `BFloat16V1Type, Float16V1Type, Float32V1Type,
Float64V1Type, IndexV1Type, ComplexV1Type, IntegerV1Type,
UniformQuantizedV1Type`
- Bytecode implementations forked from
[BuiltinDialectBytecode.cpp](https://github.com/llvm/llvm-project/blob/c48e0cf03a50bb8a2043ac4bb5e9a83ff135247a/mlir/lib/IR/BuiltinDialectBytecode.cpp)
  • Loading branch information
GleasonK committed Feb 10, 2023
1 parent 7177166 commit 70afa36
Show file tree
Hide file tree
Showing 24 changed files with 2,182 additions and 774 deletions.
25 changes: 23 additions & 2 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,7 @@ cc_library(
":vhlo_enums_inc_gen",
":vhlo_op_interfaces_inc_gen",
":vhlo_ops_inc_gen",
":vhlo_type_interfaces_inc_gen",
":vhlo_types_inc_gen",
":vhlo_types",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
Expand Down Expand Up @@ -416,6 +415,27 @@ td_library(
],
)

cc_library(
name = "vhlo_types",
srcs = [
"stablehlo/dialect/VhloTypes.cpp",
],
hdrs = [
"stablehlo/dialect/VhloTypes.h",
],
deps = [
":version",
":vhlo_type_interfaces_inc_gen",
":vhlo_types_inc_gen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:ShapeDialect",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
)

gentbl_cc_library(
name = "vhlo_type_interfaces_inc_gen",
tbl_outs = [
Expand Down Expand Up @@ -491,6 +511,7 @@ cc_library(
":stablehlo_type_inference",
":version",
":vhlo_ops",
":vhlo_types",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:InferTypeOpInterface",
Expand Down
28 changes: 26 additions & 2 deletions stablehlo/dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -178,15 +178,39 @@ add_mlir_dialect_library(VhloOps
PARTIAL_SOURCES_INTENDED
VhloBytecode.cpp
VhloOps.cpp
Version.cpp

DEPENDS
VhloOpsIncGen

LINK_LIBS PUBLIC
StablehloAssemblyFormat
StablehloBase
MLIRIR
MLIRQuantDialect
MLIRShapeDialect
MLIRSupport
)

add_mlir_dialect_library(VhloTypes
PARTIAL_SOURCES_INTENDED
VhloTypes.cpp

DEPENDS
VhloOpsIncGen

LINK_LIBS PUBLIC
VhloVersion
MLIRIR
MLIRQuantDialect
MLIRShapeDialect
MLIRSupport
MLIRTransformUtils
)

add_mlir_dialect_library(VhloVersion
PARTIAL_SOURCES_INTENDED
Version.cpp

LINK_LIBS PUBLIC
MLIRIR
MLIRSupport
)
8 changes: 7 additions & 1 deletion stablehlo/dialect/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ limitations under the License.

#include "stablehlo/dialect/Register.h"

#include "mlir/Dialect/Quant/QuantOps.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "stablehlo/dialect/VhloOps.h"
Expand All @@ -26,7 +29,10 @@ namespace stablehlo {

void registerAllDialects(mlir::DialectRegistry &registry) {
// clang-format off
registry.insert<mlir::sparse_tensor::SparseTensorDialect>();
registry.insert<mlir::shape::ShapeDialect,
mlir::sparse_tensor::SparseTensorDialect,
mlir::tensor::TensorDialect,
mlir::quant::QuantizationDialect>();
registry.insert<mlir::chlo::ChloDialect,
mlir::stablehlo::StablehloDialect,
mlir::vhlo::VhloDialect>();
Expand Down
6 changes: 6 additions & 0 deletions stablehlo/dialect/Version.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ class Version {
/// from a StringRef of the form `#.#.#`. Returns failure if invalid string.
static FailureOr<Version> fromString(llvm::StringRef versionRef);

/// Return a Version representing the current dialect version.
static Version getCurrentVersion() { return Version(0, 4, 0); }

/// Return a Version representing the minimum supported dialect version.
static Version getMinimumVersion() { return Version(0, 3, 0); }

/// Construct Version from major, minor, patch integers.
Version(int64_t major, int64_t minor, int64_t patch)
: majorMinorPatch({major, minor, patch}) {}
Expand Down
155 changes: 107 additions & 48 deletions stablehlo/dialect/VhloAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.

include "stablehlo/dialect/VhloBase.td"
include "stablehlo/dialect/VhloDialect.td"
include "stablehlo/dialect/VhloTypes.td"

include "mlir/IR/AttrTypeBase.td"

Expand Down Expand Up @@ -49,7 +50,7 @@ class VHLO_AttrDef<string name,
return *version;
}
mlir::vhlo::Version getMaxVersion() {
if (!strcmp("}] # maxVersion # [{", "current")) return VhloDialect::getCurrentVersion();
if (!strcmp("}] # maxVersion # [{", "current")) return Version::getCurrentVersion();
auto version = mlir::vhlo::Version::fromString("}] # maxVersion # [{");
if (failed(version)) llvm_unreachable("invalid version }] # maxVersion # [{ in }] # name # [{");
return *version;
Expand All @@ -58,15 +59,10 @@ class VHLO_AttrDef<string name,
}

//===----------------------------------------------------------------------===//
// Attributes
// VHLO Attributes
//===----------------------------------------------------------------------===//

def VHLO_Dims : ArrayRefParameter<"int64_t", "Dimension"> {
let parser = "mlir::hlo::parseDimSizes($_parser)";
let printer = "mlir::hlo::printDimSizes($_printer, $_self)";
}

def VHLO_ScatterDimensionNumbers: VHLO_AttrDef<"ScatterDimensionNumbers"> {
def VHLO_ScatterDimensionNumbersAttrV1 : VHLO_AttrDef<"ScatterDimensionNumbersV1"> {
let cppNamespace = "::mlir::vhlo";
let mnemonic = "scatter";
let parameters = (ins
Expand All @@ -78,7 +74,7 @@ def VHLO_ScatterDimensionNumbers: VHLO_AttrDef<"ScatterDimensionNumbers"> {
let assemblyFormat = "`<` struct(params) `>`";
}

def VHLO_GatherDimensionNumbers : VHLO_AttrDef<"GatherDimensionNumbers"> {
def VHLO_GatherDimensionNumbersAttrV1 : VHLO_AttrDef<"GatherDimensionNumbersV1"> {
let cppNamespace = "::mlir::vhlo";
let mnemonic = "gather";
let parameters = (ins
Expand All @@ -90,7 +86,7 @@ def VHLO_GatherDimensionNumbers : VHLO_AttrDef<"GatherDimensionNumbers"> {
let assemblyFormat = "`<` struct(params) `>`";
}

def VHLO_DotDimensionNumbers : VHLO_AttrDef<"DotDimensionNumbers"> {
def VHLO_DotDimensionNumbersAttrV1 : VHLO_AttrDef<"DotDimensionNumbersV1"> {
let cppNamespace = "::mlir::vhlo";
let mnemonic = "dot";
let parameters = (ins
Expand All @@ -102,7 +98,7 @@ def VHLO_DotDimensionNumbers : VHLO_AttrDef<"DotDimensionNumbers"> {
let assemblyFormat = "`<` struct(params) `>`";
}

def VHLO_OutputOperandAlias : VHLO_AttrDef<"OutputOperandAlias", "0.4.0"> {
def VHLO_OutputOperandAliasAttrV1 : VHLO_AttrDef<"OutputOperandAliasV1", "0.4.0"> {
let cppNamespace = "::mlir::vhlo";
let mnemonic = "output_operand_alias";
let parameters = (ins
Expand All @@ -113,7 +109,7 @@ def VHLO_OutputOperandAlias : VHLO_AttrDef<"OutputOperandAlias", "0.4.0"> {
let assemblyFormat = "`<` struct(params) `>`";
}

def VHLO_ArgResultAlias : VHLO_AttrDef<"ArgResultAlias"> {
def VHLO_ArgResultAliasAttrV1 : VHLO_AttrDef<"ArgResultAliasV1"> {
let cppNamespace = "::mlir::vhlo";
let mnemonic = "result_alias";
let parameters = (ins
Expand All @@ -125,53 +121,21 @@ def VHLO_ArgResultAlias : VHLO_AttrDef<"ArgResultAlias"> {
let assemblyFormat = "`<` struct(params) `>`";
}

def VHLO_ChannelHandle : VHLO_AttrDef<"ChannelHandle"> {
def VHLO_ChannelHandleAttrV1 : VHLO_AttrDef<"ChannelHandleV1"> {
let cppNamespace = "::mlir::vhlo";
let mnemonic = "channel_handle";
let parameters = (ins "int64_t":$handle, "int64_t":$type);
let assemblyFormat = "`<` struct(params) `>`";
}

def VHLO_TypeExtensions : VHLO_AttrDef<"TypeExtensions"> {
def VHLO_TypeExtensionsAttrV1 : VHLO_AttrDef<"TypeExtensionsV1"> {
let cppNamespace = "::mlir::vhlo";
let mnemonic = "type_extensions";
let parameters = (ins VHLO_Dims:$bounds);
let assemblyFormat = "`<` struct(params) `>`";
}

def VHLO_LayoutAttr : Attr<
And<[IndexElementsAttr.predicate,
CPred<[{$_self.cast<::mlir::DenseIntElementsAttr>().getType().getRank()
== 1}]>]>,
"A 1D tensor of index type (layout)"> {
let storageType = IndexElementsAttr.storageType;
let returnType = IndexElementsAttr.returnType;
let convertFromStorage = IndexElementsAttr.convertFromStorage;
}

// An array of layout (1D tensor) attributes.
def VHLO_ArrayOfLayoutAttr : TypedArrayAttrBase<VHLO_LayoutAttr,
"Array of layout (1D tensor of index type) attributes">;

// An array of FlatSymbolRef attributes that can be used as a default valued
// attribute.
def VHLO_FlatSymbolRefArrayAttr :
TypedArrayAttrBase<FlatSymbolRefAttr, "flat symbol ref array attribute"> {
let constBuilderCall = "::mlir::ArrayAttr::get($_builder.getContext(), $0)";
}

def VHLO_BoolElementsAttr :
ElementsAttrBase<
And<[CPred<"$_self.isa<::mlir::DenseIntOrFPElementsAttr>()">,
CPred<"$_self.cast<::mlir::DenseIntOrFPElementsAttr>().getType().getElementType().isInteger(1)">]>,
"constant boolean vector/tensor attribute"> {
let storageType = [{ ::mlir::DenseElementsAttr }];
let returnType = [{ ::mlir::DenseElementsAttr }];

let convertFromStorage = "$_self";
}

def VHLO_ConvDimensionNumbers : VHLO_AttrDef<"ConvDimensionNumbers"> {
def VHLO_ConvDimensionNumbersAttrV1 : VHLO_AttrDef<"ConvDimensionNumbersV1"> {
let cppNamespace = "::mlir::vhlo";
let mnemonic = "conv";
let parameters = (ins
Expand All @@ -190,7 +154,7 @@ def VHLO_ConvDimensionNumbers : VHLO_AttrDef<"ConvDimensionNumbers"> {
let assemblyFormat = "`<` struct(params) `>`";
}

def VHLO_ConvolutionAttributes {
def VHLO_ConvolutionAttributesV1 {
dag attributes = (ins
OptionalAttr<VHLO_AnyAttr>:$window_strides,
OptionalAttr<VHLO_AnyAttr>:$padding,
Expand All @@ -204,4 +168,99 @@ def VHLO_ConvolutionAttributes {
);
}

//===----------------------------------------------------------------------===//
// Forked Attributes
//===----------------------------------------------------------------------===//

def VHLO_ArrayAttrV1 : VHLO_AttrDef<"ArrayV1"> {
let mnemonic = "array";
let parameters = (ins ArrayRefParameter<"mlir::Attribute">:$value);
let genVerifyDecl = 1;
let extraClassDefinition = [{
LogicalResult ArrayV1Attr::verify(
llvm::function_ref<mlir::InFlightDiagnostic ()> errFn, ArrayRef<mlir::Attribute> value) {
if (!allFromVhlo(value)) return errFn() << "expected array of VHLO attriutes";
return success();
}
}];
let assemblyFormat = "`<` custom<AttributeArray>($value) `>`";
}

def VHLO_DenseIntOrFPElementsData : AttrParameter<"::llvm::ArrayRef<char>", "Array of int"> {
// Custom allocator to copy dense elements data into MLIR Context
let allocator = "$_dst = $_allocator.copyInto($_self);";
}
def VHLO_DenseIntOrFPElementsAttrV1 : VHLO_AttrDef<"DenseIntOrFPElementsV1"> {
let mnemonic = "dense";
let parameters = (ins "::mlir::Type":$type, VHLO_DenseIntOrFPElementsData:$raw_data);
let genVerifyDecl = 1;
let extraClassDefinition = [{
LogicalResult DenseIntOrFPElementsV1Attr::verify(
llvm::function_ref<mlir::InFlightDiagnostic ()> errFn, mlir::Type type, ArrayRef<char>) {
if (!isFromVhlo(type)) errFn() << "expected VHLO type";
return success();
}
}];
let hasCustomAssemblyFormat = 1;
}

def VHLO_FlatSymbolRefAttrV1 : VHLO_AttrDef<"FlatSymbolRefV1"> {
let mnemonic = "sym";
let parameters = (ins "::mlir::Attribute":$root_reference);
let genVerifyDecl = 1;
let extraClassDefinition = [{
LogicalResult FlatSymbolRefV1Attr::verify(
llvm::function_ref<mlir::InFlightDiagnostic ()> errFn, mlir::Attribute rootReference) {
if (!isFromVhlo(rootReference)) return errFn() << "expected VHLO attribute";
return success();
}
}];
let assemblyFormat = "`<` $root_reference `>`";
}

def VHLO_FloatAttrV1 : VHLO_AttrDef<"FloatV1"> {
let mnemonic = "float";
let parameters = (ins "mlir::Type":$type, VHLO_APFloatV1:$value);
let genVerifyDecl = 1;
let extraClassDefinition = [{
LogicalResult FloatV1Attr::verify(
llvm::function_ref<mlir::InFlightDiagnostic ()> errFn, mlir::Type type, APFloat value) {
if (!isFromVhlo(type)) return errFn() << "expected VHLO type";
return success();
}
}];
let assemblyFormat = "`<` $value `:` $type `>`";
}

def VHLO_IntegerAttrV1 : VHLO_AttrDef<"IntegerV1"> {
let mnemonic = "integer";
let parameters = (ins "mlir::Type":$type, "APInt":$value);
let genVerifyDecl = 1;
let extraClassDefinition = [{
LogicalResult IntegerV1Attr::verify(
llvm::function_ref<mlir::InFlightDiagnostic ()> errFn, mlir::Type type, APInt value) {
if (!isFromVhlo(type)) return errFn() << "expected VHLO type";
return success();
}
}];
let hasCustomAssemblyFormat = 1;
}

def VHLO_StringAttrV1 : VHLO_AttrDef<"StringV1"> {
let mnemonic = "string";
let parameters = (ins StringRefParameter<"">:$value);
let assemblyFormat = "`<` $value `>`";
}

def VHLO_UnitAttrV1 : VHLO_AttrDef<"UnitV1"> {
let mnemonic = "unit";
let storageType = [{ ::mlir::UnitAttr }];
let constBuilderCall = "(($0) ? $_builder.getUnitAttr() : nullptr)";
let convertFromStorage = "$_self != nullptr";
let returnType = "bool";
let defaultValue = "false";
let valueType = NoneType;
let isOptional = 1;
}

#endif // STABLEHLO_DIALECT_VHLO_ATTRS
7 changes: 7 additions & 0 deletions stablehlo/dialect/VhloBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
#ifndef STABLEHLO_DIALECT_VHLO_BASE
#define STABLEHLO_DIALECT_VHLO_BASE

include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/OpBase.td"

//===----------------------------------------------------------------------===//
Expand All @@ -28,6 +29,12 @@ def VHLO_AnyType : AnyTypeOf<[AnyType]>;
def VHLO_AnyAttr : AnyAttrOf<[AnyAttr]>;
def VHLO_AnyRegion : Region<CPred<"true">, "any region">;

// Data types used in Attributes / Types
def VHLO_Dims : ArrayRefParameter<"int64_t", "Dimension"> {
let parser = "mlir::hlo::parseDimSizes($_parser)";
let printer = "mlir::hlo::printDimSizes($_printer, $_self)";
}

//===----------------------------------------------------------------------===//
// VHLO traits
//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit 70afa36

Please sign in to comment.