From ee92a6af929d076a2d113844dd734a2262d85683 Mon Sep 17 00:00:00 2001 From: Eugene Burmako Date: Wed, 14 Dec 2022 09:29:05 -0800 Subject: [PATCH] Port type inference for 6 ops from StableHLO to MHLO Ops: 1) AfterAllOp: https://github.com/openxla/stablehlo/pull/708. 2) CreateTokenOp: https://github.com/openxla/stablehlo/pull/711. 3) DynamicUpdateSliceOp: https://github.com/openxla/stablehlo/pull/686 and https://github.com/openxla/stablehlo/pull/757. 4) OptimizationBarrierOp: https://github.com/openxla/stablehlo/pull/575. 5) OutfeedOp: https://github.com/openxla/stablehlo/pull/713. 6) SendOp: https://github.com/openxla/stablehlo/pull/580. This PR prepares for migration from producing MHLO to producing StableHLO by aligning type inference between dialects, so that switching from one to another doesn't need changes to calls to Python builders. PiperOrigin-RevId: 495336100 --- jax/_src/ad_checkpoint.py | 6 ++++- jax/_src/lax/control_flow/loops.py | 6 ++++- jax/_src/lax/lax.py | 26 +++++++++++++------ jax/interpreters/mlir.py | 40 ++++++++++++++++++++++-------- jax/interpreters/pxla.py | 7 ++++-- jaxlib/gpu_solver.py | 5 +++- tests/lax_test.py | 8 ++++-- 7 files changed, 74 insertions(+), 24 deletions(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index ab91df0253e2..a1b3c41e0093 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -33,6 +33,7 @@ from jax._src.api_util import flatten_fun, shaped_abstractify from jax._src.lax import lax as lax_internal from jax._src.lax import convolution as lax_convolution +from jax._src.lib import xla_client as xc from jax._src.lib.mlir.dialects import mhlo from jax._src.traceback_util import api_boundary from jax._src.util import (unzip2, wraps, split_list, partition_list, safe_map, @@ -621,7 +622,10 @@ def _optimization_barrier_lowering_rule(ctx, *args): barrier_types = map(mlir.aval_to_ir_types, ctx.avals_in) flat_barrier_types = util.flatten(barrier_types) flat_args = mlir.flatten_lowering_ir_args(args) - barrier_op = mhlo.OptimizationBarrierOp(flat_barrier_types, flat_args) + if xc.mlir_api_version < 40: + barrier_op = mhlo.OptimizationBarrierOp(flat_barrier_types, flat_args) + else: + barrier_op = mhlo.OptimizationBarrierOp(flat_args) return util.unflatten(barrier_op.results, map(len, barrier_types)) def _optimization_barrier(arg): diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index ffa97d752229..067f8c9abf8e 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -41,6 +41,7 @@ from jax._src.lax import lax from jax._src.lax import slicing from jax._src.lax import windowed_reductions +from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import mhlo from jax._src.numpy.ufuncs import logaddexp @@ -1571,7 +1572,10 @@ def _pred_bcast_select_mhlo(ctx, if x_y_aval is core.abstract_token: x, = xs y, = ys - return [mhlo.AfterAllOp(mlir.aval_to_ir_type(x_y_aval), [x, y]).result] + if xc.mlir_api_version < 40: + return [mhlo.AfterAllOp(mlir.aval_to_ir_type(x_y_aval), [x, y]).result] + else: + return [mhlo.AfterAllOp([x, y]).result] else: assert isinstance(x_y_aval, core.ShapedArray), x_y_aval x, = xs diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 339a360105a3..cdd1dd581eec 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -4133,7 +4133,10 @@ def create_token(_=None): def _create_token_lowering(ctx, *operands): aval_out, = ctx.avals_out - return mhlo.CreateTokenOp(mlir.aval_to_ir_type(aval_out)).results + if xc.mlir_api_version < 40: + return mhlo.CreateTokenOp(mlir.aval_to_ir_type(aval_out)).results + else: + return mhlo.CreateTokenOp().results mlir.register_lowering(create_token_p, _create_token_lowering) @@ -4156,7 +4159,10 @@ def _after_all_abstract_eval(*operands): def _after_all_lowering(ctx, *operands): aval_out, = ctx.avals_out - return mhlo.AfterAllOp(mlir.aval_to_ir_type(aval_out), operands).results + if xc.mlir_api_version < 40: + return mhlo.AfterAllOp(mlir.aval_to_ir_type(aval_out), operands).results + else: + return mhlo.AfterAllOp(operands).results mlir.register_lowering(after_all_p, _after_all_lowering) @@ -4252,11 +4258,17 @@ def _outfeed_abstract_eval(token, *xs, partitions): def _outfeed_lowering(ctx, token, *xs, partitions): token_aval = ctx.avals_in[0] - outfeed = mhlo.OutfeedOp( - mlir.aval_to_ir_type(token_aval), - mlir.flatten_lowering_ir_args(xs), - token, - outfeed_config=ir.StringAttr.get('')) + if xc.mlir_api_version < 40: + outfeed = mhlo.OutfeedOp( + mlir.aval_to_ir_type(token_aval), + mlir.flatten_lowering_ir_args(xs), + token, + outfeed_config=ir.StringAttr.get('')) + else: + outfeed = mhlo.OutfeedOp( + mlir.flatten_lowering_ir_args(xs), + token, + outfeed_config=ir.StringAttr.get('')) if partitions is not None: mlir.set_sharding(outfeed, xla.sharding_to_proto(partitions)) return outfeed.results diff --git a/jax/interpreters/mlir.py b/jax/interpreters/mlir.py index 4e58efec8432..a9992609641d 100644 --- a/jax/interpreters/mlir.py +++ b/jax/interpreters/mlir.py @@ -302,8 +302,12 @@ def _device_array_constant_handler(val, canonicalize_types): for t in device_array.device_array_types: register_constant_handler(t, _device_array_constant_handler) -register_constant_handler( - core.Token, lambda _, __: [mhlo.CreateTokenOp(mhlo.TokenType.get()).result]) +def _token_constant_handler(val, canonicalize_types): + if mlir_api_version < 40: + return [mhlo.CreateTokenOp(mhlo.TokenType.get()).result] + else: + return [mhlo.CreateTokenOp().result] +register_constant_handler(core.Token, _token_constant_handler) # Source locations @@ -760,8 +764,11 @@ def token_type() -> Sequence[ir.Type]: return [mhlo.TokenType.get()] def create_token() -> Token: - return wrap_singleton_ir_values( - mhlo.CreateTokenOp(mhlo.TokenType.get()).result) + if mlir_api_version < 40: + return wrap_singleton_ir_values( + mhlo.CreateTokenOp(mhlo.TokenType.get()).result) + else: + return wrap_singleton_ir_values(mhlo.CreateTokenOp().result) class TokenSet: """An immutable container of tokens to be used to lower effectful jaxprs. When lowering @@ -980,7 +987,10 @@ def aval_to_types(aval): args: List[List[ir.Value]] = [] for aval, arg in zip(jaxpr.in_avals, unflattened_args): if replace_tokens_with_dummy and aval is core.abstract_token: - args.append(mhlo.CreateTokenOp(mhlo.TokenType.get()).results) + if mlir_api_version < 40: + args.append(mhlo.CreateTokenOp(mhlo.TokenType.get()).results) + else: + args.append(mhlo.CreateTokenOp().results) else: args.append(arg) callee_name_stack = xla.extend_name_stack(ctx.name_stack, @@ -1345,8 +1355,11 @@ def dynamic_update_slice(ctx: LoweringRuleContext, aval_out, x, update, *, ctx, aval_out, x, update, *start_indices) # TODO(necula): handle dynamic shapes - return mhlo.DynamicUpdateSliceOp(aval_to_ir_type(aval_out), x, update, - start_indices).result + if mlir_api_version < 40: + return mhlo.DynamicUpdateSliceOp( + aval_to_ir_type(aval_out), x, update, start_indices).result + else: + return mhlo.DynamicUpdateSliceOp(x, update, start_indices).result def full_like_aval(ctx: LoweringRuleContext, value, aval: core.ShapedArray) -> ir.Value: """Returns an IR constant shaped full of `value` shaped like `aval`.""" @@ -1602,8 +1615,12 @@ def send_to_host(channel: int, token: mhlo.TokenType, operand: Any, aval: core.ShapedArray, name: str, *, sharding: Optional[xc.OpSharding] = None) -> ir.Value: channel_handle = mhlo.ChannelHandle.get(channel, SEND_TO_HOST_TYPE) - send_op = mhlo.SendOp(mhlo.TokenType.get(), [operand], token, channel_handle, - is_host_transfer=ir.BoolAttr.get(True)) + if mlir_api_version < 40: + send_op = mhlo.SendOp(mhlo.TokenType.get(), [operand], token, channel_handle, + is_host_transfer=ir.BoolAttr.get(True)) + else: + send_op = mhlo.SendOp([operand], token, channel_handle, + is_host_transfer=ir.BoolAttr.get(True)) dtype_str = _dtype_to_xla_type_string(aval.dtype) if dtype_str in {"f64", "s64", "u64", "c64", "c128"}: raise NotImplementedError("64-bit types not supported.") @@ -1652,7 +1669,10 @@ def _emit_tpu_python_callback( *, sharding: Optional[xc.OpSharding] = None ) -> Tuple[List[ir.Value], Any, Any]: - token = token or mhlo.CreateTokenOp(mhlo.TokenType.get()).result + if mlir_api_version < 40: + token = token or mhlo.CreateTokenOp(mhlo.TokenType.get()).result + else: + token = token or mhlo.CreateTokenOp().result _wrapped_callback = callback send_channels = [] diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 167c19651c8d..737b5c3913a6 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -2258,8 +2258,11 @@ def _mhlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs, p idxs = [_unravel_index_mhlo(axis_env)] + [zero] * len(dims) broadcast_result = mhlo.BroadcastOp( x, mlir.dense_int_elements([1])).result - padded = mhlo.DynamicUpdateSliceOp( - padded.type, padded, broadcast_result, idxs).result + if xc.mlir_api_version < 40: + padded = mhlo.DynamicUpdateSliceOp(padded.type, padded, broadcast_result, + idxs).result + else: + padded = mhlo.DynamicUpdateSliceOp(padded, broadcast_result, idxs).result replica_groups = mlir.dense_int_elements( xla.axis_groups(axis_env, axis_env.names[-1])) out = mhlo.CrossReplicaSumOp(padded, replica_groups).result diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index 1cd2b4dc4e83..464e8d9a09b1 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -498,7 +498,10 @@ def _sytrd_mhlo(platform, gpu_solver, dtype, a, *, lower): ir.RankedTensorType.get(s_type.shape, a_type.element_type), s) offsets = tuple(mhlo.ConstantOp(intattr(i)) for i in ((0,) * len(batch_dims) + (0, 1))) - a = mhlo.DynamicUpdateSliceOp(a.type, a, s, offsets).result + if xla_client.mlir_api_version < 40: + a = mhlo.DynamicUpdateSliceOp(a.type, a, s, offsets).result + else: + a = mhlo.DynamicUpdateSliceOp(a, s, offsets).result return a, d, e, taus, info diff --git a/tests/lax_test.py b/tests/lax_test.py index 407e28ba4844..5b62a508c2f1 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -47,6 +47,7 @@ from jax._src import lax_reference from jax._src.util import prod from jax._src.lax import lax as lax_internal +from jax._src.lib import xla_client as xc from jax.config import config config.parse_flags_with_absl() @@ -2883,8 +2884,11 @@ def dynamic_update_slice_mlir(ctx, aval_out, x, update, *start_indices): aval_out, = ctx.avals_out dtype = dtypes.canonicalize_dtype(np.dtype('int64')) start_indices = (*start_indices, mlir.ir_constant(np.array(0, dtype=dtype))) - return mhlo.DynamicUpdateSliceOp(mlir.aval_to_ir_type(aval_out), x, update, - start_indices).result + if xc.mlir_api_version < 40: + return mhlo.DynamicUpdateSliceOp( + mlir.aval_to_ir_type(aval_out), x, update, start_indices).result + else: + return mhlo.DynamicUpdateSliceOp(x, update, start_indices).result @staticmethod def broadcast_in_dim_mlir(ctx, aval_out, x, broadcast_dimensions):