Skip to content

Commit

Permalink
Add spec for CustomCallOp (#636)
Browse files Browse the repository at this point in the history
closes #518
  • Loading branch information
subhankarshah authored Dec 13, 2022
1 parent 66d8481 commit 914ccd5
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 35 deletions.
42 changes: 42 additions & 0 deletions docs/spec_draft.md
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,7 @@ syntax.
* [convolution](#stablehloconvolution)
* [cosine](#stablehlocosine)
* [count_leading_zeros](#stablehlocount_leading_zeros)
* [custom_call](#stablehlocustom_call)
* [divide](#stablehlodivide)
* [dot_general](#stablehlodot_general)
* [dynamic_slice](#stablehlodynamic_slice)
Expand Down Expand Up @@ -2066,6 +2067,47 @@ tensor and produces a `result` tensor.

[Back to Ops](#index-of-ops)

## stablehlo.custom_call

### Semantics

Encapsulates an implementation-defined operation `call_target_name` that takes
`inputs` and `called_computations` and produces `results`. `has_side_effect`,
`backend_config` and `api_version` may be used to provide additional
implementation-defined metadata.

### Inputs

| Name | Type |
|-----------------------|-----------------------------------------------------------------|
| `inputs` | variadic number of values of any supported type |
| `call_target_name` | constant of type `string` |
| `has_side_effect` | constant of type `i1` |
| `backend_config` | constant of type `string` |
| `api_version` | enum of `API_VERSION_ORIGINAL`, `API_VERSION_STATUS_RETURNING`, |
| | and `API_VERSION_STATUS_RETURNING_UNIFIED` |
| `called_computations` | variadic number of `function` |

### Outputs

| Name | Type |
|-----------|-------------------------------------------------|
| `results` | variadic number of values of any supported type |

### Examples

```mlir
%results = "stablehlo.custom_call"(%inputs0) {
call_target_name = "foo",
has_side_effect = false,
backend_config = "bar",
api_version = 1 : i32,
called_computations = [@foo]
} : (tensor<f32>) -> tensor<f32>
```

[Back to Ops](#index-of-ops)

## stablehlo.divide

### Semantics
Expand Down
2 changes: 1 addition & 1 deletion docs/status.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ one of the following tracking labels.
| create_token | no | yes* | yes* | yes | no |
| cross-replica-sum | no | revisit | yes* | no | no |
| cstr_reshapable | no | revisit | no | yes | no |
| custom_call | no | revisit | infeasible | yes | no |
| custom_call | yes | yes | infeasible | yes | no |
| divide | yes | yes | yes | yes | no |
| dot | no | revisit | revisit | yes | no |
| dot_general | yes | revisit | revisit | no | no |
Expand Down
45 changes: 11 additions & 34 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2137,44 +2137,21 @@ def StableHLO_CustomCallOp: StableHLO_Op<"custom_call",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "CustomCall operator";
let description = [{
A custom call invokes code external to XLA. The `args` are passed to the
external code, and the external code is expected to produce a result of the
given type. The exact mechanism is backend-specific. For example, in the CPU
backend, a call instruction is emitted which targets a symbol with the name
`call_target_name`.
Encapsulates an implementation-defined operation `call_target_name` that
takes `inputs` and `called_computations` and produces `results`.

`call_target_name` and `backend_config` can be arbitrary strings, but
`call_target_name` should be short as it may be used in labels.
`backend_config` can encode arbitrarily large amounts of information.

`has_side_effect` must be true if the custom call has side-effects.
`api_version` specifies the version of the API used by the custom call
function.

A custom call may apply functions within the scope of the parent module.
They can be referenced using `called_computations` attribute.

A custom call can also have layout constraints on operands and results which
can be specified as optional `operand_layouts` and `result_layouts`
attributes. The layout attribute is an array of rank-1 index tensors and the
i-th layout attribute specifies the layout for i-th operand/result.

The `operand_layouts` & `result_layouts` attributes can be specified under
the following constraints:
1) Either both `operand_layouts` and `result_layouts` are specified or none.
2) None of the operands are of tuple type.
3) None of the results are of tuple type except the common case of single
tuple result packing non-tuple values is allowed. In this case the i-th
`result_layouts` attribute specifies the layout of i-th element in the
result tuple.

See https://www.tensorflow.org/xla/operation_semantics#customcall.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec_draft.md#stablehlocustom_call

Example:

```mlir
%1 = stablehlo.custom_call @foo(%arg0, %arg1) {backend_config = "bar", has_side_effect = true}
: (tensor<2x3xf32>, tensor<5x5xf32>) -> tensor<1x2x3xf32>
%results = "stablehlo.custom_call"(%inputs0) {
call_target_name = "foo",
has_side_effect = false,
backend_config = "bar",
api_version = 1 : i32,
called_computations = [@foo]
} : (tensor<f32>) -> tensor<f32>
```
}];

Expand Down

0 comments on commit 914ccd5

Please sign in to comment.