Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[shape_poly] Add heuristics for deciding >= 0 #18762

Merged
merged 1 commit into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 60 additions & 11 deletions jax/experimental/export/shape_poly.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,10 +417,9 @@ def normalize(cls, coeffs: dict[_DimMon, int]) -> DimSize:
def normalize_floordiv_times_divisor(cls, coeffs: dict[_DimMon, int]) -> DimSize:
# Look for floordiv(E, M) * M and turn into E - mod(E, M). This comes
# up when handling strided convolution.
for dec in _decompose_expr(_DimExpr(coeffs), _DimAtom.FLOORDIV):
for dec in _decompose_expr(_DimExpr(coeffs), _DimAtom.FLOORDIV,
with_exp=1):
# e = factor * floordiv(operands)^exp * rest_monomial + rest_expr
if dec.exp != 1:
continue
if dec.rest_monomial == 1 and dec.factor == 1:
continue
m_trimmed, m_remainder = divmod(dec.factor * dec.rest_monomial, dec.operands[1])
Expand Down Expand Up @@ -472,11 +471,33 @@ def inconclusive_comparison(self, operation: str, op: Any) -> Exception:
"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#comparison-of-symbolic-dimensions-is-partially-supported.")

def ge(self, other: DimSize) -> bool:
lb, ub = _ensure_poly(self - other, "ge").bounds()
self_minus_other = _ensure_poly(self - other, "ge")
lb, ub = self_minus_other.bounds()
if lb >= 0:
return True
if ub < 0:
return False
# Attempt to handle non_negative
for dec in _decompose_expr(self_minus_other, _DimAtom.NON_NEGATIVE):
# e = factor * non_negative(operands)^exp * rest_monomial + rest_expr
e1 = dec.rest_expr
e2 = dec.rest_expr + dec.factor * (dec.operands[0] ** dec.exp) * dec.rest_monomial
try:
if (e1 >= 0) and (e2 >= 0):
return True
except InconclusiveDimensionOperation:
continue
# Attempt to handle floordiv >= 0
for dec in _decompose_expr(self_minus_other, _DimAtom.FLOORDIV,
with_exp=1, with_rest_monomial=1,
with_rest_expr=0):
# e = factor * floordiv(op1, op2)^1 * 1 + 0
if dec.factor > 0:
try:
if (dec.operands[0] >= 0) and (dec.operands[1] >= 0):
return True
except InconclusiveDimensionOperation:
continue
raise self.inconclusive_comparison(">=", other)

def __hash__(self):
Expand Down Expand Up @@ -680,9 +701,10 @@ def bounds(self) -> tuple[float, float]:
# Watch for special-case: ct*a - ct*mod(b, a) >= 1 when ct >= 0 and a >= 0
# TODO(necula): add more principled support for floordiv and mod
# For example, this will miss "1 + a - mod(b, a)"
for dec in _decompose_expr(self, _DimAtom.MOD):
# E = factor*mod(op1, op2)^exp * rest_monomial + rest_expr
if dec.exp == 1 and dec.rest_monomial == 1 and dec.rest_expr == - dec.factor * dec.operands[1]:
for dec in _decompose_expr(self, _DimAtom.MOD,
with_exp=1, with_rest_monomial=1):
# E = factor*mod(op1, op2)^1 * 1 + rest_expr
if dec.rest_expr == - dec.factor * dec.operands[1]:
try:
if dec.operands[1] <= 0:
continue
Expand Down Expand Up @@ -729,7 +751,9 @@ def __jax_array__(self):
class _Decomposition:
"""Decomposition of an expression around an operation atom.

E = factor * mod(*operands)^exp * rest_monomial + rest_expr
E.g., for decomposing around "mod":

E = factor * mod(*operands)^exp * rest_monomial + rest_expr
"""
factor: int
operands: Sequence[_DimExpr]
Expand All @@ -738,19 +762,44 @@ class _Decomposition:
rest_expr: _DimExpr


def _decompose_expr(e: _DimExpr, operation: str) -> Iterable[_Decomposition]:
def _decompose_expr(e: _DimExpr, operation: str, *,
with_factor: Optional[int] = None,
with_exp: Optional[int] = None,
with_rest_monomial: Optional[Union[_DimExpr, int]] = None,
with_rest_expr: Optional[Union[_DimExpr, int]] = None,
) -> Iterable[_Decomposition]:
"""Computes the decompositions of `e` into `_Decomposition`.

Args:
e: the expression to decompose
operation: the operation atom around which to decompose
with_factor, with_exp, with_rest_monomial, with_rest_expr: if present,
keep only the decompositions that match.
"""
for m, m_factor in e.monomials():
atoms = [(a, aexp) for a, aexp in m.items() if a.operation == operation]
if atoms:
e_minus_m_coeffs = e._coeffs.copy()
del e_minus_m_coeffs[m]
for a, aexp in atoms:
if with_factor is not None and with_factor != m_factor:
continue
if with_exp is not None and with_exp != aexp:
continue
rest_monomial = _DimExpr({m.divide(_DimMon.from_atom(a, aexp)): 1})
if (with_rest_monomial is not None and
not core.definitely_equal(with_rest_monomial, rest_monomial)):
continue
rest_expr = _DimExpr(e_minus_m_coeffs)
if (with_rest_expr is not None and
not core.definitely_equal(with_rest_expr, rest_expr)):
continue
yield _Decomposition(
factor=m_factor,
operands=a.operands,
exp=aexp,
rest_monomial=_DimExpr({m.divide(_DimMon.from_atom(a, aexp)): 1}),
rest_expr=_DimExpr(e_minus_m_coeffs))
rest_monomial=rest_monomial,
rest_expr=rest_expr)

core.pytype_aval_mappings[_DimExpr] = _DimExpr.get_aval
xla.pytype_aval_mappings[_DimExpr] = _DimExpr.get_aval
Expand Down
3 changes: 0 additions & 3 deletions tests/shape_poly_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2353,9 +2353,6 @@ def test_harness(self, harness: PolyHarness):
if harness.group_name == "eig" and not jtu.test_device_matches(["cpu"]):
raise unittest.SkipTest("JAX implements eig only on CPU.")

if harness.group_name == "indexing":
raise unittest.SkipTest("TODO(necula): fix the indexing tests")

prev_jax_config_flags = {
fname: getattr(jax.config, fname)
for fname, fvalue in harness.override_jax_config_flags.items()
Expand Down