Skip to content

Commit

Permalink
Add spec for ConvolutionOp (#441)
Browse files Browse the repository at this point in the history
fixes #359 

The PR addresses the followings:
1. Spec of ConvolutionOp
2. Clarify the semantics of `precision_config` : The precision_config
parameter is a array of enums without any constraint on its size. Need
to resolve this.
- update: Added constraints on the parameter. With that the verifier is
in sync with this spec. Also added
#445 for further exploration.
4. Fix
#360 (comment)
5. Avoid disabling clang formatting in StablehloOps.cpp.
6. Address #399

Only missing peice:

The constraint between output feature size and input batch size. Working
on getting a better understanding on this: Done

Type inference should be "revisit" as well because of #600.
  • Loading branch information
sdasgup3 authored Dec 13, 2022
1 parent a0b804e commit 66d8481
Show file tree
Hide file tree
Showing 10 changed files with 387 additions and 178 deletions.
1 change: 1 addition & 0 deletions docs/images/spec_draft/convolution.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
185 changes: 183 additions & 2 deletions docs/spec_draft.md
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ syntax.
* [complex](#stablehlocomplex)
* [concatenate](#stablehloconcatenate)
* [constant](#stablehloconstant)
* [convolution](#stablehloconvolution)
* [cosine](#stablehlocosine)
* [count_leading_zeros](#stablehlocount_leading_zeros)
* [divide](#stablehlodivide)
Expand Down Expand Up @@ -1814,6 +1815,185 @@ Produces an `output` tensor from a constant `value`.

[Back to Ops](#index-of-ops)

## stablehlo.convolution

### Semantics

Computes dot products between windows of `lhs` and slices of `rhs` and produces
`result`. The following diagram shows how elements in `result` are computed from
`lhs` and `rhs` using a concrete example.

![](images/spec_draft/convolution.svg)

More formally, we start with reframing the inputs to the operation in terms
of `lhs` in order to be able to express windows of `lhs`:

* `lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension))`.
* `lhs_window_strides = lhs_shape(1, window_strides, 1)`.
* `lhs_padding = lhs_shape([0, 0], padding, [0, 0])`.
* `lhs_base_dilations = lhs_shape(1, lhs_dilation, 1)`.
* `lhs_window_dilations = lhs_shape(1, rhs_dilation, 1)`.

This reframing uses the following helper functions:

* `lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension])`.
* `result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension])`.

If `feature_group_count = 1` and `batch_group_count = 1`, then for all
`output_spatial_index` in the index space of `dim(result, output_spatial_dimensions)`,
`result[result_shape(:, output_spatial_index, :)] = dot_product` where:

* `padded_lhs = pad(lhs, 0, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations)`.
* `lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides`.
* `lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations)`.
* `dot_product = dot_general(lhs_window, rhs,
lhs_batching_dimensions=[],
lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension],
rhs_batching_dimensions=[],
rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension])`.

If `feature_group_count > 1`:

* `lhses = split(lhs, feature_group_count, input_feature_dimension)`.
* `rhses = split(rhs, feature_group_count, kernel_output_feature_dimension)`.
* `results[:] = convolution(lhses[:], rhses[:], ..., feature_group_count=1, ...)`.
* `result = concatenate(results, output_feature_dimension)`.

If `batch_group_count > 1`:

* `lhses = split(lhs, batch_group_count, input_batch_dimension)`.
* `rhses = split(rhs, batch_group_count, kernel_output_feature_dimension)`.
* `results[:] = convolution(lhses[:], rhses[:], ..., batch_group_count=1, ...)`.
* `result = concatenate(results, output_feature_dimension)`.

### Inputs

| Name | Type | Constraints |
|-----------------------------------|-------------------------------------------------------------|----------------------------------------|
| `lhs` | tensor of any supported type | (C1), (C2), (C11), (C12), (C26), (C27) |
| `rhs` | tensor of any supported type | (C1), (C2), (C15), (C16), (C17), (C26) |
| `window_strides` | 1-dimensional tensor constant of type `si64` | (C3), (C4), (C26) |
| `padding` | 2-dimensional tensor constant of type `si64` | (C5), (C26) |
| `lhs_dilation` | 1-dimensional tensor constant of type `si64` | (C6), (C7), (C26) |
| `rhs_dilation` | 1-dimensional tensor constant of type `si64` | (C8), (C9), (C26) |
| `window_reversal` | 1-dimensional tensor constant of type `boolean` | (C10) |
| `input_batch_dimension` | constant of type `si64` | (C11), (C14), (C26) |
| `input_feature_dimension` | constant of type `si64` | (C12), (C14) |
| `input_spatial_dimensions` | 1-dimensional tensor constant of type `si64` | (C13), (C14), (C26) |
| `kernel_input_feature_dimension` | constant of type `si64` | (C15), (C19) |
| `kernel_output_feature_dimension` | constant of type `si64` | (C16), (C17), (C19), (C26) |
| `kernel_spatial_dimensions` | 1-dimensional tensor constant of type `si64` | (C18), (C19), (C26) |
| `output_batch_dimension` | constant of type `si64` | (C21), (C26) |
| `output_feature_dimension` | constant of type `si64` | (C21), (C26) |
| `output_spatial_dimensions` | 1-dimensional tensor constant of type `si64` | (C20), (C21), (C26) |
| `feature_group_count` | constant of type `si64` | (C12), (C15), (C17), (C22), (C24) |
| `batch_group_count` | constant of type `si64` | (C11), (C16), (C23), (C24), (C26) |
| `precision_config` | variadic number of enum of `DEFAULT`, `HIGH`, and `HIGHEST` | (C25) |


### Outputs

| Name | Type | Constraints |
|----------|------------------------------|---------------------|
| `result` | tensor of any supported type | (C26), (C27), (C28) |

### Constraints

* (C1) $N =$ rank(`lhs`) $=$ rank(`rhs`).
* (C2) element_type(`lhs`) $=$ element_type(`rhs`).
* (C3) size(`window_strides`) $= N - 2$ .
* (C4) `window_strides[i]` $\gt 0$ for all i $\in$ [0, size(`window_strides`)).
* (C5) dim(`padding`, 0) $= N - 2$ and dim(`padding`, 1) = 2.
* (C6) size(`lhs_dilation`) $= N - 2$.
* (C7) `lhs_dilation[i]` $\gt 0$ for all i $\in$ [0, size(`lhs_dilation`)).
* (C8) size(`rhs_dilation`) $= N - 2$.
* (C9) `rhs_dilation[i]` $\gt 0$ for all i $\in$ [0, size(`rhs_dilation`)).
* (C10) size(`window_reversal`) $= N - 2$.
* (C11) `dim(lhs, input_batch_dimension) % batch_group_count = 0`.
* (C12) `dim(lhs, input_feature_dimension) % feature_group_count = 0.
* (C13) size(`input_spatial_dimensions`) $= N - 2$.
* (C14) Given `input_dimensions = [input_batch_dimension] +
input_spatial_dimensions + [input_feature_dimension]`.
* All dimensions in `input_dimensions` are unique.
* For any i $\in$ `input_dimensions`, 0 $\le$ i $\lt$ N.
* (C15) `dim(rhs, kernel_input_feature_dimension = dim(lhs, input_feature_dimension) / feature_group_count`.
* (C16) `dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0`.
* (C17) `dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0`.
* (C18) size(`kernel_spatial_dimensions`) $= N - 2$.
* (C19) Given `kernel_dimensions = kernel_spatial_dimensions +
[kernel_input_feature_dimension] + [kernel_output_feature_dimension]`.
* All dimensions in `kernel_dimensions` are unique.
* For any i $\in$ `kernel_dimensions`, 0 $\le$ i $\lt$ N.
* (C20) size(`output_spatial_dimensions`) $= N - 2$.
* (C21) Given `output_dimensions = [output_batch_dimension] +
output_spatial_dimensions + [output_feature_dimension]`.
* All dimensions in `output_dimensions` are unique.
* For any i $\in$ `output_dimensions`, 0 $\le$ i $\lt$ N.
* (C22) `feature_group_count > 0`.
* (C23) `batch_group_count > 0`.
* (C24) `feature_group_count` $= 1$ OR `batch_group_count` $= 1$.
* (C25) size(`precision_config`) $=$ 2.
* (C26) For result_dim $\in$ [0, N), `dim(result, result_dim)` is given by
* `dim(lhs, input_batch_dimension) / batch_group_count`, if `result_dim = output_batch_dimension`.
* `dim(rhs, kernel_output_feature_dimension)`, if `result_dim = output_feature_dimension`.
* `num_windows` otherwise, where:
* `output_spatial_dimensions[spatial_dim] = result_dim`.
* `lhs_dim = input_spatial_dimensions[spatial_dim]`.
* `rhs_dim = kernel_spatial_dimensions[spatial_dim]`.
* `dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) == 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1`.
* `padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]`.
* `dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) == 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1`.
* `num_windows = (padded_input_shape[lhs_dim] == 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]) ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1`.
* (C27) element_type(`result`) $=$ element_type(`lhs`).
* (C28) rank(`result`) $= N$.

### Examples

```mlir
// %lhs: [[
// [
// [1], [2], [5], [6]
// ],
// [
// [3], [4], [7], [8]
// ],
// [
// [10], [11], [14], [15]
// ],
// [
// [12], [13], [16], [17]
// ]
// ]]
//
// %rhs : [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
window_strides = dense<4> : tensor<2xi64>,
padding = dense<0> : tensor<2x2xi64>,
lhs_dilation = dense<2> : tensor<2xi64>,
rhs_dilation = dense<1> : tensor<2xi64>,
window_reversal = dense<false> : tensor<2xi1>,
// In the StableHLO dialect, dimension numbers are encoded via:
// `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
// "b" is batch dimenion, "f" is feature dimension,
// "i" is input feature dimension, "o" is output feature dimension,
// "0/1/etc" are spatial dimensions.
dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
feature_group_count = 1 : i64,
batch_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi32>, tensor<3x3x1x1xi32>) -> tensor<1x2x2x1xi32>
// %result: [[
// [[10], [26]],
// [[46], [62]]
// ]]
```

[Back to Ops](#index-of-ops)

## stablehlo.cosine

### Semantics
Expand Down Expand Up @@ -3857,10 +4037,11 @@ More formally, `results[:][result_index] = reduce(windows, init_values, axes(inp
* (C13) `body` has type `(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)`
where `Ek = element_type(inputs[0])`.
* (C14) All `results` have the same shape.
* (C15) `shape(results[0]) = (padded_input_shape == 0 || window_shape > padded_input_shape) ? 0 : floor((padded_input_shape - window_shape) / window_strides) + 1:`
* (C15) `shape(results[0]) = num_windows`
* `dilated_input_shape = shape(inputs[0]) == 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1`.
* `padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1]`.
* `window_shape = window_dimensions == 0 ? 0 : (window_dimensions - 1) * window_dilations + 1`.
* `dilated_window_shape = window_dimensions == 0 ? 0 : (window_dimensions - 1) * window_dilations + 1`.
* `num_windows = (padded_input_shape == 0 || dilated_window_shape > padded_input_shape) ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1`.
* (C16) `element_type(results[k]) = element_type(init_values[k])` for any k
$\in$ [0, N).

Expand Down
2 changes: 1 addition & 1 deletion docs/status.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ one of the following tracking labels.
| concatenate | yes | yes | yes | yes | no |
| constant | yes | yes | yes | yes | yes |
| convert | no | yes* | infeasible | yes | no |
| convolution | no | yes* | yes* | revisit | no |
| convolution | revisit | yes | revisit | revisit | no |
| cosine | yes | yes | yes | yes | yes |
| count_leading_zeros | yes | yes | yes | yes | no |
| create_token | no | yes* | yes* | yes | no |
Expand Down
Loading

0 comments on commit 66d8481

Please sign in to comment.