Skip to content

Commit

Permalink
Integrate StableHLO at openxla/stablehlo@c4a2b74
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 562067154
  • Loading branch information
GleasonK authored and TensorFlow MLIR Team committed Sep 1, 2023
1 parent 85801cc commit 6a50444
Show file tree
Hide file tree
Showing 3,910 changed files with 11,997 additions and 4,029 deletions.
The diff you're trying to view is too large. We only load the first 3000 changed files.
4 changes: 2 additions & 2 deletions stablehlo/WORKSPACE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ http_archive(
],
)

LLVM_COMMIT = "3823395c9d715a1ed993eb13fc2bec97372f5655"
LLVM_COMMIT = "f0f395e00e2ec3f1f20ca9021d1554fde73d56c9"

LLVM_SHA256 = "6ee059f385900bb7b0b6c5b6e7d9145cbf216090b03a58c0f96628e6e88b4911"
LLVM_SHA256 = "4f59f9ff83edc7be5d762682b2da960e40cfba1087f63d3934a6975aede77a27"

http_archive(
name = "llvm-raw",
Expand Down
101 changes: 71 additions & 30 deletions stablehlo/docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -198,11 +198,12 @@ have quantized element types, instead of regular element types.
In quantized tensors, quantization can be **per-tensor**, meaning, having
one `scale` and `zero_point` for the entire tensor or can be **per-axis**,
meaning, having multiple `scales` and `zero_points`, one pair per slice of
a particular dimension `quantized_dimension`. More formally, in a tensor `t` of
with per-axis quantization, there are `dim(t, quantized_dimension)` slices
of the `quantized_dimension`: `t[:, ..., 0, ..., :], t[:, ..., 1, ..., :]`, etc.
All elements in the `i`th slice use `scales[i]` and `zero_points[i]` as their
quantization parameters. Quantized tensor types have the following constraints:
a particular dimension `quantization_dimension`. More formally, in a tensor `t`
with per-axis quantization, there are `dim(t, quantization_dimension)` slices
of the `quantization_dimension`: `t[:, ..., 0, ..., :], t[:, ..., 1, ..., :]`,
etc. All elements in the `i`th slice use `scales[i]` and `zero_points[i]` as
their quantization parameters. Quantized tensor types have the following
constraints:

* For per-tensor quantization:
* No additional constraints.
Expand Down Expand Up @@ -1364,26 +1365,37 @@ in the `operand` tensor and produces a `result` tensor. More formally,

#### Inputs

| Label | Name | Type | Constraints |
|-------|------------------------|----------------------------------------------|---------------|
| (I1) | `operand` | tensor or per-tensor quantized tensor | (C1-C2), (C5) |
| (I2) | `broadcast_dimensions` | 1-dimensional tensor constant of type `si64` | (C2-C5) |
| Label | Name | Type | Constraints |
|-------|------------------------|----------------------------------------------|------------------|
| (I1) | `operand` | tensor or quantized tensor | (C1-C2), (C5-C6) |
| (I2) | `broadcast_dimensions` | 1-dimensional tensor constant of type `si64` | (C2-C6) |

#### Outputs

| Name | Type | Constraints |
|----------|---------------------------------------|------------------|
| `result` | tensor or per-tensor quantized tensor | (C1), (C3), (C5) |
| Name | Type | Constraints |
|----------|----------------------------|---------------------|
| `result` | tensor or quantized tensor | (C1), (C3), (C5-C6) |

#### Constraints

* (C1) `element_type(operand) = element_type(result)`.
* (C1) `element_type(result)` is given by:
* `element_type(operand)`, if `!is_per_axis_quantized(operand)`.
* `element_type(operand)` except that `quantization_dimension(operand)`,
`scales(operand)`, and `zero_points(operand)` may differ from
`quantization_dimension(result)`, `scales(result)`, and `zero_points(result)`
resp., otherwise.
* (C2) `size(broadcast_dimensions) = rank(operand)`.
* (C3) `0 <= broadcast_dimensions < rank(result)`.
* (C4) `is_unique(broadcast_dimensions)`.
* (C5) For all `d` in `axes(operand)`:
* `dim(operand, d) = 1` or
* `dim(operand, d) = dim(result, broadcast_dimensions[d])`.
* (C6) If `is_per_axis_quantized(result)`:
* `quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]`.
* If `dim(operand, quantization_dimension(operand)) = 1`, then
`scales(result)[i] = scales(operand)[0] and zero_points(result)[i] =
zero_points(operand)[0] for i in
range(dim(result, quantization_dimension(result)))`.

#### Examples

Expand Down Expand Up @@ -3001,6 +3013,8 @@ Extracts element at `index` position of the `operand` tuple and produces a
// %result: [1.0, 2.0]
```

&nbsp;[More Examples](../stablehlo/tests/interpret_tuple_and_get_tuple_element.mlir)

### if

#### Semantics
Expand Down Expand Up @@ -4359,20 +4373,36 @@ ordering of `index_space(result)` and `index_space(operand)`.

#### Inputs

| Label | Name | Type | Constraints |
|-------|-----------|---------------------------------------|-------------|
| (I1) | `operand` | tensor or per-tensor quantized tensor | (C1-C2) |
| Label | Name | Type | Constraints |
|-------|-----------|----------------------------|-------------|
| (I1) | `operand` | tensor or quantized tensor | (C1-C3) |

#### Outputs

| Name | Type | Constraints |
|----------|---------------------------------------|-------------|
| `result` | tensor or per-tensor quantized tensor | (C1-C2) |
| Name | Type | Constraints |
|----------|----------------------------|-------------|
| `result` | tensor or quantized tensor | (C1-C3) |

#### Constraints

* (C1) `element_type(operand) = element_type(result)`.
* (C1) `element_type(result)` is given by:
* `element_type(operand)`, if `!is_per_axis_quantized(operand)`.
* `element_type(operand)` except that `quantization_dimension(operand)` and
`quantization_dimension(result)` may differ, otherwise.
* (C2) `size(operand) = size(result)`.
* (C3) If `is_per_axis_quantized(operand)`:
* `reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]),
init_values=1, dimensions=[0], body=lambda x, y: x * y) =
reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]),
init_values=1, dimensions=[0], body=lambda x, y: x * y)`.
* `dim(operand, quantization_dimension(operand)) =
dim(result, quantization_dimension(result))`.
* `reduce(dims(operand,
[quantization_dimension(operand) + 1, ..., rank(operand) - 1]),
init_values=1, dimensions=[0], body=lambda x, y: x * y) =
reduce(dims(result,
[quantization_dimension(result) + 1, ..., rank(result) - 1]),
init_values=1, dimensions=[0], body=lambda x, y: x * y)`.

#### Examples

Expand Down Expand Up @@ -5471,20 +5501,26 @@ where `result_index[d] = operand_index[permutation[d]]`.

| Label | Name | Type | Constraints |
|-------|---------------|----------------------------------------------|-------------|
| (I1) | `operand` | tensor or per-tensor quantized tensor | (C1-C3) |
| (I2) | `permutation` | 1-dimensional tensor constant of type `si64` | (C2), (C3) |
| (I1) | `operand` | tensor or quantized tensor | (C1-C4) |
| (I2) | `permutation` | 1-dimensional tensor constant of type `si64` | (C2-C4) |

#### Outputs

| Name | Type | Constraints |
|----------|---------------------------------------|-------------|
| `result` | tensor or per-tensor quantized tensor | (C1), (C3) |
| Name | Type | Constraints |
|----------|----------------------------|---------------|
| `result` | tensor or quantized tensor | (C1), (C3-C4) |

#### Constraints

* (C1) `element_type(operand) = element_type(result)`.
* (C1) `element_type(result)` is given by:
* `element_type(operand)`, if `!is_per_axis_quantized(operand)`.
* `element_type(operand)` except that `quantization_dimension(operand)` and
`quantization_dimension(result)` may differ, otherwise.
* (C2) `permutation` is a permutation of `range(rank(operand))`.
* (C3) `shape(result) = dim(operand, permutation...)`.
* (C4) If `is_per_axis_quantized(result)`, then
`quantization_dimension(operand) =
permutation(quantization_dimension(result))`.
#### Examples

Expand Down Expand Up @@ -5615,6 +5651,8 @@ Produces a `result` tuple from values `val`.
// %result: ([1.0, 2.0], (3))
```

