Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fork Attributes and Types for VHLO #849

Merged
merged 44 commits into from
Feb 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
0875f17
Begin forking attributes / types
GleasonK Dec 9, 2022
dec1393
Wrap tensor type in VHLO to control serialization and deserializaiton
GleasonK Dec 12, 2022
b6b5709
Wrap shaped types for serialization.
GleasonK Dec 14, 2022
6715343
Wrap tuple type for serilization
GleasonK Dec 14, 2022
b03d026
Add element type forks / wrappers
GleasonK Dec 15, 2022
dd844fd
Fork quantized type.
GleasonK Dec 19, 2022
57d1584
Wrap IntegertAttr, FloatAttr, StringAttr, FlatSymbolRefAttr.
GleasonK Dec 19, 2022
acd7196
Fork attributes
GleasonK Dec 20, 2022
4705eb5
Fork all attributes.
GleasonK Dec 29, 2022
85b73a2
Fork all attributes and types instead of wrapping
GleasonK Jan 4, 2023
3bcad47
Rebase and reorganize bytecode
GleasonK Jan 6, 2023
ef3e43a
Reorder type conversion alphabetically
GleasonK Jan 6, 2023
17007f8
Sort attribute and type declarations alphabetically
GleasonK Jan 6, 2023
87a01f0
Reorder bytecode encodings alphabetically
GleasonK Jan 6, 2023
143df01
Reorganize attribute conversions alphabetically
GleasonK Jan 6, 2023
5d010a6
Add missing attribute and type validation and tests
GleasonK Jan 6, 2023
f96ed60
Test cleanup
GleasonK Jan 6, 2023
3438406
fix lint error
GleasonK Jan 6, 2023
08deeb0
Rebase on main
GleasonK Jan 19, 2023
5c5418c
Addressed initial feedback
GleasonK Jan 24, 2023
8902560
Address feedback
GleasonK Jan 24, 2023
c693450
Remove VHLO loads of ShapeDialect/QuantDialect, move loading to opt t…
GleasonK Jan 24, 2023
6ed57be
Address feedback
GleasonK Jan 26, 2023
8e76ee8
Remove unused headers
GleasonK Jan 26, 2023
432b4dc
Uncomment verifier
GleasonK Jan 26, 2023
ffcdb2f
Addressed feedback - Fixed dependent dialects and dialect registration.
GleasonK Jan 31, 2023
70ae6f2
Split wrapped IntegerType into several VHLO integer types
GleasonK Jan 31, 2023
dd05a29
Don't wrap integer attr. Use upstream printers/parsers.
GleasonK Jan 31, 2023
1a70a6b
Cleanup float printing
GleasonK Jan 31, 2023
bb46635
assertFromVhlo -> verifyFromVhlo
GleasonK Jan 31, 2023
d450c76
Added verifiers to attributes and types.
GleasonK Jan 31, 2023
440bdda
Remove verify from bytecode now that is is on attrs
GleasonK Jan 31, 2023
723dc73
Addressed feedback - printTensorShape -> printShape, removed explicit…
GleasonK Feb 1, 2023
e635493
Improve validation error messages
GleasonK Feb 1, 2023
5605a76
Fix bazel build
GleasonK Feb 2, 2023
6278df5
Remove commented out code
GleasonK Feb 2, 2023
6793b06
Refactor Builtin<->VHLO type converter into its own class
GleasonK Feb 2, 2023
a50c73d
fix whitespace
GleasonK Feb 2, 2023
0ccd322
Refactor types into separate target
GleasonK Feb 3, 2023
46feb7f
Rename VhloTypeConverterBase
GleasonK Feb 3, 2023
8a8cd40
uncomment verifier
GleasonK Feb 3, 2023
1ee938a
Move VhloDialect back into VhloOps
GleasonK Feb 3, 2023
02bc6eb
Address feedback
GleasonK Feb 3, 2023
cb21681
rename include guard
GleasonK Feb 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,7 @@ cc_library(
":vhlo_enums_inc_gen",
":vhlo_op_interfaces_inc_gen",
":vhlo_ops_inc_gen",
":vhlo_type_interfaces_inc_gen",
":vhlo_types_inc_gen",
":vhlo_types",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
Expand Down Expand Up @@ -416,6 +415,27 @@ td_library(
],
)

