Skip to content

Commit

Permalink
Port type inference for 6 ops from StableHLO to MHLO
Browse files Browse the repository at this point in the history
Ops:
  1) AfterAllOp: openxla/stablehlo#708.
  2) CreateTokenOp: openxla/stablehlo#711.
  3) DynamicUpdateSliceOp: openxla/stablehlo#686 and openxla/stablehlo#757.
  4) OptimizationBarrierOp: openxla/stablehlo#575.
  5) OutfeedOp: openxla/stablehlo#713.
  6) SendOp: openxla/stablehlo#580.

This PR prepares for migration from producing MHLO to producing StableHLO by
aligning type inference between dialects, so that switching from one to another
doesn't need changes to calls to Python builders.

PiperOrigin-RevId: 495336100
  • Loading branch information
Eugene Burmako authored and jax authors committed Dec 14, 2022
1 parent 5832dfd commit f37b67f
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 24 deletions.
6 changes: 5 additions & 1 deletion jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from jax._src.api_util import flatten_fun, shaped_abstractify
from jax._src.lax import lax as lax_internal
from jax._src.lax import convolution as lax_convolution
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir.dialects import mhlo
from jax._src.traceback_util import api_boundary
from jax._src.util import (unzip2, wraps, split_list, partition_list, safe_map,
Expand Down Expand Up @@ -621,7 +622,10 @@ def _optimization_barrier_lowering_rule(ctx, *args):
barrier_types = map(mlir.aval_to_ir_types, ctx.avals_in)
flat_barrier_types = util.flatten(barrier_types)
flat_args = mlir.flatten_lowering_ir_args(args)
barrier_op = mhlo.OptimizationBarrierOp(flat_barrier_types, flat_args)
if xc.mlir_api_version < 40:
barrier_op = mhlo.OptimizationBarrierOp(flat_barrier_types, flat_args)
else:
barrier_op = mhlo.OptimizationBarrierOp(flat_args)
return util.unflatten(barrier_op.results, map(len, barrier_types))

def _optimization_barrier(arg):
Expand Down
6 changes: 5 additions & 1 deletion jax/_src/lax/control_flow/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from jax._src.lax import lax
from jax._src.lax import slicing
from jax._src.lax import windowed_reductions
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import mhlo
from jax._src.numpy.ufuncs import logaddexp
Expand Down Expand Up @@ -1571,7 +1572,10 @@ def _pred_bcast_select_mhlo(ctx,
if x_y_aval is core.abstract_token:
x, = xs
y, = ys
return [mhlo.AfterAllOp(mlir.aval_to_ir_type(x_y_aval), [x, y]).result]
if xc.mlir_api_version < 40:
return [mhlo.AfterAllOp(mlir.aval_to_ir_type(x_y_aval), [x, y]).result]
else:
return [mhlo.AfterAllOp([x, y]).result]
else:
assert isinstance(x_y_aval, core.ShapedArray), x_y_aval
x, = xs
Expand Down
26 changes: 19 additions & 7 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -4133,7 +4133,10 @@ def create_token(_=None):

def _create_token_lowering(ctx, *operands):
aval_out, = ctx.avals_out
return mhlo.CreateTokenOp(mlir.aval_to_ir_type(aval_out)).results
if xc.mlir_api_version < 40:
return mhlo.CreateTokenOp(mlir.aval_to_ir_type(aval_out)).results
else:
return mhlo.CreateTokenOp().results

mlir.register_lowering(create_token_p, _create_token_lowering)

Expand All @@ -4156,7 +4159,10 @@ def _after_all_abstract_eval(*operands):

def _after_all_lowering(ctx, *operands):
aval_out, = ctx.avals_out
return mhlo.AfterAllOp(mlir.aval_to_ir_type(aval_out), operands).results
if xc.mlir_api_version < 40:
return mhlo.AfterAllOp(mlir.aval_to_ir_type(aval_out), operands).results
else:
return mhlo.AfterAllOp(operands).results

mlir.register_lowering(after_all_p, _after_all_lowering)

Expand Down Expand Up @@ -4252,11 +4258,17 @@ def _outfeed_abstract_eval(token, *xs, partitions):

def _outfeed_lowering(ctx, token, *xs, partitions):
token_aval = ctx.avals_in[0]
outfeed = mhlo.OutfeedOp(
mlir.aval_to_ir_type(token_aval),
mlir.flatten_lowering_ir_args(xs),
token,
outfeed_config=ir.StringAttr.get(''))
if xc.mlir_api_version < 40:
outfeed = mhlo.OutfeedOp(
mlir.aval_to_ir_type(token_aval),
mlir.flatten_lowering_ir_args(xs),
token,
outfeed_config=ir.StringAttr.get(''))
else:
outfeed = mhlo.OutfeedOp(
mlir.flatten_lowering_ir_args(xs),
token,
outfeed_config=ir.StringAttr.get(''))
if partitions is not None:
mlir.set_sharding(outfeed, xla.sharding_to_proto(partitions))
return outfeed.results
Expand Down
40 changes: 30 additions & 10 deletions jax/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,12 @@ def _device_array_constant_handler(val, canonicalize_types):
for t in device_array.device_array_types:
register_constant_handler(t, _device_array_constant_handler)

register_constant_handler(
core.Token, lambda _, __: [mhlo.CreateTokenOp(mhlo.TokenType.get()).result])
def _token_constant_handler(val, canonicalize_types):
if mlir_api_version < 40:
return [mhlo.CreateTokenOp(mhlo.TokenType.get()).result]
else:
return [mhlo.CreateTokenOp().result]
register_constant_handler(core.Token, _token_constant_handler)

# Source locations

Expand Down Expand Up @@ -760,8 +764,11 @@ def token_type() -> Sequence[ir.Type]:
return [mhlo.TokenType.get()]

def create_token() -> Token:
return wrap_singleton_ir_values(
mhlo.CreateTokenOp(mhlo.TokenType.get()).result)
if mlir_api_version < 40:
return wrap_singleton_ir_values(
mhlo.CreateTokenOp(mhlo.TokenType.get()).result)
else:
return wrap_singleton_ir_values(mhlo.CreateTokenOp().result)

class TokenSet:
"""An immutable container of tokens to be used to lower effectful jaxprs. When lowering
Expand Down Expand Up @@ -980,7 +987,10 @@ def aval_to_types(aval):
args: List[List[ir.Value]] = []
for aval, arg in zip(jaxpr.in_avals, unflattened_args):
if replace_tokens_with_dummy and aval is core.abstract_token:
args.append(mhlo.CreateTokenOp(mhlo.TokenType.get()).results)
if mlir_api_version < 40:
args.append(mhlo.CreateTokenOp(mhlo.TokenType.get()).results)
else:
args.append(mhlo.CreateTokenOp().results)
else:
args.append(arg)
callee_name_stack = xla.extend_name_stack(ctx.name_stack,
Expand Down Expand Up @@ -1345,8 +1355,11 @@ def dynamic_update_slice(ctx: LoweringRuleContext, aval_out, x, update, *,
ctx, aval_out, x, update, *start_indices)

# TODO(necula): handle dynamic shapes
return mhlo.DynamicUpdateSliceOp(aval_to_ir_type(aval_out), x, update,
start_indices).result
if mlir_api_version < 40:
return mhlo.DynamicUpdateSliceOp(aval_to_ir_type(aval_out), x, update,
start_indices).result
else:
return mhlo.DynamicUpdateSliceOp(x, update, start_indices).result

def full_like_aval(ctx: LoweringRuleContext, value, aval: core.ShapedArray) -> ir.Value:
"""Returns an IR constant shaped full of `value` shaped like `aval`."""
Expand Down Expand Up @@ -1602,8 +1615,12 @@ def send_to_host(channel: int, token: mhlo.TokenType, operand: Any,
aval: core.ShapedArray, name: str, *,
sharding: Optional[xc.OpSharding] = None) -> ir.Value:
channel_handle = mhlo.ChannelHandle.get(channel, SEND_TO_HOST_TYPE)
send_op = mhlo.SendOp(mhlo.TokenType.get(), [operand], token, channel_handle,
is_host_transfer=ir.BoolAttr.get(True))
if mlir_api_version < 40:
send_op = mhlo.SendOp(mhlo.TokenType.get(), [operand], token, channel_handle,
is_host_transfer=ir.BoolAttr.get(True))
else:
send_op = mhlo.SendOp([operand], token, channel_handle,
is_host_transfer=ir.BoolAttr.get(True))
dtype_str = _dtype_to_xla_type_string(aval.dtype)
if dtype_str in {"f64", "s64", "u64", "c64", "c128"}:
raise NotImplementedError("64-bit types not supported.")
Expand Down Expand Up @@ -1652,7 +1669,10 @@ def _emit_tpu_python_callback(
*,
sharding: Optional[xc.OpSharding] = None
) -> Tuple[List[ir.Value], Any, Any]:
token = token or mhlo.CreateTokenOp(mhlo.TokenType.get()).result
if mlir_api_version < 40:
token = token or mhlo.CreateTokenOp(mhlo.TokenType.get()).result
else:
token = token or mhlo.CreateTokenOp().result
_wrapped_callback = callback

send_channels = []
Expand Down
8 changes: 6 additions & 2 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -2258,8 +2258,12 @@ def _mhlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs, p
idxs = [_unravel_index_mhlo(axis_env)] + [zero] * len(dims)
broadcast_result = mhlo.BroadcastOp(
x, mlir.dense_int_elements([1])).result
padded = mhlo.DynamicUpdateSliceOp(
padded.type, padded, broadcast_result, idxs).result
if xc.mlir_api_version < 40:
padded = mhlo.DynamicUpdateSliceOp(
padded.type, padded, broadcast_result, idxs).result
else:
padded = mhlo.DynamicUpdateSliceOp(
padded, broadcast_result, idxs).result
replica_groups = mlir.dense_int_elements(
xla.axis_groups(axis_env, axis_env.names[-1]))
out = mhlo.CrossReplicaSumOp(padded, replica_groups).result
Expand Down
2 changes: 1 addition & 1 deletion jaxlib/gpu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def _sytrd_mhlo(platform, gpu_solver, dtype, a, *, lower):
ir.RankedTensorType.get(s_type.shape, a_type.element_type), s)
offsets = tuple(mhlo.ConstantOp(intattr(i))
for i in ((0,) * len(batch_dims) + (0, 1)))
a = mhlo.DynamicUpdateSliceOp(a.type, a, s, offsets).result
a = mhlo.DynamicUpdateSliceOp(a, s, offsets).result

return a, d, e, taus, info

Expand Down
8 changes: 6 additions & 2 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from jax._src import lax_reference
from jax._src.util import prod
from jax._src.lax import lax as lax_internal
from jax._src.lib import xla_client as xc

from jax.config import config
config.parse_flags_with_absl()
Expand Down Expand Up @@ -2883,8 +2884,11 @@ def dynamic_update_slice_mlir(ctx, aval_out, x, update, *start_indices):
aval_out, = ctx.avals_out
dtype = dtypes.canonicalize_dtype(np.dtype('int64'))
start_indices = (*start_indices, mlir.ir_constant(np.array(0, dtype=dtype)))
return mhlo.DynamicUpdateSliceOp(mlir.aval_to_ir_type(aval_out), x, update,
start_indices).result
if xc.mlir_api_version < 40:
return mhlo.DynamicUpdateSliceOp(mlir.aval_to_ir_type(aval_out), x, update,
start_indices).result
else:
return mhlo.DynamicUpdateSliceOp(x, update, start_indices).result

@staticmethod
def broadcast_in_dim_mlir(ctx, aval_out, x, broadcast_dimensions):
Expand Down

0 comments on commit f37b67f

Please sign in to comment.