&nbsp;[More Examples](../stablehlo/tests/interpret_tuple_and_get_tuple_element.mlir)

### uniform_dequantize

#### Semantics
Expand Down Expand Up @@ -6191,10 +6229,10 @@ def element_type(x: Value | Placeholder | Type):
```

* `is_per_axis_quantized(x: Value | Placeholder | Type) -> Value` is a shortcut
for `is_quantized(x) and quantized_dimension(x) is not None`.
for `is_quantized(x) and quantization_dimension(x) is not None`.

* `is_per_tensor_quantized(x: Value | Placeholder | Type) -> Value` is a
shortcut for `is_quantized(x) and quantized_dimension(x) is None`.
shortcut for `is_quantized(x) and quantization_dimension(x) is None`.

* `is_quantized(x: Value | Placeholder | Type) -> Value` is a shortcut for
`is_quantized_tensor_element_type(x)`.
Expand Down Expand Up @@ -6269,6 +6307,9 @@ If `x` is not a tensor or `dim(x, axis) % num_results != 0`, returns `None`.
* `dim(x: Value | Placeholder | Type, axis: Value) -> Value` is a shortcut for
`shape(x)[axis]`.

* `dims(x: Value | Placeholder | Type, axes: List) -> List` is a shortcut for
`list(map(lambda axis: dim(x, axis), axes))`.

* `index_space(x: Value | Placeholder | Type) -> Value` is defined on tensors
and returns `size(x)` indices for the corresponding `TensorType` sorted in
ascending lexicographical order, i.e. `[0, ..., 0]`, `[0, ..., 1]`, ...,
Expand Down Expand Up @@ -6326,7 +6367,7 @@ def compute_zero_points(quantized_type, result_type):
return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantized_dimension(quantized_type)
d = quantization_dimension(quantized_type)
zero_points[i] = zero_points(quantized_type)[i[d]]
return zero_points

Expand All @@ -6336,7 +6377,7 @@ def compute_scales(quantized_type, result_type):
type(result_type))
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantized_dimension(quantized_type)
d = quantization_dimension(quantized_type)
scales[i] = scales(quantized_type)[i[d]]
return scales

Expand Down
4 changes: 2 additions & 2 deletions stablehlo/docs/status.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ one of the following tracking labels.
| floor | yes | yes | yes | yes | yes |
| gather | yes | yes | yes | no | yes |
| get_dimension_size | yes | yes | yes | yes | yes |
| get_tuple_element | yes | yes | yes | yes | no |
| get_tuple_element | yes | yes | yes | yes | yes |
| if | yes | revisit | yes | no | yes |
| imag | yes | yes | yes | yes | yes |
| infeed | yes | revisit | infeasible | no | no |
Expand Down Expand Up @@ -151,7 +151,7 @@ one of the following tracking labels.
| trace | no | revisit | no | yes | revisit |
| transpose | yes | yes | yes | yes | yes |
| triangular_solve | yes | revisit | yes | no | revisit |
| tuple | yes | yes | yes | yes | no |
| tuple | yes | yes | yes | yes | yes |
| unary_einsum | no | revisit | no | yes | revisit |
| uniform_dequantize | yes | yes | yes | yes | no |
| uniform_quantize | yes | revisit | infeasible | yes | no |
Expand Down
14 changes: 7 additions & 7 deletions stablehlo/stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1495,7 +1495,7 @@ def StableHLO_ReduceOp: StableHLO_ShapedInterfaceOp<"reduce", [
// StableHLO tuple op definitions.
//===----------------------------------------------------------------------===//
def StableHLO_GetTupleElementOp: StableHLO_Op<"get_tuple_element", [Pure,
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
DeclareOpInterfaceMethods<InferTypeOpInterface> /*get_tuple_element_c2*/]> {
let summary = "GetTupleElement operation";
let description = [{
Extracts element at `index` position of the `operand` tuple and produces a
Expand All @@ -1506,12 +1506,12 @@ def StableHLO_GetTupleElementOp: StableHLO_Op<"get_tuple_element", [Pure,

Example:
```mlir
%result = stablehlo.get_tuple_element %operand[0] : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
%result = stablehlo.get_tuple_element %operand[0] : (tuple<tensor<2xf64>, tuple<tensor<i64>>>) -> tensor<2xf64>
```
}];
let arguments = (ins
HLO_Tuple:$operand,
I32Attr:$index
HLO_Tuple:$operand, /*get_tuple_element_i1*/
I32Attr:$index /*get_tuple_element_i2*/
);

let results = (outs HLO_TensorOrTokenOrTuple);
Expand All @@ -1522,7 +1522,7 @@ def StableHLO_GetTupleElementOp: StableHLO_Op<"get_tuple_element", [Pure,
}

def StableHLO_TupleOp : StableHLO_Op<"tuple", [Pure,
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
DeclareOpInterfaceMethods<InferTypeOpInterface> /*tuple_c1*/]> {
let summary = "Tuple operation";
let description = [{
Produces a `result` tuple from values `val`.
Expand All @@ -1532,11 +1532,11 @@ def StableHLO_TupleOp : StableHLO_Op<"tuple", [Pure,

Example:
```mlir
%result = stablehlo.tuple %val0, %val1 : tuple<tensor<2xf32>, tuple<tensor<i32>>>
%result = stablehlo.tuple %val0, %val1 : tuple<tensor<2xf64>, tuple<tensor<i64>>>
```
}];

let arguments = (ins Variadic<HLO_TensorOrTokenOrTuple>:$val);
let arguments = (ins Variadic<HLO_TensorOrTokenOrTuple>:$val /*tuple_i1*/);
let results = (outs HLO_Tuple:$result);

let assemblyFormat = [{
Expand Down
5 changes: 4 additions & 1 deletion stablehlo/stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2320,11 +2320,13 @@ LogicalResult inferGetTupleElementOp(
SmallVectorImpl<Type>& inferredReturnTypes) {
auto operandType = operand.getType().dyn_cast<TupleType>();
if (!operandType) return failure();
// get_tuple_element_c1
if (index < 0 || index >= static_cast<int64_t>(operandType.size()))
return emitOptionalError(location, "index ", index,
" is out of bounds of operand with size ",
operandType.size());

// get_tuple_element_c2
inferredReturnTypes.push_back(operandType.getType(index));
return success();
}
Expand Down Expand Up @@ -2515,7 +2517,7 @@ LogicalResult inferPadOp(std::optional<Location> location, Type operandType,
int64_t operandSizeOrBound = isStaticDim ? inputShape[i] : inputBounds[i];
int64_t resultSizeOrBound =
operandSizeOrBound + paddingLowVal + paddingHighVal +
std::max<int64_t>(operandSizeOrBound - 1, 0LL) * paddingInteriorVal;
std::max<int64_t>(operandSizeOrBound - 1, 0ll) * paddingInteriorVal;

// pad_c4
if (resultSizeOrBound < 0) {
Expand Down Expand Up @@ -3006,6 +3008,7 @@ LogicalResult inferTriangularSolveOp(
LogicalResult inferTupleOp(MLIRContext* context, std::optional<Location>,
ValueRange val,
SmallVectorImpl<Type>& inferredReturnTypes) {
// tuple_c1
inferredReturnTypes.push_back(TupleType::get(context, val.getTypes()));
return success();
}
Expand Down
2 changes: 1 addition & 1 deletion stablehlo/stablehlo/dialect/Version.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class Version {
static FailureOr<Version> fromString(llvm::StringRef versionRef);

/// Return a Version representing the current VHLO dialect version.
static Version getCurrentVersion() { return Version(0, 14, 16); }
static Version getCurrentVersion() { return Version(0, 14, 17); }

/// Return a Version representing the minimum supported VHLO dialect version.
static Version getMinimumVersion() { return Version(0, 9, 0); }
Expand Down
10 changes: 5 additions & 5 deletions stablehlo/stablehlo/reference/Element.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,7 @@ Element reducePrecision(const Element &el, int32_t exponentBits,
int32_t srcMantissaBits = type.getFPMantissaWidth() - 1;
auto destMantissaBits = mantissaBits;
if (destMantissaBits < srcMantissaBits) {
auto lastMantissaBitMask = 1UL << (srcMantissaBits - destMantissaBits);
auto lastMantissaBitMask = 1ull << (srcMantissaBits - destMantissaBits);

// Compute rounding bias for round-to-nearest with ties to even.
auto baseRoundingBias = (lastMantissaBitMask >> 1) - 1;
Expand All @@ -971,15 +971,15 @@ Element reducePrecision(const Element &el, int32_t exponentBits,
auto srcExponentBits = bitWidth - srcMantissaBits - 1;
auto destExponentBits = exponentBits;
if (destExponentBits < srcExponentBits) {
auto signBitMask = 1UL << (bitWidth - 1);
auto expBitsMask = ((1UL << srcExponentBits) - 1) << srcMantissaBits;
auto signBitMask = 1ull << (bitWidth - 1);
auto expBitsMask = ((1ull << srcExponentBits) - 1) << srcMantissaBits;

// An exponent of 2^(n-1)-1 (i.e. 0b0111...) with 0 being the most
// significant bit is equal to 1.0f for all exponent sizes. Adding 2^(n-1)-1
// to this results in highest non-infinite exponent, and subtracting
// 2^(n-1)-1 results in lowest exponent (i.e. 0.0f) for a bit size of n.
auto exponentBias = (1UL << (srcExponentBits - 1)) - 1;
auto reducedExponentBias = (1UL << (destExponentBits - 1)) - 1;
auto exponentBias = (1ull << (srcExponentBits - 1)) - 1;
auto reducedExponentBias = (1ull << (destExponentBits - 1)) - 1;
auto reducedMaxExponent = exponentBias + reducedExponentBias;
auto reducedMinExponent = exponentBias - reducedExponentBias;

Expand Down
Loading

0 comments on commit 6a50444

Please sign in to comment.