diff --git a/tests/export_test.py b/tests/export_test.py index 0d87d0c52d36..4ad0f5f50699 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -946,22 +946,22 @@ def f_jax_inner(x): exp = export.export(f_jax)(x) if exp.serialization_version >= export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: - self.assertSetEqual({"TestingOrderedEffect1", "TestingOrderedEffect2"}, - {str(e) for e in exp.ordered_effects}) - self.assertEqual({"TestingUnorderedEffect1"}, - {str(e) for e in exp.unordered_effects}) + self.assertEqual(["TestingOrderedEffect1", "TestingOrderedEffect2"], + sorted(str(e) for e in exp.ordered_effects)) + self.assertEqual(["TestingUnorderedEffect1"], + [str(e) for e in exp.unordered_effects]) else: - self.assertSetEqual(set(), {str(e) for e in exp.ordered_effects}) - self.assertSetEqual(set(), {str(e) for e in exp.unordered_effects}) + self.assertEqual([], [str(e) for e in exp.ordered_effects]) + self.assertEqual([], [str(e) for e in exp.unordered_effects]) mlir_module_str = str(exp.mlir_module()) # Inner functions use stablehlo.token for all versions inner_fun_expected_re = ( r"func.func private @f_jax_inner\(" - r"%arg0: !stablehlo.token {jax.token = true}.*" + r"%arg0: !stablehlo.token .*jax.token = true.*" r"%arg1: tensor<3xf32>.*->.*" # Results - r"!stablehlo.token {jax.token = true}.*" + r"!stablehlo.token .*jax.token = true.*" r"tensor<3xf32>" ) self.assertRegex(mlir_module_str, inner_fun_expected_re) @@ -970,11 +970,11 @@ def f_jax_inner(x): # i1[0] before version 9. wrapped_main_expected_re = ( r"@_wrapped_jax_export_main\(" - r"%arg0: !stablehlo.token {jax.token = true}.*" - r"%arg1: !stablehlo.token {jax.token = true}.*->.*" + r"%arg0: !stablehlo.token .*jax.token = true.*" + r"%arg1: !stablehlo.token .*jax.token = true.*->.*" # Results - r"!stablehlo.token {jax.token = true}.*" - r"!stablehlo.token {jax.token = true}.*") + r"!stablehlo.token .*jax.token = true.*" + r"!stablehlo.token .*jax.token = true.*") if exp.serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: wrapped_main_expected_re = wrapped_main_expected_re.replace("!stablehlo.token", "tensor<0xi1>") self.assertRegex(mlir_module_str, wrapped_main_expected_re) @@ -987,20 +987,32 @@ def f_jax_inner(x): main_expected_re = wrapped_main_expected_re.replace("@_wrapped_jax_export_main", "@main") self.assertRegex(mlir_module_str, main_expected_re) - lowered = jax.jit(export.call_exported(exp)).lower(x) + # Now call the exported from a function that uses its own effects + def f_outer(x): + return ( + testing_primitive_with_effect_p.bind( + x, effect_class_name="TestingOrderedEffect2") + + testing_primitive_with_effect_p.bind( + x, effect_class_name="TestingUnorderedEffect1") + + export.call_exported(exp)(x)) + + lowered_outer = jax.jit(f_outer).lower(x) if exp.serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: - self.assertSetEqual(set(), - {str(e) for e in lowered._lowering.compile_args["ordered_effects"]}) - self.assertSetEqual(set(), - {str(e) for e in lowered._lowering.compile_args["unordered_effects"]}) + self.assertEqual(["TestingOrderedEffect2"], + [str(e) for e in lowered_outer._lowering.compile_args["ordered_effects"]]) else: - self.assertSetEqual({"TestingOrderedEffect1", "TestingOrderedEffect2"}, - {str(e) for e in lowered._lowering.compile_args["ordered_effects"]}) - self.assertSetEqual({"TestingUnorderedEffect1"}, - {str(e) for e in lowered._lowering.compile_args["unordered_effects"]}) + self.assertEqual(["TestingOrderedEffect1", "TestingOrderedEffect2"], + sorted(str(e) for e in lowered_outer._lowering.compile_args["ordered_effects"])) + self.assertEqual(["TestingUnorderedEffect1"], + sorted([str(e) for e in lowered_outer._lowering.compile_args["unordered_effects"]])) - res = export.call_exported(exp)(x) - self.assertAllClose(10. + 4. * 2. * x, res) + mlir_outer_module_str = str(lowered_outer.compiler_ir()) + if exp.serialization_version >= export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: + main_expected_re = main_expected_re.replace("!stablehlo.token", "tensor<0xi1>") + self.assertRegex(mlir_outer_module_str, main_expected_re) + + res = jax.jit(f_outer)(x) + self.assertAllClose(2. * 2. * x + 10. + 4. * 2. * x, res) @jtu.parameterized_filterable( kwargs=[