cc_library(
name = "vhlo_types",
srcs = [
"stablehlo/dialect/VhloTypes.cpp",
],
hdrs = [
"stablehlo/dialect/VhloTypes.h",
],
deps = [
":version",
":vhlo_type_interfaces_inc_gen",
":vhlo_types_inc_gen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:ShapeDialect",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
)

gentbl_cc_library(
name = "vhlo_type_interfaces_inc_gen",
tbl_outs = [
Expand Down Expand Up @@ -491,6 +511,7 @@ cc_library(
":stablehlo_type_inference",
":version",
":vhlo_ops",
":vhlo_types",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:InferTypeOpInterface",
Expand Down
28 changes: 26 additions & 2 deletions stablehlo/dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -178,15 +178,39 @@ add_mlir_dialect_library(VhloOps
PARTIAL_SOURCES_INTENDED
VhloBytecode.cpp
VhloOps.cpp
Version.cpp

DEPENDS
VhloOpsIncGen

LINK_LIBS PUBLIC
StablehloAssemblyFormat
StablehloBase
MLIRIR
GleasonK marked this conversation as resolved.
Show resolved Hide resolved
MLIRQuantDialect
MLIRShapeDialect
MLIRSupport
)

add_mlir_dialect_library(VhloTypes
PARTIAL_SOURCES_INTENDED
VhloTypes.cpp

DEPENDS
VhloOpsIncGen

LINK_LIBS PUBLIC
VhloVersion
MLIRIR
MLIRQuantDialect
MLIRShapeDialect
MLIRSupport
GleasonK marked this conversation as resolved.
Show resolved Hide resolved
MLIRTransformUtils
)

add_mlir_dialect_library(VhloVersion
PARTIAL_SOURCES_INTENDED
Version.cpp
GleasonK marked this conversation as resolved.
Show resolved Hide resolved

LINK_LIBS PUBLIC
MLIRIR
GleasonK marked this conversation as resolved.
Show resolved Hide resolved
GleasonK marked this conversation as resolved.
Show resolved Hide resolved
GleasonK marked this conversation as resolved.
Show resolved Hide resolved
MLIRSupport
)
8 changes: 7 additions & 1 deletion stablehlo/dialect/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ limitations under the License.

#include "stablehlo/dialect/Register.h"

#include "mlir/Dialect/Quant/QuantOps.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "stablehlo/dialect/VhloOps.h"
Expand All @@ -26,7 +29,10 @@ namespace stablehlo {

void registerAllDialects(mlir::DialectRegistry &registry) {
// clang-format off
registry.insert<mlir::sparse_tensor::SparseTensorDialect>();
registry.insert<mlir::shape::ShapeDialect,
mlir::sparse_tensor::SparseTensorDialect,
mlir::tensor::TensorDialect,
mlir::quant::QuantizationDialect>();
registry.insert<mlir::chlo::ChloDialect,
mlir::stablehlo::StablehloDialect,
mlir::vhlo::VhloDialect>();
Expand Down
6 changes: 6 additions & 0 deletions stablehlo/dialect/Version.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ class Version {
/// from a StringRef of the form `#.#.#`. Returns failure if invalid string.
static FailureOr<Version> fromString(llvm::StringRef versionRef);

/// Return a Version representing the current dialect version.
static Version getCurrentVersion() { return Version(0, 4, 0); }

/// Return a Version representing the minimum supported dialect version.
static Version getMinimumVersion() { return Version(0, 3, 0); }

/// Construct Version from major, minor, patch integers.
Version(int64_t major, int64_t minor, int64_t patch)
: majorMinorPatch({major, minor, patch}) {}
Expand Down
155 changes: 107 additions & 48 deletions stablehlo/dialect/VhloAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.

include "stablehlo/dialect/VhloBase.td"
include "stablehlo/dialect/VhloDialect.td"
include "stablehlo/dialect/VhloTypes.td"

include "mlir/IR/AttrTypeBase.td"

Expand Down Expand Up @@ -49,7 +50,7 @@ class VHLO_AttrDef<string name,
return *version;
}
mlir::vhlo::Version getMaxVersion() {
if (!strcmp("}] # maxVersion # [{", "current")) return VhloDialect::getCurrentVersion();
if (!strcmp("}] # maxVersion # [{", "current")) return Version::getCurrentVersion();
auto version = mlir::vhlo::Version::fromString("}] # maxVersion # [{");
if (failed(version)) llvm_unreachable("invalid version }] # maxVersion # [{ in }] # name # [{");
return *version;
Expand All @@ -58,15 +59,10 @@ class VHLO_AttrDef<string name,
}

//===----------------------------------------------------------------------===//
// Attributes
// VHLO Attributes
//===----------------------------------------------------------------------===//

def VHLO_Dims : ArrayRefParameter<"int64_t", "Dimension"> {
let parser = "mlir::hlo::parseDimSizes($_parser)";
let printer = "mlir::hlo::printDimSizes($_printer, $_self)";
}

def VHLO_ScatterDimensionNumbers: VHLO_AttrDef<"ScatterDimensionNumbers"> {
def VHLO_ScatterDimensionNumbersAttrV1 : VHLO_AttrDef<"ScatterDimensionNumbersV1"> {
let cppNamespace = "::mlir::vhlo";
let mnemonic = "scatter";
let parameters = (ins
Expand All @@ -78,7 +74,7 @@ def VHLO_ScatterDimensionNumbers: VHLO_AttrDef<"ScatterDimensionNumbers"> {
let assemblyFormat = "`<` struct(params) `>`";
}

def VHLO_GatherDimensionNumbers : VHLO_AttrDef<"GatherDimensionNumbers"> {
def VHLO_GatherDimensionNumbersAttrV1 : VHLO_AttrDef<"GatherDimensionNumbersV1"> {
let cppNamespace = "::mlir::vhlo";
let mnemonic = "gather";
let parameters = (ins
Expand All @@ -90,7 +86,7 @@ def VHLO_GatherDimensionNumbers : VHLO_AttrDef<"GatherDimensionNumbers"> {
let assemblyFormat = "`<` struct(params) `>`";
}

def VHLO_DotDimensionNumbers : VHLO_AttrDef<"DotDimensionNumbers"> {
def VHLO_DotDimensionNumbersAttrV1 : VHLO_AttrDef<"DotDimensionNumbersV1"> {
let cppNamespace = "::mlir::vhlo";
let mnemonic = "dot";
let parameters = (ins
Expand All @@ -102,7 +98,7 @@ def VHLO_DotDimensionNumbers : VHLO_AttrDef<"DotDimensionNumbers"> {
let assemblyFormat = "`<` struct(params) `>`";
}

def VHLO_OutputOperandAlias : VHLO_AttrDef<"OutputOperandAlias", "0.4.0"> {
def VHLO_OutputOperandAliasAttrV1 : VHLO_AttrDef<"OutputOperandAliasV1", "0.4.0"> {
let cppNamespace = "::mlir::vhlo";
let mnemonic = "output_operand_alias";
let parameters = (ins
Expand All @@ -113,7 +109,7 @@ def VHLO_OutputOperandAlias : VHLO_AttrDef<"OutputOperandAlias", "0.4.0"> {
let assemblyFormat = "`<` struct(params) `>`";
}

def VHLO_ArgResultAlias : VHLO_AttrDef<"ArgResultAlias"> {
def VHLO_ArgResultAliasAttrV1 : VHLO_AttrDef<"ArgResultAliasV1"> {
let cppNamespace = "::mlir::vhlo";
let mnemonic = "result_alias";
let parameters = (ins
Expand All @@ -125,53 +121,21 @@ def VHLO_ArgResultAlias : VHLO_AttrDef<"ArgResultAlias"> {
let assemblyFormat = "`<` struct(params) `>`";
}

def VHLO_ChannelHandle : VHLO_AttrDef<"ChannelHandle"> {
def VHLO_ChannelHandleAttrV1 : VHLO_AttrDef<"ChannelHandleV1"> {
let cppNamespace = "::mlir::vhlo";
let mnemonic = "channel_handle";
let parameters = (ins "int64_t":$handle, "int64_t":$type);
let assemblyFormat = "`<` struct(params) `>`";
}

def VHLO_TypeExtensions : VHLO_AttrDef<"TypeExtensions"> {
def VHLO_TypeExtensionsAttrV1 : VHLO_AttrDef<"TypeExtensionsV1"> {
let cppNamespace = "::mlir::vhlo";
let mnemonic = "type_extensions";
let parameters = (ins VHLO_Dims:$bounds);
let assemblyFormat = "`<` struct(params) `>`";
}

def VHLO_LayoutAttr : Attr<
And<[IndexElementsAttr.predicate,
CPred<[{$_self.cast<::mlir::DenseIntElementsAttr>().getType().getRank()
== 1}]>]>,
"A 1D tensor of index type (layout)"> {
let storageType = IndexElementsAttr.storageType;
let returnType = IndexElementsAttr.returnType;
let convertFromStorage = IndexElementsAttr.convertFromStorage;
}

// An array of layout (1D tensor) attributes.
def VHLO_ArrayOfLayoutAttr : TypedArrayAttrBase<VHLO_LayoutAttr,
"Array of layout (1D tensor of index type) attributes">;

// An array of FlatSymbolRef attributes that can be used as a default valued
// attribute.
def VHLO_FlatSymbolRefArrayAttr :
TypedArrayAttrBase<FlatSymbolRefAttr, "flat symbol ref array attribute"> {
let constBuilderCall = "::mlir::ArrayAttr::get($_builder.getContext(), $0)";
}

def VHLO_BoolElementsAttr :
ElementsAttrBase<
And<[CPred<"$_self.isa<::mlir::DenseIntOrFPElementsAttr>()">,
CPred<"$_self.cast<::mlir::DenseIntOrFPElementsAttr>().getType().getElementType().isInteger(1)">]>,
"constant boolean vector/tensor attribute"> {
let storageType = [{ ::mlir::DenseElementsAttr }];
let returnType = [{ ::mlir::DenseElementsAttr }];

let convertFromStorage = "$_self";
}

def VHLO_ConvDimensionNumbers : VHLO_AttrDef<"ConvDimensionNumbers"> {
def VHLO_ConvDimensionNumbersAttrV1 : VHLO_AttrDef<"ConvDimensionNumbersV1"> {
let cppNamespace = "::mlir::vhlo";
let mnemonic = "conv";
let parameters = (ins
Expand All @@ -190,7 +154,7 @@ def VHLO_ConvDimensionNumbers : VHLO_AttrDef<"ConvDimensionNumbers"> {
let assemblyFormat = "`<` struct(params) `>`";
}

def VHLO_ConvolutionAttributes {
def VHLO_ConvolutionAttributesV1 {
dag attributes = (ins
OptionalAttr<VHLO_AnyAttr>:$window_strides,
OptionalAttr<VHLO_AnyAttr>:$padding,
Expand All @@ -204,4 +168,99 @@ def VHLO_ConvolutionAttributes {
);
}

//===----------------------------------------------------------------------===//
// Forked Attributes
//===----------------------------------------------------------------------===//

def VHLO_ArrayAttrV1 : VHLO_AttrDef<"ArrayV1"> {
let mnemonic = "array";
let parameters = (ins ArrayRefParameter<"mlir::Attribute">:$value);
GleasonK marked this conversation as resolved.
Show resolved Hide resolved
let genVerifyDecl = 1;
let extraClassDefinition = [{
LogicalResult ArrayV1Attr::verify(
llvm::function_ref<mlir::InFlightDiagnostic ()> errFn, ArrayRef<mlir::Attribute> value) {
if (!allFromVhlo(value)) return errFn() << "expected array of VHLO attriutes";
return success();
}
}];
let assemblyFormat = "`<` custom<AttributeArray>($value) `>`";
}

def VHLO_DenseIntOrFPElementsData : AttrParameter<"::llvm::ArrayRef<char>", "Array of int"> {
// Custom allocator to copy dense elements data into MLIR Context
let allocator = "$_dst = $_allocator.copyInto($_self);";
}
def VHLO_DenseIntOrFPElementsAttrV1 : VHLO_AttrDef<"DenseIntOrFPElementsV1"> {
let mnemonic = "dense";
let parameters = (ins "::mlir::Type":$type, VHLO_DenseIntOrFPElementsData:$raw_data);
let genVerifyDecl = 1;
let extraClassDefinition = [{
LogicalResult DenseIntOrFPElementsV1Attr::verify(
llvm::function_ref<mlir::InFlightDiagnostic ()> errFn, mlir::Type type, ArrayRef<char>) {
if (!isFromVhlo(type)) errFn() << "expected VHLO type";
return success();
}
}];
let hasCustomAssemblyFormat = 1;
}

def VHLO_FlatSymbolRefAttrV1 : VHLO_AttrDef<"FlatSymbolRefV1"> {
let mnemonic = "sym";
let parameters = (ins "::mlir::Attribute":$root_reference);
let genVerifyDecl = 1;
let extraClassDefinition = [{
LogicalResult FlatSymbolRefV1Attr::verify(
llvm::function_ref<mlir::InFlightDiagnostic ()> errFn, mlir::Attribute rootReference) {
if (!isFromVhlo(rootReference)) return errFn() << "expected VHLO attribute";
return success();
}
}];
let assemblyFormat = "`<` $root_reference `>`";
}

def VHLO_FloatAttrV1 : VHLO_AttrDef<"FloatV1"> {
let mnemonic = "float";
let parameters = (ins "mlir::Type":$type, VHLO_APFloatV1:$value);
let genVerifyDecl = 1;
let extraClassDefinition = [{
LogicalResult FloatV1Attr::verify(
llvm::function_ref<mlir::InFlightDiagnostic ()> errFn, mlir::Type type, APFloat value) {
if (!isFromVhlo(type)) return errFn() << "expected VHLO type";
return success();
}
}];
let assemblyFormat = "`<` $value `:` $type `>`";
}

def VHLO_IntegerAttrV1 : VHLO_AttrDef<"IntegerV1"> {
let mnemonic = "integer";
let parameters = (ins "mlir::Type":$type, "APInt":$value);
let genVerifyDecl = 1;
let extraClassDefinition = [{
LogicalResult IntegerV1Attr::verify(
llvm::function_ref<mlir::InFlightDiagnostic ()> errFn, mlir::Type type, APInt value) {
if (!isFromVhlo(type)) return errFn() << "expected VHLO type";
return success();
}
}];
let hasCustomAssemblyFormat = 1;
}

def VHLO_StringAttrV1 : VHLO_AttrDef<"StringV1"> {
let mnemonic = "string";
let parameters = (ins StringRefParameter<"">:$value);
let assemblyFormat = "`<` $value `>`";
}

def VHLO_UnitAttrV1 : VHLO_AttrDef<"UnitV1"> {
let mnemonic = "unit";
let storageType = [{ ::mlir::UnitAttr }];
let constBuilderCall = "(($0) ? $_builder.getUnitAttr() : nullptr)";
let convertFromStorage = "$_self != nullptr";
let returnType = "bool";
let defaultValue = "false";
let valueType = NoneType;
let isOptional = 1;
}

#endif // STABLEHLO_DIALECT_VHLO_ATTRS
7 changes: 7 additions & 0 deletions stablehlo/dialect/VhloBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
#ifndef STABLEHLO_DIALECT_VHLO_BASE
#define STABLEHLO_DIALECT_VHLO_BASE

include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/OpBase.td"

//===----------------------------------------------------------------------===//
Expand All @@ -28,6 +29,12 @@ def VHLO_AnyType : AnyTypeOf<[AnyType]>;
def VHLO_AnyAttr : AnyAttrOf<[AnyAttr]>;
def VHLO_AnyRegion : Region<CPred<"true">, "any region">;

// Data types used in Attributes / Types
def VHLO_Dims : ArrayRefParameter<"int64_t", "Dimension"> {
let parser = "mlir::hlo::parseDimSizes($_parser)";
let printer = "mlir::hlo::printDimSizes($_printer, $_self)";
}

//===----------------------------------------------------------------------===//
// VHLO traits
//===----------------------------------------------------------------------===//
Expand Down
Loading