diff --git a/jax/experimental/export/shape_poly.py b/jax/experimental/export/shape_poly.py index fa261d4ae79c..8459aa6dfbad 100644 --- a/jax/experimental/export/shape_poly.py +++ b/jax/experimental/export/shape_poly.py @@ -417,10 +417,9 @@ def normalize(cls, coeffs: dict[_DimMon, int]) -> DimSize: def normalize_floordiv_times_divisor(cls, coeffs: dict[_DimMon, int]) -> DimSize: # Look for floordiv(E, M) * M and turn into E - mod(E, M). This comes # up when handling strided convolution. - for dec in _decompose_expr(_DimExpr(coeffs), _DimAtom.FLOORDIV): + for dec in _decompose_expr(_DimExpr(coeffs), _DimAtom.FLOORDIV, + with_exp=1): # e = factor * floordiv(operands)^exp * rest_monomial + rest_expr - if dec.exp != 1: - continue if dec.rest_monomial == 1 and dec.factor == 1: continue m_trimmed, m_remainder = divmod(dec.factor * dec.rest_monomial, dec.operands[1]) @@ -472,11 +471,33 @@ def inconclusive_comparison(self, operation: str, op: Any) -> Exception: "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#comparison-of-symbolic-dimensions-is-partially-supported.") def ge(self, other: DimSize) -> bool: - lb, ub = _ensure_poly(self - other, "ge").bounds() + self_minus_other = _ensure_poly(self - other, "ge") + lb, ub = self_minus_other.bounds() if lb >= 0: return True if ub < 0: return False + # Attempt to handle non_negative + for dec in _decompose_expr(self_minus_other, _DimAtom.NON_NEGATIVE): + # e = factor * non_negative(operands)^exp * rest_monomial + rest_expr + e1 = dec.rest_expr + e2 = dec.rest_expr + dec.factor * (dec.operands[0] ** dec.exp) * dec.rest_monomial + try: + if (e1 >= 0) and (e2 >= 0): + return True + except InconclusiveDimensionOperation: + continue + # Attempt to handle floordiv >= 0 + for dec in _decompose_expr(self_minus_other, _DimAtom.FLOORDIV, + with_exp=1, with_rest_monomial=1, + with_rest_expr=0): + # e = factor * floordiv(op1, op2)^1 * 1 + 0 + if dec.factor > 0: + try: + if (dec.operands[0] >= 0) and (dec.operands[1] >= 0): + return True + except InconclusiveDimensionOperation: + continue raise self.inconclusive_comparison(">=", other) def __hash__(self): @@ -680,9 +701,10 @@ def bounds(self) -> tuple[float, float]: # Watch for special-case: ct*a - ct*mod(b, a) >= 1 when ct >= 0 and a >= 0 # TODO(necula): add more principled support for floordiv and mod # For example, this will miss "1 + a - mod(b, a)" - for dec in _decompose_expr(self, _DimAtom.MOD): - # E = factor*mod(op1, op2)^exp * rest_monomial + rest_expr - if dec.exp == 1 and dec.rest_monomial == 1 and dec.rest_expr == - dec.factor * dec.operands[1]: + for dec in _decompose_expr(self, _DimAtom.MOD, + with_exp=1, with_rest_monomial=1): + # E = factor*mod(op1, op2)^1 * 1 + rest_expr + if dec.rest_expr == - dec.factor * dec.operands[1]: try: if dec.operands[1] <= 0: continue @@ -729,7 +751,9 @@ def __jax_array__(self): class _Decomposition: """Decomposition of an expression around an operation atom. - E = factor * mod(*operands)^exp * rest_monomial + rest_expr + E.g., for decomposing around "mod": + + E = factor * mod(*operands)^exp * rest_monomial + rest_expr """ factor: int operands: Sequence[_DimExpr] @@ -738,19 +762,44 @@ class _Decomposition: rest_expr: _DimExpr -def _decompose_expr(e: _DimExpr, operation: str) -> Iterable[_Decomposition]: +def _decompose_expr(e: _DimExpr, operation: str, *, + with_factor: Optional[int] = None, + with_exp: Optional[int] = None, + with_rest_monomial: Optional[Union[_DimExpr, int]] = None, + with_rest_expr: Optional[Union[_DimExpr, int]] = None, + ) -> Iterable[_Decomposition]: + """Computes the decompositions of `e` into `_Decomposition`. + + Args: + e: the expression to decompose + operation: the operation atom around which to decompose + with_factor, with_exp, with_rest_monomial, with_rest_expr: if present, + keep only the decompositions that match. + """ for m, m_factor in e.monomials(): atoms = [(a, aexp) for a, aexp in m.items() if a.operation == operation] if atoms: e_minus_m_coeffs = e._coeffs.copy() del e_minus_m_coeffs[m] for a, aexp in atoms: + if with_factor is not None and with_factor != m_factor: + continue + if with_exp is not None and with_exp != aexp: + continue + rest_monomial = _DimExpr({m.divide(_DimMon.from_atom(a, aexp)): 1}) + if (with_rest_monomial is not None and + not core.definitely_equal(with_rest_monomial, rest_monomial)): + continue + rest_expr = _DimExpr(e_minus_m_coeffs) + if (with_rest_expr is not None and + not core.definitely_equal(with_rest_expr, rest_expr)): + continue yield _Decomposition( factor=m_factor, operands=a.operands, exp=aexp, - rest_monomial=_DimExpr({m.divide(_DimMon.from_atom(a, aexp)): 1}), - rest_expr=_DimExpr(e_minus_m_coeffs)) + rest_monomial=rest_monomial, + rest_expr=rest_expr) core.pytype_aval_mappings[_DimExpr] = _DimExpr.get_aval xla.pytype_aval_mappings[_DimExpr] = _DimExpr.get_aval diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index 28a900afbc03..f5ac929b775e 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -2353,9 +2353,6 @@ def test_harness(self, harness: PolyHarness): if harness.group_name == "eig" and not jtu.test_device_matches(["cpu"]): raise unittest.SkipTest("JAX implements eig only on CPU.") - if harness.group_name == "indexing": - raise unittest.SkipTest("TODO(necula): fix the indexing tests") - prev_jax_config_flags = { fname: getattr(jax.config, fname) for fname, fvalue in harness.override_jax_config_flags.items()