From 7462fdfad6e2405bcf701e6cacd30ff3dca138c6 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Mon, 21 Oct 2024 12:14:01 +0200 Subject: [PATCH 1/6] Cache unique value of TensorConstants and deprecate `get_unique_constant_value` --- pytensor/scan/rewriting.py | 7 ++--- pytensor/tensor/basic.py | 32 ++++++--------------- pytensor/tensor/rewriting/elemwise.py | 13 +++++---- pytensor/tensor/rewriting/math.py | 13 ++------- pytensor/tensor/variable.py | 41 ++++++++++++++++++++------- tests/tensor/test_basic.py | 5 ++-- 6 files changed, 54 insertions(+), 57 deletions(-) diff --git a/pytensor/scan/rewriting.py b/pytensor/scan/rewriting.py index ab4f5b6a77..90a572406a 100644 --- a/pytensor/scan/rewriting.py +++ b/pytensor/scan/rewriting.py @@ -71,7 +71,7 @@ get_slice_elements, set_subtensor, ) -from pytensor.tensor.variable import TensorConstant, get_unique_constant_value +from pytensor.tensor.variable import TensorConstant list_opt_slice = [ @@ -136,10 +136,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node): all_ins = list(graph_inputs(op_outs)) for idx in range(op_info.n_seqs): node_inp = node.inputs[idx + 1] - if ( - isinstance(node_inp, TensorConstant) - and get_unique_constant_value(node_inp) is not None - ): + if isinstance(node_inp, TensorConstant) and node_inp.unique_value is not None: try: # This works if input is a constant that has all entries # equal diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 7d5236d04a..1aced5a338 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -19,7 +19,7 @@ import pytensor import pytensor.scalar.sharedvar -from pytensor import compile, config, printing +from pytensor import config, printing from pytensor import scalar as ps from pytensor.compile.builders import OpFromGraph from pytensor.gradient import DisconnectedType, grad_undefined @@ -74,7 +74,6 @@ from pytensor.tensor.variable import ( TensorConstant, TensorVariable, - get_unique_constant_value, ) @@ -319,6 +318,8 @@ def get_underlying_scalar_constant_value( but I'm not sure where it is. """ + from pytensor.compile.ops import DeepCopyOp, OutputGuard + v = orig_v while True: if v is None: @@ -336,34 +337,19 @@ def get_underlying_scalar_constant_value( raise NotScalarConstantError() if isinstance(v, Constant): - unique_value = get_unique_constant_value(v) - if unique_value is not None: - data = unique_value - else: - data = v.data - - if isinstance(data, np.ndarray): - try: - return np.array(data.item(), dtype=v.dtype) - except ValueError: - raise NotScalarConstantError() + if isinstance(v, TensorConstant) and v.unique_value is not None: + return v.unique_value - from pytensor.sparse.type import SparseTensorType + elif isinstance(v, ScalarConstant): + return v.data - if isinstance(v.type, SparseTensorType): - raise NotScalarConstantError() - - return data + raise NotScalarConstantError() if not only_process_constants and getattr(v, "owner", None) and max_recur > 0: max_recur -= 1 if isinstance( v.owner.op, - Alloc - | DimShuffle - | Unbroadcast - | compile.ops.OutputGuard - | compile.DeepCopyOp, + Alloc | DimShuffle | Unbroadcast | OutputGuard | DeepCopyOp, ): # OutputGuard is only used in debugmode but we # keep it here to avoid problems with old pickles diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 277b8bdb55..8c91ac6e60 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -41,7 +41,7 @@ register_specialize, ) from pytensor.tensor.shape import shape_padleft -from pytensor.tensor.variable import TensorConstant, get_unique_constant_value +from pytensor.tensor.variable import TensorConstant class InplaceElemwiseOptimizer(GraphRewriter): @@ -513,7 +513,6 @@ def local_upcast_elemwise_constant_inputs(fgraph, node): new_inputs.append(i) else: try: - # works only for scalars cval_i = get_underlying_scalar_constant_value( i, only_process_constants=True ) @@ -1216,11 +1215,13 @@ def local_inline_composite_constants(fgraph, node): inner_replacements = {} for outer_inp, inner_inp in zip(node.inputs, composite_op.fgraph.inputs): # Complex variables don't have a `c_literal` that can be inlined - if "complex" not in outer_inp.type.dtype: - unique_value = get_unique_constant_value(outer_inp) - if unique_value is not None: + if ( + isinstance(outer_inp, TensorConstant) + and "complex" not in outer_inp.type.dtype + ): + if outer_inp.unique_value is not None: inner_replacements[inner_inp] = ps.constant( - unique_value, dtype=inner_inp.dtype + outer_inp.unique_value, dtype=inner_inp.dtype ) continue new_outer_inputs.append(outer_inp) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 2e30e1399b..37a6f1424f 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -105,7 +105,6 @@ from pytensor.tensor.variable import ( TensorConstant, TensorVariable, - get_unique_constant_value, ) @@ -137,16 +136,8 @@ def get_constant(v): numeric constant. If v is a plain Variable, returns None. """ - if isinstance(v, Constant): - unique_value = get_unique_constant_value(v) - if unique_value is not None: - data = unique_value - else: - data = v.data - if data.ndim == 0: - return data - else: - return None + if isinstance(v, TensorConstant): + return v.unique_value elif isinstance(v, Variable): return None else: diff --git a/pytensor/tensor/variable.py b/pytensor/tensor/variable.py index 261a8bbc4a..3f729e141e 100644 --- a/pytensor/tensor/variable.py +++ b/pytensor/tensor/variable.py @@ -11,7 +11,10 @@ from pytensor.configdefaults import config from pytensor.graph.basic import Constant, OptionalApplyType, Variable from pytensor.graph.utils import MetaType -from pytensor.scalar import ComplexError, IntegerDivisionError +from pytensor.scalar import ( + ComplexError, + IntegerDivisionError, +) from pytensor.tensor import _get_vector_length from pytensor.tensor.exceptions import AdvancedIndexingError from pytensor.tensor.type import TensorType @@ -1042,15 +1045,9 @@ def no_nan(self): def get_unique_constant_value(x: TensorVariable) -> Number | None: """Return the unique value of a tensor, if there is one""" - if isinstance(x, Constant): - data = x.data - - if isinstance(data, np.ndarray) and data.ndim > 0: - flat_data = data.ravel() - if flat_data.shape[0]: - if (flat_data == flat_data[0]).all(): - return flat_data[0] - + warnings.warn("get_unique_constant_value is deprecated.", FutureWarning) + if isinstance(x, TensorConstant): + return x.unique_value return None @@ -1077,6 +1074,30 @@ def __init__(self, type: _TensorTypeType, data, name=None): def signature(self): return TensorConstantSignature((self.type, self.data)) + @property + def unique_value(self) -> Number | None: + """Return the unique value of a tensor, if there is one""" + try: + return self._unique_value + except AttributeError: + data = self.data + if np.ndim(data) == 0: + unique_value = data + else: + flat_data = data.ravel() + if (flat_data == flat_data[0]).all(): + unique_value = flat_data[0] + else: + unique_value = None + + if unique_value is not None: + # Don't allow the unique value to be changed + unique_value.setflags(write=False) + + self._unique_value = unique_value + + return self._unique_value + def equals(self, other): # Override Constant.equals to allow to compare with # numpy.ndarray, and python type. diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 323d401f42..72763d05db 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -3564,11 +3564,12 @@ def test_second(self): assert get_underlying_scalar_constant_value(s) == c.data def test_copy(self): - # Make sure we do not return the internal storage of a constant, + # Make sure we do not return a writeable internal storage of a constant, # so we cannot change the value of a constant by mistake. c = constant(3) d = extract_constant(c) - d += 1 + with pytest.raises(ValueError, match="output array is read-only"): + d += 1 e = extract_constant(c) assert e == 3, (c, d, e) From 029fc881517905237f5c30b9effb02ba3a920bd7 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Mon, 21 Oct 2024 12:16:06 +0200 Subject: [PATCH 2/6] Deprecate `pytensor.get_underlying_scalar_constant` --- pytensor/__init__.py | 11 +++++------ pytensor/gradient.py | 4 ++-- pytensor/tensor/basic.py | 31 +++++++++++++++++++------------ tests/tensor/test_elemwise.py | 4 ++-- 4 files changed, 28 insertions(+), 22 deletions(-) diff --git a/pytensor/__init__.py b/pytensor/__init__.py index dd6117c527..720fa7d741 100644 --- a/pytensor/__init__.py +++ b/pytensor/__init__.py @@ -24,6 +24,7 @@ # pytensor code, since this code may want to log some messages. import logging import sys +import warnings from functools import singledispatch from pathlib import Path from typing import Any, NoReturn, Optional @@ -148,12 +149,10 @@ def get_underlying_scalar_constant(v): If `v` is not some view of constant data, then raise a `NotScalarConstantError`. """ - # Is it necessary to test for presence of pytensor.sparse at runtime? - sparse = globals().get("sparse") - if sparse and isinstance(v.type, sparse.SparseTensorType): - if v.owner is not None and isinstance(v.owner.op, sparse.CSM): - data = v.owner.inputs[0] - return tensor.get_underlying_scalar_constant_value(data) + warnings.warn( + "get_underlying_scalar_constant is deprecated. Use tensor.get_underlying_scalar_constant_value instead.", + DeprecationWarning, + ) return tensor.get_underlying_scalar_constant_value(v) diff --git a/pytensor/gradient.py b/pytensor/gradient.py index f9c393b512..446bd9de96 100644 --- a/pytensor/gradient.py +++ b/pytensor/gradient.py @@ -1330,7 +1330,7 @@ def try_to_copy_if_needed(var): f" {i}. Since this input is only connected " "to integer-valued outputs, it should " "evaluate to zeros, but it evaluates to" - f"{pytensor.get_underlying_scalar_constant(term)}." + f"{pytensor.get_underlying_scalar_constant_value(term)}." ) raise ValueError(msg) @@ -2172,7 +2172,7 @@ def _is_zero(x): no_constant_value = True try: - constant_value = pytensor.get_underlying_scalar_constant(x) + constant_value = pytensor.get_underlying_scalar_constant_value(x) no_constant_value = False except pytensor.tensor.exceptions.NotScalarConstantError: pass diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 1aced5a338..3959094756 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -319,6 +319,8 @@ def get_underlying_scalar_constant_value( """ from pytensor.compile.ops import DeepCopyOp, OutputGuard + from pytensor.sparse import CSM + from pytensor.tensor.subtensor import Subtensor v = orig_v while True: @@ -346,16 +348,16 @@ def get_underlying_scalar_constant_value( raise NotScalarConstantError() if not only_process_constants and getattr(v, "owner", None) and max_recur > 0: + op = v.owner.op max_recur -= 1 if isinstance( - v.owner.op, - Alloc | DimShuffle | Unbroadcast | OutputGuard | DeepCopyOp, + op, Alloc | DimShuffle | Unbroadcast | OutputGuard | DeepCopyOp ): # OutputGuard is only used in debugmode but we # keep it here to avoid problems with old pickles v = v.owner.inputs[0] continue - elif isinstance(v.owner.op, Shape_i): + elif isinstance(op, Shape_i): i = v.owner.op.i inp = v.owner.inputs[0] if isinstance(inp, Constant): @@ -369,10 +371,10 @@ def get_underlying_scalar_constant_value( # mess with the stabilization optimization and be too slow. # We put all the scalar Ops used by get_canonical_form_slice() # to allow it to determine the broadcast pattern correctly. - elif isinstance(v.owner.op, ScalarFromTensor | TensorFromScalar): + elif isinstance(op, ScalarFromTensor | TensorFromScalar): v = v.owner.inputs[0] continue - elif isinstance(v.owner.op, CheckAndRaise): + elif isinstance(op, CheckAndRaise): # check if all conditions are constant and true conds = [ get_underlying_scalar_constant_value(c, max_recur=max_recur) @@ -381,7 +383,7 @@ def get_underlying_scalar_constant_value( if builtins.all(0 == c.ndim and c != 0 for c in conds): v = v.owner.inputs[0] continue - elif isinstance(v.owner.op, ps.ScalarOp): + elif isinstance(op, ps.ScalarOp): if isinstance(v.owner.op, ps.Second): # We don't need both input to be constant for second shp, val = v.owner.inputs @@ -398,7 +400,7 @@ def get_underlying_scalar_constant_value( # In fast_compile, we don't enable local_fill_to_alloc, so # we need to investigate Second as Alloc. So elemwise # don't disable the check for Second. - elif isinstance(v.owner.op, Elemwise): + elif isinstance(op, Elemwise): if isinstance(v.owner.op.scalar_op, ps.Second): # We don't need both input to be constant for second shp, val = v.owner.inputs @@ -414,10 +416,7 @@ def get_underlying_scalar_constant_value( ret = [[None]] v.owner.op.perform(v.owner, const, ret) return np.asarray(ret[0][0].copy()) - elif ( - isinstance(v.owner.op, pytensor.tensor.subtensor.Subtensor) - and v.ndim == 0 - ): + elif isinstance(op, Subtensor) and v.ndim == 0: if isinstance(v.owner.inputs[0], TensorConstant): from pytensor.tensor.subtensor import get_constant_idx @@ -541,6 +540,14 @@ def get_underlying_scalar_constant_value( if isinstance(grandparent, Constant): return np.asarray(np.shape(grandparent.data)[idx]) + elif isinstance(op, CSM): + data = get_underlying_scalar_constant_value( + v.owner.inputs, elemwise=elemwise, max_recur=max_recur + ) + # Sparse variable can only be constant if zero (or I guess if homogeneously dense) + if data == 0: + return data + break raise NotScalarConstantError() @@ -4064,7 +4071,7 @@ def make_node(self, a, choices): static_out_shape = () for s in out_shape: try: - s_val = pytensor.get_underlying_scalar_constant(s) + s_val = get_underlying_scalar_constant_value(s) except (NotScalarConstantError, AttributeError): s_val = None diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 76906232af..54246918d4 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -804,8 +804,8 @@ def test_partial_static_shape_info(self): assert len(res_shape) == 1 assert len(res_shape[0]) == 2 - assert pytensor.get_underlying_scalar_constant(res_shape[0][0]) == 1 - assert pytensor.get_underlying_scalar_constant(res_shape[0][1]) == 1 + assert pytensor.get_underlying_scalar_constant_value(res_shape[0][0]) == 1 + assert pytensor.get_underlying_scalar_constant_value(res_shape[0][1]) == 1 def test_infer_shape_multi_output(self): class CustomElemwise(Elemwise): From 40e001189c66c156ed4b6093c3c24bf60ae4a29c Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Mon, 21 Oct 2024 17:19:36 +0200 Subject: [PATCH 3/6] Deprecate `extract_constant` --- pytensor/scan/rewriting.py | 30 +++- pytensor/tensor/basic.py | 154 +++++++++++++------ pytensor/tensor/rewriting/basic.py | 202 +++++++++++++------------ pytensor/tensor/rewriting/math.py | 193 +++++++++++++---------- pytensor/tensor/rewriting/shape.py | 44 ++++-- pytensor/tensor/rewriting/subtensor.py | 27 +++- tests/tensor/rewriting/test_math.py | 4 +- tests/tensor/test_basic.py | 5 +- 8 files changed, 406 insertions(+), 253 deletions(-) diff --git a/pytensor/scan/rewriting.py b/pytensor/scan/rewriting.py index 90a572406a..7ee512a549 100644 --- a/pytensor/scan/rewriting.py +++ b/pytensor/scan/rewriting.py @@ -54,6 +54,7 @@ from pytensor.tensor.basic import ( Alloc, AllocEmpty, + get_scalar_constant_value, get_underlying_scalar_constant_value, ) from pytensor.tensor.elemwise import DimShuffle, Elemwise @@ -665,8 +666,10 @@ def inner_sitsot_only_last_step_used( client = fgraph.clients[outer_var][0][0] if isinstance(client, Apply) and isinstance(client.op, Subtensor): lst = get_idx_list(client.inputs, client.op.idx_list) - if len(lst) == 1 and pt.extract_constant(lst[0]) == -1: - return True + return ( + len(lst) == 1 + and get_scalar_constant_value(idx[0], raise_not_constant=False) == -1 + ) return False @@ -1341,10 +1344,17 @@ def scan_save_mem(fgraph, node): if isinstance(this_slice[0], slice) and this_slice[0].stop is None: global_nsteps = None if isinstance(cf_slice[0], slice): - stop = pt.extract_constant(cf_slice[0].stop) + stop = get_scalar_constant_value( + cf_slice[0].stop, raise_not_constant=False + ) else: - stop = pt.extract_constant(cf_slice[0]) + 1 - if stop == maxsize or stop == pt.extract_constant(length): + stop = ( + get_scalar_constant_value(cf_slice[0], raise_not_constant=False) + + 1 + ) + if stop == maxsize or stop == get_scalar_constant_value( + length, raise_not_constant=False + ): stop = None else: # there is a **gotcha** here ! Namely, scan returns an @@ -1448,9 +1458,13 @@ def scan_save_mem(fgraph, node): cf_slice = get_canonical_form_slice(this_slice[0], length) if isinstance(cf_slice[0], slice): - start = pt.extract_constant(cf_slice[0].start) + start = pt.get_scalar_constant_value( + cf_slice[0].start, raise_not_constant=False + ) else: - start = pt.extract_constant(cf_slice[0]) + start = pt.get_scalar_constant_value( + cf_slice[0], raise_not_constant=False + ) if start == 0 or store_steps[i] == 0: store_steps[i] = 0 @@ -1625,7 +1639,7 @@ def scan_save_mem(fgraph, node): # 3.6 Compose the new scan # TODO: currently we don't support scan with 0 step. So # don't create one. - if pt.extract_constant(node_ins[0]) == 0: + if get_scalar_constant_value(node_ins[0], raise_not_constant=False) == 0: return False # Do not call make_node for test_value diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 3959094756..93b59df319 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -267,27 +267,7 @@ def _obj_is_wrappable_as_tensor(x): ) -def get_scalar_constant_value( - v, elemwise=True, only_process_constants=False, max_recur=10 -): - """ - Checks whether 'v' is a scalar (ndim = 0). - - If 'v' is a scalar then this function fetches the underlying constant by calling - 'get_underlying_scalar_constant_value()'. - - If 'v' is not a scalar, it raises a NotScalarConstantError. - - """ - if isinstance(v, Variable | np.ndarray): - if v.ndim != 0: - raise NotScalarConstantError() - return get_underlying_scalar_constant_value( - v, elemwise, only_process_constants, max_recur - ) - - -def get_underlying_scalar_constant_value( +def _get_underlying_scalar_constant_value( orig_v, elemwise=True, only_process_constants=False, max_recur=10 ): """Return the constant scalar(0-D) value underlying variable `v`. @@ -377,7 +357,7 @@ def get_underlying_scalar_constant_value( elif isinstance(op, CheckAndRaise): # check if all conditions are constant and true conds = [ - get_underlying_scalar_constant_value(c, max_recur=max_recur) + _get_underlying_scalar_constant_value(c, max_recur=max_recur) for c in v.owner.inputs[1:] ] if builtins.all(0 == c.ndim and c != 0 for c in conds): @@ -391,7 +371,7 @@ def get_underlying_scalar_constant_value( continue if isinstance(v.owner.op, _scalar_constant_value_elemwise_ops): const = [ - get_underlying_scalar_constant_value(i, max_recur=max_recur) + _get_underlying_scalar_constant_value(i, max_recur=max_recur) for i in v.owner.inputs ] ret = [[None]] @@ -410,7 +390,7 @@ def get_underlying_scalar_constant_value( v.owner.op.scalar_op, _scalar_constant_value_elemwise_ops ): const = [ - get_underlying_scalar_constant_value(i, max_recur=max_recur) + _get_underlying_scalar_constant_value(i, max_recur=max_recur) for i in v.owner.inputs ] ret = [[None]] @@ -453,7 +433,7 @@ def get_underlying_scalar_constant_value( ): idx = v.owner.op.idx_list[0] if isinstance(idx, Type): - idx = get_underlying_scalar_constant_value( + idx = _get_underlying_scalar_constant_value( v.owner.inputs[1], max_recur=max_recur ) try: @@ -487,14 +467,13 @@ def get_underlying_scalar_constant_value( ): idx = v.owner.op.idx_list[0] if isinstance(idx, Type): - idx = get_underlying_scalar_constant_value( + idx = _get_underlying_scalar_constant_value( v.owner.inputs[1], max_recur=max_recur ) - # Python 2.4 does not support indexing with numpy.integer - # So we cast it. - idx = int(idx) ret = v.owner.inputs[0].owner.inputs[idx] - ret = get_underlying_scalar_constant_value(ret, max_recur=max_recur) + ret = _get_underlying_scalar_constant_value( + ret, max_recur=max_recur + ) # MakeVector can cast implicitly its input in some case. return np.asarray(ret, dtype=v.type.dtype) @@ -509,7 +488,7 @@ def get_underlying_scalar_constant_value( idx_list = op.idx_list idx = idx_list[0] if isinstance(idx, Type): - idx = get_underlying_scalar_constant_value( + idx = _get_underlying_scalar_constant_value( owner.inputs[1], max_recur=max_recur ) grandparent = leftmost_parent.owner.inputs[0] @@ -519,7 +498,9 @@ def get_underlying_scalar_constant_value( grandparent.owner.op, Unbroadcast ): ggp_shape = grandparent.owner.inputs[0].type.shape - l = [get_underlying_scalar_constant_value(s) for s in ggp_shape] + l = [ + _get_underlying_scalar_constant_value(s) for s in ggp_shape + ] gp_shape = tuple(l) if not (idx < ndim): @@ -541,7 +522,7 @@ def get_underlying_scalar_constant_value( if isinstance(grandparent, Constant): return np.asarray(np.shape(grandparent.data)[idx]) elif isinstance(op, CSM): - data = get_underlying_scalar_constant_value( + data = _get_underlying_scalar_constant_value( v.owner.inputs, elemwise=elemwise, max_recur=max_recur ) # Sparse variable can only be constant if zero (or I guess if homogeneously dense) @@ -552,6 +533,92 @@ def get_underlying_scalar_constant_value( raise NotScalarConstantError() +def get_underlying_scalar_constant_value( + v, + *, + elemwise=True, + only_process_constants=False, + max_recur=10, + raise_not_constant=True, +): + """Return the unique constant scalar(0-D) value underlying variable `v`. + + If `v` is the output of dimshuffles, fills, allocs, etc, + cast, OutputGuard, DeepCopyOp, ScalarFromTensor, ScalarOp, Elemwise + and some pattern with Subtensor, this function digs through them. + + If `v` is not some view of constant scalar data, then raise a + NotScalarConstantError. + + This function performs symbolic reasoning about the value of `v`, as opposed to numerical reasoning by + constant folding the inputs of `v`. + + Parameters + ---------- + v: Variable + elemwise : bool + If False, we won't try to go into elemwise. So this call is faster. + But we still investigate in Second Elemwise (as this is a substitute + for Alloc) + only_process_constants : bool + If True, we only attempt to obtain the value of `orig_v` if it's + directly constant and don't try to dig through dimshuffles, fills, + allocs, and other to figure out its value. + max_recur : int + The maximum number of recursion. + raise_not_constant: bool, default True + If True, raise a NotScalarConstantError if `v` does not have an + underlying constant scalar value. If False, return `v` as is. + + + Raises + ------ + NotScalarConstantError + `v` does not have an underlying constant scalar value. + Only rasise if raise_not_constant is True. + + """ + try: + return _get_underlying_scalar_constant_value( + v, + elemwise=elemwise, + only_process_constants=only_process_constants, + max_recur=max_recur, + ) + except NotScalarConstantError: + if raise_not_constant: + raise + return v + + +def get_scalar_constant_value( + v, + elemwise=True, + only_process_constants=False, + max_recur=10, + raise_not_constant: bool = True, +): + """ + Checks whether 'v' is a scalar (ndim = 0). + + If 'v' is a scalar then this function fetches the underlying constant by calling + 'get_underlying_scalar_constant_value()'. + + If 'v' is not a scalar, it raises a TypeError. + + """ + if isinstance(v, Variable | np.ndarray): + if v.ndim != 0: + raise TypeError() + return get_underlying_scalar_constant_value( + v, + elemwise=elemwise, + only_process_constants=only_process_constants, + max_recur=max_recur, + raise_not_constant=raise_not_constant, + ) + + class TensorFromScalar(COp): __props__ = () @@ -2006,16 +2073,16 @@ def extract_constant(x, elemwise=True, only_process_constants=False): ScalarVariable, we convert it to a tensor with tensor_from_scalar. """ - try: - x = get_underlying_scalar_constant_value(x, elemwise, only_process_constants) - except NotScalarConstantError: - pass - if isinstance(x, ps.ScalarVariable | ps.sharedvar.ScalarSharedVariable): - if x.owner and isinstance(x.owner.op, ScalarFromTensor): - x = x.owner.inputs[0] - else: - x = tensor_from_scalar(x) - return x + warnings.warn( + "extract_constant is deprecated. Use `get_underlying_scalar_constant_value(..., raise_not_constant=False)`", + FutureWarning, + ) + return get_underlying_scalar_constant_value( + x, + elemwise=elemwise, + only_process_constants=only_process_constants, + raise_not_constant=False, + ) def transpose(x, axes=None): @@ -4394,7 +4461,6 @@ def ix_(*args): "split", "transpose", "matrix_transpose", - "extract_constant", "default", "tensor_copy", "transfer", diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 78d00790ac..d9b82b0ad5 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -30,7 +30,7 @@ from pytensor import compile, config from pytensor.compile.ops import ViewOp from pytensor.graph import FunctionGraph -from pytensor.graph.basic import Constant, Variable +from pytensor.graph.basic import Constant from pytensor.graph.rewriting.basic import ( NodeRewriter, RemovalNodeRewriter, @@ -54,8 +54,8 @@ as_tensor_variable, atleast_Nd, cast, - extract_constant, fill, + get_scalar_constant_value, get_underlying_scalar_constant_value, join, ones_like, @@ -477,7 +477,12 @@ def local_alloc_sink_dimshuffle(fgraph, node): output_shape = node.inputs[1:] num_dims_with_size_1_added_to_left = 0 for i in range(len(output_shape) - inp.ndim): - if extract_constant(output_shape[i], only_process_constants=True) == 1: + if ( + get_scalar_constant_value( + output_shape[i], only_process_constants=True, raise_not_constant=False + ) + == 1 + ): num_dims_with_size_1_added_to_left += 1 else: break @@ -537,93 +542,90 @@ def local_useless_elemwise(fgraph, node): xor(x, x) -> zeros_like(x) TODO: This implementation is painfully redundant. + TODO: Allow rewrite when useless input broadcasts output """ - if isinstance(node.op, Elemwise): - # We call zeros_like and one_like with opt=True to generate a - # cleaner graph. - dtype = node.outputs[0].dtype - - if node.op.scalar_op == ps.eq and len(node.inputs) == 2: - if node.inputs[0] == node.inputs[1]: - # it is the same var in the graph. That will always be true - ret = ones_like(node.inputs[0], dtype=dtype, opt=True) - - # Copy stack trace from input to constant output - copy_stack_trace(node.outputs[0], ret) - return [ret] - elif node.op.scalar_op == ps.neq and len(node.inputs) == 2: - if node.inputs[0] == node.inputs[1]: - # it is the same var in the graph. That will always be false - ret = zeros_like(node.inputs[0], dtype=dtype, opt=True) - - # Copy stack trace from input to constant output - copy_stack_trace(node.outputs[0], ret) - return [ret] - - elif node.op.scalar_op == ps.mul and len(node.inputs) == 1: - # No need to copy over any stack trace - return [node.inputs[0]] - - elif node.op.scalar_op == ps.add and len(node.inputs) == 1: - # No need to copy over any stack trace - return [node.inputs[0]] - elif node.op.scalar_op == ps.identity and len(node.inputs) == 1: - return [node.inputs[0]] - - elif isinstance(node.op.scalar_op, ps.AND) and len(node.inputs) == 2: - if isinstance(node.inputs[0], TensorConstant): - const_val = extract_constant( - node.inputs[0], only_process_constants=True - ) - if not isinstance(const_val, Variable): - if const_val == 0: - return [zeros_like(node.inputs[1], dtype=dtype, opt=True)] - elif node.outputs[0].dtype == "bool": - # If the output is not Boolean, it is the bitwise AND, - # and this rewrite would be wrong - return [node.inputs[1].astype(node.outputs[0].dtype)] - - if isinstance(node.inputs[1], TensorConstant): - const_val = extract_constant( - node.inputs[1], only_process_constants=True - ) - if not isinstance(const_val, Variable): - if const_val == 0: - return [zeros_like(node.inputs[0], dtype=dtype, opt=True)] - elif node.outputs[0].dtype == "bool": - # If the output is not Boolean, it is the bitwise AND, - # and this rewrite would be wrong - return [node.inputs[0].astype(node.outputs[0].dtype)] - - elif isinstance(node.op.scalar_op, ps.OR) and len(node.inputs) == 2: - if isinstance(node.inputs[0], TensorConstant): - const_val = extract_constant( - node.inputs[0], only_process_constants=True - ) - if not isinstance(const_val, Variable): - if const_val == 0: - return [node.inputs[1].astype(node.outputs[0].dtype)] - elif node.outputs[0].dtype == "bool": - # If the output is not Boolean, it is the bitwise OR, - # and this rewrite would be wrong - return [ones_like(node.inputs[1], dtype=dtype, opt=True)] - - if isinstance(node.inputs[1], TensorConstant): - const_val = extract_constant( - node.inputs[1], only_process_constants=True - ) - if not isinstance(const_val, Variable): - if const_val == 0: - return [node.inputs[0].astype(node.outputs[0].dtype)] - elif node.outputs[0].dtype == "bool": - # If the output is not Boolean, it is the bitwise OR, - # and this rewrite would be wrong - return [ones_like(node.inputs[0], dtype=dtype, opt=True)] - - elif isinstance(node.op.scalar_op, ps.XOR) and len(node.inputs) == 2: - if node.inputs[0] is node.inputs[1]: - return [zeros_like(node.inputs[0], dtype=dtype, opt=True)] + out_bcast = node.outputs[0].type.broadcastable + dtype = node.outputs[0].type.dtype + scalar_op = node.op.scalar_op + + if isinstance(scalar_op, ps.EQ) and len(node.inputs) == 2: + if node.inputs[0] is node.inputs[1]: + # it is the same var in the graph. That will always be true + ret = ones_like(node.inputs[0], dtype=dtype, opt=True) + + # Copy stack trace from input to constant output + copy_stack_trace(node.outputs[0], ret) + return [ret] + elif isinstance(scalar_op, ps.NEQ | ps.XOR) and len(node.inputs) == 2: + if node.inputs[0] is node.inputs[1]: + # it is the same var in the graph. That will always be false + ret = zeros_like(node.inputs[0], dtype=dtype, opt=True) + + # Copy stack trace from input to constant output + copy_stack_trace(node.outputs[0], ret) + return [ret] + + elif ( + isinstance(node.op.scalar_op, ps.Mul | ps.Add | ps.Identity) + and len(node.inputs) == 1 + ): + # No need to copy over any stack trace + return [node.inputs[0]] + + elif isinstance(node.op.scalar_op, ps.AND) and len(node.inputs) == 2: + if ( + isinstance(node.inputs[0], TensorConstant) + and node.inputs[1].type.broadcastable == out_bcast + ): + const_val = node.inputs[0].unique_value + if const_val is not None: + if const_val == 0: + return [zeros_like(node.inputs[1], dtype=dtype, opt=True)] + elif node.outputs[0].dtype == "bool": + # If the output is not Boolean, it is the bitwise AND, + # and this rewrite would be wrong + return [node.inputs[1].astype(node.outputs[0].dtype)] + + if ( + isinstance(node.inputs[1], TensorConstant) + and node.inputs[0].type.broadcastable == out_bcast + ): + const_val = node.inputs[1].unique_value + if const_val is not None: + if const_val == 0: + return [zeros_like(node.inputs[0], dtype=dtype, opt=True)] + elif node.outputs[0].dtype == "bool": + # If the output is not Boolean, it is the bitwise AND, + # and this rewrite would be wrong + return [node.inputs[0].astype(node.outputs[0].dtype)] + + elif isinstance(node.op.scalar_op, ps.OR) and len(node.inputs) == 2: + if ( + isinstance(node.inputs[0], TensorConstant) + and node.inputs[1].type.broadcastable == out_bcast + ): + const_val = node.inputs[0].unique_value + if const_val is not None: + if const_val == 0: + return [node.inputs[1].astype(node.outputs[0].dtype)] + elif node.outputs[0].dtype == "bool": + # If the output is not Boolean, it is the bitwise OR, + # and this rewrite would be wrong + return [ones_like(node.inputs[1], dtype=dtype, opt=True)] + + if ( + isinstance(node.inputs[1], TensorConstant) + and node.inputs[0].type.broadcastable == out_bcast + ): + const_val = node.inputs[1].unique_value + if const_val is not None: + if const_val == 0: + return [node.inputs[0].astype(node.outputs[0].dtype)] + elif node.outputs[0].dtype == "bool": + # If the output is not Boolean, it is the bitwise OR, + # and this rewrite would be wrong + return [ones_like(node.inputs[0], dtype=dtype, opt=True)] @register_specialize @@ -987,13 +989,10 @@ def local_useless_switch(fgraph, node): left = node.inputs[1] right = node.inputs[2] cond_var = node.inputs[0] - cond = extract_constant(cond_var, only_process_constants=True) out_bcast = node.outputs[0].type.broadcastable - if (isinstance(cond, np.ndarray) and cond.ndim == 0) or isinstance( - cond, np.number | np.bool_ - ): - if cond == 0: + if isinstance(cond_var, TensorConstant) and cond_var.unique_value is not None: + if cond_var.unique_value == 0: correct_out = right else: correct_out = left @@ -1013,7 +1012,7 @@ def local_useless_switch(fgraph, node): # if left is right -> left if equivalent_up_to_constant_casting(left, right): if left.type.broadcastable != out_bcast: - left, _ = broadcast_arrays(left, cond) + left, _ = broadcast_arrays(left, cond_var) out_dtype = node.outputs[0].type.dtype if left.type.dtype != out_dtype: @@ -1025,13 +1024,22 @@ def local_useless_switch(fgraph, node): # This case happens with scan. # Elemwise{switch}(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X) if ( - cond_var.owner + node.outputs[0].type.ndim == 0 + and cond_var.owner and isinstance(cond_var.owner.op, Elemwise) and isinstance(cond_var.owner.op.scalar_op, ps.LE) and cond_var.owner.inputs[0].owner and isinstance(cond_var.owner.inputs[0].owner.op, Shape_i) - and extract_constant(cond_var.owner.inputs[1], only_process_constants=True) == 0 - and extract_constant(left, only_process_constants=True) == 0 + and get_scalar_constant_value( + cond_var.owner.inputs[1], + only_process_constants=True, + raise_not_constant=False, + ) + == 0 + and get_scalar_constant_value( + left, only_process_constants=True, raise_not_constant=False + ) + == 0 and right == cond_var.owner.inputs[0] ): assert node.outputs[0].type.is_super(right.type) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 37a6f1424f..572b105680 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -28,7 +28,6 @@ as_tensor_variable, cast, constant, - extract_constant, get_underlying_scalar_constant_value, moveaxis, ones_like, @@ -565,11 +564,14 @@ def local_expm1(fgraph, node): in1.owner and isinstance(in1.owner.op, Elemwise) and isinstance(in1.owner.op.scalar_op, ps.Exp) - and extract_constant(in2, only_process_constants=False) == 1 + and get_underlying_scalar_constant_value(in2, raise_not_constant=False) == 1 ): in11 = in1.owner.inputs[0] new_out = expm1(in11) + if new_out.type.broadcastable != out.type.broadcastable: + new_out = broadcast_arrays(in11, in2)[0] + if new_out.dtype != out.dtype: new_out = cast(new_out, dtype=out.dtype) @@ -1370,12 +1372,13 @@ def local_useless_elemwise_comparison(fgraph, node): the graph easier to read. """ + # TODO: Refactor this function. So much repeated code! + if node.op.scalar_op.nin != 2: return - # We call zeros_like and one_like with opt=True to generate a - # cleaner graph. - dtype = node.outputs[0].dtype + dtype = node.outputs[0].type.dtype + out_bcast = node.outputs[0].type.broadcastable # Elemwise[{LT,GT}](X, X) -> Elemwise[zeros](X) if ( @@ -1386,6 +1389,7 @@ def local_useless_elemwise_comparison(fgraph, node): # Copy over stacktrace from previous output. copy_stack_trace(node.outputs, res) return [res] + # Elemwise[{LE,GE}](X, X) -> Elemwise[ones](X) if ( isinstance(node.op.scalar_op, ps.LE | ps.GE) @@ -1396,6 +1400,7 @@ def local_useless_elemwise_comparison(fgraph, node): # Copy over stacktrace from previous output. copy_stack_trace(node.outputs, res) return [res] + # Elemwise[{minimum,maximum}](X, X) -> X if ( isinstance(node.op.scalar_op, ps.ScalarMinimum | ps.ScalarMaximum) @@ -1411,64 +1416,72 @@ def local_useless_elemwise_comparison(fgraph, node): isinstance(node.op.scalar_op, ps.LT) and node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Shape_i) - and extract_constant(node.inputs[1], only_process_constants=True) == 0 + and get_underlying_scalar_constant_value( + node.inputs[1], only_process_constants=True, raise_not_constant=False + ) + == 0 ): res = zeros_like(node.inputs[0], dtype=dtype, opt=True) + if res.type.broadcastable != out_bcast: + res = broadcast_arrays(res, node.inputs[1])[0] # Copy over stacktrace from previous output. copy_stack_trace(node.outputs, res) return [res] + # Elemwise[GE](X.shape[i], 0) -> Elemwise[ones](X) if ( isinstance(node.op.scalar_op, ps.GE) and node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Shape_i) - and extract_constant(node.inputs[1], only_process_constants=True) == 0 + and get_underlying_scalar_constant_value( + node.inputs[1], only_process_constants=True, raise_not_constant=False + ) + == 0 ): res = ones_like(node.inputs[0], dtype=dtype, opt=True) + if res.type.broadcastable != out_bcast: + res = broadcast_arrays(res, node.inputs[1])[0] # Copy over stacktrace from previous output. copy_stack_trace(node.outputs, res) return [res] + # Elemwise[maximum](X.shape[i], 0) -> X.shape[i] - if ( - isinstance(node.op.scalar_op, ps.ScalarMaximum) - and node.inputs[0].owner - and isinstance(node.inputs[0].owner.op, Shape_i) - and extract_constant(node.inputs[1], only_process_constants=True) == 0 - ): - # No need to copy over stacktrace. - return [node.inputs[0]] - # Elemwise[maximum](0, X.shape[i]) -> X.shape[i] - if ( - isinstance(node.op.scalar_op, ps.ScalarMaximum) - and extract_constant(node.inputs[0], only_process_constants=True) == 0 - and node.inputs[1].owner - and isinstance(node.inputs[1].owner.op, Shape_i) - ): - # No need to copy over stacktrace. - return [node.inputs[1]] - # Elemwise[minimum](X.shape[i], 0) -> 0 - if ( - isinstance(node.op.scalar_op, ps.ScalarMinimum) - and node.inputs[0].owner - and isinstance(node.inputs[0].owner.op, Shape_i) - and extract_constant(node.inputs[1], only_process_constants=True) == 0 - ): - res = zeros_like(node.inputs[0], dtype=dtype, opt=True) - # Copy over stacktrace from previous output. - copy_stack_trace(node.outputs, res) - return [res] + if isinstance(node.op.scalar_op, ps.ScalarMaximum): + for idx in range(2): + if ( + node.inputs[idx].owner + and isinstance(node.inputs[idx].owner.op, Shape_i) + and get_underlying_scalar_constant_value( + node.inputs[1 - idx], + only_process_constants=True, + raise_not_constant=False, + ) + == 0 + ): + res = node.inputs[idx] + if res.type.broadcastable != out_bcast: + res = broadcast_arrays(res, node.inputs[1 - idx])[0] + # No need to copy over stacktrace. + return [res] - # Elemwise[minimum](0, X.shape[i]) -> 0 - if ( - isinstance(node.op.scalar_op, ps.ScalarMinimum) - and extract_constant(node.inputs[0], only_process_constants=True) == 0 - and node.inputs[1].owner - and isinstance(node.inputs[1].owner.op, Shape_i) - ): - res = zeros_like(node.inputs[1], dtype=dtype, opt=True) - # Copy over stacktrace from previous output. - copy_stack_trace(node.outputs, res) - return [res] + # Elemwise[minimum](X.shape[i], 0) -> 0 + if isinstance(node.op.scalar_op, ps.ScalarMinimum): + for idx in range(2): + if ( + node.inputs[idx].owner + and isinstance(node.inputs[idx].owner.op, Shape_i) + and get_underlying_scalar_constant_value( + node.inputs[1 - idx], + only_process_constants=True, + raise_not_constant=False, + ) + == 0 + ): + res = zeros_like(node.inputs[idx], dtype=dtype, opt=True) + if res.type.broadcastable != out_bcast: + res = broadcast_arrays(res, node.inputs[1 - idx])[0] + # No need to copy over stacktrace. + return [res] # Elemwise[LT](add([anything that is shapes]), 0) -> Elemwise[zeros](X) if ( @@ -1480,12 +1493,18 @@ def local_useless_elemwise_comparison(fgraph, node): isinstance(var.owner and var.owner.op, Shape_i) for var in node.inputs[0].owner.inputs ) - and extract_constant(node.inputs[1], only_process_constants=True) == 0 + and get_underlying_scalar_constant_value( + node.inputs[1], only_process_constants=True, raise_not_constant=False + ) + == 0 ): res = zeros_like(node.inputs[0], dtype=dtype, opt=True) + if res.type.broadcastable != out_bcast: + res = broadcast_arrays(res, node.inputs[1])[0] # Copy over stacktrace from previous output. copy_stack_trace(node.outputs, res) return [res] + # Elemwise[GE](add([anything that is shapes]), 0) -> Elemwise[ones](X) if ( isinstance(node.op.scalar_op, ps.GE) @@ -1496,57 +1515,60 @@ def local_useless_elemwise_comparison(fgraph, node): isinstance(var.owner and var.owner.op, Shape_i) for var in node.inputs[0].owner.inputs ) - and extract_constant(node.inputs[1], only_process_constants=True) == 0 + and get_underlying_scalar_constant_value( + node.inputs[1], only_process_constants=True, raise_not_constant=False + ) + == 0 ): res = ones_like(node.inputs[0], dtype=dtype, opt=True) - + if res.type.broadcastable != out_bcast: + res = broadcast_arrays(res, node.inputs[1])[0] # Copy over stacktrace from previous output. copy_stack_trace(node.outputs, res) return [res] - # Elemwise[EQ](Subtensor(Shape(x)), -N) - # Elemwise[EQ](somegraph that only depend of shape, -N) - # TODO: handle the case where the -N is on either side - """ - |Elemwise{eq,no_inplace} [id B] '' - | |Subtensor{int64} [id C] '' - | | |Join [id D] '' - | | | |TensorConstant{0} [id E] - | | | |Subtensor{int64:int64:} [id F] '' - | | | | |Shape [id G] '' - """ + # Elemwise[EQ](Subtensor(Shape(x)), -N) + # Elemwise[EQ](somegraph that only depend of shape, -N) + # TODO: handle the case where the -N is on either side + """ +|Elemwise{eq,no_inplace} [id B] '' +| |Subtensor{int64} [id C] '' +| | |Join [id D] '' +| | | |TensorConstant{0} [id E] +| | | |Subtensor{int64:int64:} [id F] '' +| | | | |Shape [id G] '' + """ - def investigate(node): + def investigate_if_shape(node) -> bool: "Return True if values will be shapes, so >= 0" if isinstance(node.op, Shape | Shape_i): return True elif isinstance(node.op, Subtensor) and node.inputs[0].owner: - return investigate(node.inputs[0].owner) + return investigate_if_shape(node.inputs[0].owner) elif isinstance(node.op, Join): - return all(v.owner and investigate(v.owner) for v in node.inputs[1:]) + return all( + v.owner and investigate_if_shape(v.owner) for v in node.inputs[1:] + ) elif isinstance(node.op, MakeVector): - return all(v.owner and investigate(v.owner) for v in node.inputs) + return all(v.owner and investigate_if_shape(v.owner) for v in node.inputs) if ( isinstance(node.op.scalar_op, ps.EQ) and node.inputs[0].owner - and investigate(node.inputs[0].owner) + and investigate_if_shape(node.inputs[0].owner) + and ( + isinstance(node.inputs[1], TensorConstant) + and node.inputs[1].unique_value is not None + and node.inputs[1].unique_value < 0 + ) ): - try: - cst = get_underlying_scalar_constant_value( - node.inputs[1], only_process_constants=True - ) - - res = zeros_like(node.inputs[0], dtype=dtype, opt=True) - - if cst < 0: - # Copy over stacktrace from previous output. - copy_stack_trace(node.outputs, res) - - return [res] + res = zeros_like(node.inputs[0], dtype=dtype, opt=True) + if res.type.broadcastable != out_bcast: + res = broadcast_arrays(res, node.inputs[1])[0] + # Copy over stacktrace from previous output. + copy_stack_trace(node.outputs, res) + return [res] - except NotScalarConstantError: - pass return @@ -2248,12 +2270,19 @@ def local_log1p(fgraph, node): return [alloc_like(log1p(ninp), node.outputs[0], fgraph)] elif log_arg.owner and log_arg.owner.op == sub: - one = extract_constant(log_arg.owner.inputs[0], only_process_constants=True) + one, other = node.inputs + one = get_underlying_scalar_constant_value( + one, only_process_constants=True, raise_not_constant=False + ) if one != 1: return - other = log_arg.owner.inputs[1] - if other.dtype != log_arg.dtype: + + if other.type.broadcastable != log_arg.type.broadcastable: + other = broadcast_arrays(other, one)[0] + + if other.type.dtype != log_arg.type.dtype: other = other.astype(log_arg.dtype) + return [log1p(neg(other))] diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index 91c731a4ff..8b1e483ee4 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -22,7 +22,7 @@ as_tensor_variable, cast, constant, - extract_constant, + get_scalar_constant_value, get_underlying_scalar_constant_value, register_infer_shape, stack, @@ -354,7 +354,9 @@ def set_shape(self, r, s, override=False): not hasattr(r.type, "shape") or r.type.shape[i] != 1 or self.lscalar_one.equals(shape_vars[i]) - or self.lscalar_one.equals(extract_constant(shape_vars[i])) + or self.lscalar_one.equals( + get_scalar_constant_value(shape_vars[i], raise_not_constant=False) + ) for i in range(r.type.ndim) ) self.shape_of[r] = tuple(shape_vars) @@ -450,7 +452,11 @@ def update_shape(self, r, other_r): ) or self.lscalar_one.equals(merged_shape[i]) or self.lscalar_one.equals( - extract_constant(merged_shape[i], only_process_constants=True) + get_underlying_scalar_constant_value( + merged_shape[i], + only_process_constants=True, + raise_not_constant=False, + ) ) for i in range(r.type.ndim) ) @@ -474,7 +480,11 @@ def set_shape_i(self, r, i, s_i): not hasattr(r.type, "shape") or r.type.shape[idx] != 1 or self.lscalar_one.equals(new_shape[idx]) - or self.lscalar_one.equals(extract_constant(new_shape[idx])) + or self.lscalar_one.equals( + get_underlying_scalar_constant_value( + new_shape[idx], raise_not_constant=False + ) + ) for idx in range(r.type.ndim) ) self.shape_of[r] = tuple(new_shape) @@ -847,7 +857,10 @@ def local_useless_reshape(fgraph, node): outshp_i.owner and isinstance(outshp_i.owner.op, Subtensor) and len(outshp_i.owner.inputs) == 2 - and extract_constant(outshp_i.owner.inputs[1]) == dim + and get_scalar_constant_value( + outshp_i.owner.inputs[1], raise_not_constant=False + ) + == dim ): subtensor_inp = outshp_i.owner.inputs[0] if subtensor_inp.owner and isinstance(subtensor_inp.owner.op, Shape): @@ -857,7 +870,9 @@ def local_useless_reshape(fgraph, node): continue # Match constant if input.type.shape[dim] == constant - cst_outshp_i = extract_constant(outshp_i, only_process_constants=1) + cst_outshp_i = get_scalar_constant_value( + outshp_i, only_process_constants=True, raise_not_constant=False + ) if inp.type.shape[dim] == cst_outshp_i: shape_match[dim] = True continue @@ -872,8 +887,12 @@ def local_useless_reshape(fgraph, node): if shape_feature: inpshp_i = shape_feature.get_shape(inp, dim) if inpshp_i == outshp_i or ( - extract_constant(inpshp_i, only_process_constants=True) - == extract_constant(outshp_i, only_process_constants=True) + get_scalar_constant_value( + inpshp_i, only_process_constants=True, raise_not_constant=True + ) + == get_scalar_constant_value( + outshp_i, only_process_constants=True, raise_not_constant=True + ) ): shape_match[dim] = True continue @@ -909,11 +928,14 @@ def local_reshape_to_dimshuffle(fgraph, node): new_output_shape = [] index = 0 # index over the output of the new reshape for i in range(output.ndim): - # Since output_shape is a symbolic vector, we trust extract_constant + # Since output_shape is a symbolic vector, we trust get_scalar_constant_value # to go through however it is formed to see if its i-th element is 1. # We need only_process_constants=False for that. - dim = extract_constant( - output_shape[i], only_process_constants=False, elemwise=False + dim = get_scalar_constant_value( + output_shape[i], + only_process_constants=False, + elemwise=False, + raise_not_constant=False, ) if dim == 1: dimshuffle_new_order.append("x") diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index 7699169143..2b79a57212 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -26,7 +26,7 @@ as_tensor, cast, concatenate, - extract_constant, + get_scalar_constant_value, get_underlying_scalar_constant_value, register_infer_shape, switch, @@ -373,7 +373,10 @@ def local_useless_slice(fgraph, node): step = s.step if ( start is not None - and extract_constant(start, only_process_constants=True) == 0 + and get_scalar_constant_value( + start, only_process_constants=True, raise_not_constant=False + ) + == 0 ): change_flag = True start = None @@ -381,14 +384,20 @@ def local_useless_slice(fgraph, node): if ( stop is not None and x.type.shape[dim] is not None - and extract_constant(stop, only_process_constants=True) == x.type.shape[dim] + and get_scalar_constant_value( + stop, only_process_constants=True, raise_not_constant=False + ) + == x.type.shape[dim] ): change_flag = True stop = None if ( step is not None - and extract_constant(step, only_process_constants=True) == 1 + and get_scalar_constant_value( + step, only_process_constants=True, raise_not_constant=False + ) + == 1 ): change_flag = True step = None @@ -878,7 +887,10 @@ def local_useless_inc_subtensor(fgraph, node): and e.stop is None and ( e.step is None - or extract_constant(e.step, only_process_constants=True) == -1 + or get_scalar_constant_value( + e.step, only_process_constants=True, raise_not_constant=False + ) + == -1 ) for e in idx_cst ): @@ -1479,7 +1491,10 @@ def local_adv_sub1_adv_inc_sub1(fgraph, node): and # Don't use only_process_constants=True. We need to # investigate Alloc of 0s but with non constant shape. - extract_constant(x, elemwise=False) != 0 + get_underlying_scalar_constant_value( + x, elemwise=False, raise_not_constant=False + ) + != 0 ): return diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 019833a9d5..da7675b296 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -1379,11 +1379,11 @@ def assert_eqs_const(self, f, val, op=deep_copy_op): if op == deep_copy_op: assert len(elem.inputs) == 1, elem.inputs assert isinstance(elem.inputs[0], TensorConstant), elem - assert pt.extract_constant(elem.inputs[0]) == val, val + assert pt.get_underlying_scalar_constant_value(elem.inputs[0]) == val, val else: assert len(elem.inputs) == 2, elem.inputs assert isinstance(elem.inputs[0], TensorConstant), elem - assert pt.extract_constant(elem.inputs[0]) == val, val + assert pt.get_underlying_scalar_constant_value(elem.inputs[0]) == val, val def assert_identity(self, f): topo = f.maker.fgraph.toposort() diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 72763d05db..8717b9b44d 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -46,7 +46,6 @@ default, diag, expand_dims, - extract_constant, eye, fill, flatnonzero, @@ -3567,10 +3566,10 @@ def test_copy(self): # Make sure we do not return a writeable internal storage of a constant, # so we cannot change the value of a constant by mistake. c = constant(3) - d = extract_constant(c) + d = get_scalar_constant_value(c) with pytest.raises(ValueError, match="output array is read-only"): d += 1 - e = extract_constant(c) + e = get_scalar_constant_value(c) assert e == 3, (c, d, e) @pytest.mark.parametrize("only_process_constants", (True, False)) From 96dbda42d5c4ecca291ef0193165d06dc6818e06 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Mon, 21 Oct 2024 17:31:28 +0200 Subject: [PATCH 4/6] Use more strict `get_scalar_constant_value` when the input must be a scalar --- pytensor/link/jax/dispatch/tensor_basic.py | 6 ++--- pytensor/scan/basic.py | 2 +- pytensor/scan/rewriting.py | 5 ++-- pytensor/tensor/basic.py | 10 ++++---- pytensor/tensor/conv/abstract_conv.py | 16 +++++------- pytensor/tensor/extra_ops.py | 2 +- pytensor/tensor/rewriting/basic.py | 5 ++-- pytensor/tensor/rewriting/math.py | 30 ++++++++++------------ pytensor/tensor/rewriting/shape.py | 11 +++----- pytensor/tensor/rewriting/subtensor.py | 13 +++------- pytensor/tensor/shape.py | 17 +++++++----- pytensor/tensor/subtensor.py | 12 ++++----- 12 files changed, 58 insertions(+), 71 deletions(-) diff --git a/pytensor/link/jax/dispatch/tensor_basic.py b/pytensor/link/jax/dispatch/tensor_basic.py index bf1a93ce5b..79344b2275 100644 --- a/pytensor/link/jax/dispatch/tensor_basic.py +++ b/pytensor/link/jax/dispatch/tensor_basic.py @@ -18,7 +18,7 @@ Split, TensorFromScalar, Tri, - get_underlying_scalar_constant_value, + get_scalar_constant_value, ) from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.shape import Shape_i @@ -103,7 +103,7 @@ def join(axis, *tensors): def jax_funcify_Split(op: Split, node, **kwargs): _, axis, splits = node.inputs try: - constant_axis = get_underlying_scalar_constant_value(axis) + constant_axis = get_scalar_constant_value(axis) except NotScalarConstantError: constant_axis = None warnings.warn( @@ -113,7 +113,7 @@ def jax_funcify_Split(op: Split, node, **kwargs): try: constant_splits = np.array( [ - get_underlying_scalar_constant_value(splits[i]) + get_scalar_constant_value(splits[i]) for i in range(get_vector_length(splits)) ] ) diff --git a/pytensor/scan/basic.py b/pytensor/scan/basic.py index 931e105597..775f9b57e6 100644 --- a/pytensor/scan/basic.py +++ b/pytensor/scan/basic.py @@ -484,7 +484,7 @@ def wrap_into_list(x): n_fixed_steps = int(n_steps) else: try: - n_fixed_steps = pt.get_underlying_scalar_constant_value(n_steps) + n_fixed_steps = pt.get_scalar_constant_value(n_steps) except NotScalarConstantError: n_fixed_steps = None diff --git a/pytensor/scan/rewriting.py b/pytensor/scan/rewriting.py index 7ee512a549..5919247b58 100644 --- a/pytensor/scan/rewriting.py +++ b/pytensor/scan/rewriting.py @@ -55,7 +55,6 @@ Alloc, AllocEmpty, get_scalar_constant_value, - get_underlying_scalar_constant_value, ) from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError @@ -1976,13 +1975,13 @@ def belongs_to_set(self, node, set_nodes): nsteps = node.inputs[0] try: - nsteps = int(get_underlying_scalar_constant_value(nsteps)) + nsteps = int(get_scalar_constant_value(nsteps)) except NotScalarConstantError: pass rep_nsteps = rep_node.inputs[0] try: - rep_nsteps = int(get_underlying_scalar_constant_value(rep_nsteps)) + rep_nsteps = int(get_scalar_constant_value(rep_nsteps)) except NotScalarConstantError: pass diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 93b59df319..d9af4cacad 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -609,7 +609,7 @@ def get_scalar_constant_value( """ if isinstance(v, Variable | np.ndarray): if v.ndim != 0: - raise TypeError() + raise TypeError("Input is not a scalar") return get_underlying_scalar_constant_value( v, elemwise=elemwise, @@ -1801,7 +1801,7 @@ def do_constant_folding(self, fgraph, node): @_get_vector_length.register(Alloc) def _get_vector_length_Alloc(var_inst, var): try: - return get_underlying_scalar_constant_value(var.owner.inputs[1]) + return get_scalar_constant_value(var.owner.inputs[1]) except NotScalarConstantError: raise ValueError(f"Length of {var} cannot be determined") @@ -2502,7 +2502,7 @@ def make_node(self, axis, *tensors): if not isinstance(axis, int): try: - axis = int(get_underlying_scalar_constant_value(axis)) + axis = int(get_scalar_constant_value(axis)) except NotScalarConstantError: pass @@ -2746,7 +2746,7 @@ def infer_shape(self, fgraph, node, ishapes): def _get_vector_length_Join(op, var): axis, *arrays = var.owner.inputs try: - axis = get_underlying_scalar_constant_value(axis) + axis = get_scalar_constant_value(axis) assert axis == 0 and builtins.all(a.ndim == 1 for a in arrays) return builtins.sum(get_vector_length(a) for a in arrays) except NotScalarConstantError: @@ -4138,7 +4138,7 @@ def make_node(self, a, choices): static_out_shape = () for s in out_shape: try: - s_val = get_underlying_scalar_constant_value(s) + s_val = get_scalar_constant_value(s) except (NotScalarConstantError, AttributeError): s_val = None diff --git a/pytensor/tensor/conv/abstract_conv.py b/pytensor/tensor/conv/abstract_conv.py index 73d402cfca..82406b9c57 100644 --- a/pytensor/tensor/conv/abstract_conv.py +++ b/pytensor/tensor/conv/abstract_conv.py @@ -25,7 +25,7 @@ from pytensor.raise_op import Assert from pytensor.tensor.basic import ( as_tensor_variable, - get_underlying_scalar_constant_value, + get_scalar_constant_value, ) from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.variable import TensorConstant, TensorVariable @@ -497,8 +497,8 @@ def check_dim(given, computed): if given is None or computed is None: return True try: - given = get_underlying_scalar_constant_value(given) - computed = get_underlying_scalar_constant_value(computed) + given = get_scalar_constant_value(given) + computed = get_scalar_constant_value(computed) return int(given) == int(computed) except NotScalarConstantError: # no answer possible, accept for now @@ -534,7 +534,7 @@ def assert_conv_shape(shape): out_shape = [] for i, n in enumerate(shape): try: - const_n = get_underlying_scalar_constant_value(n) + const_n = get_scalar_constant_value(n) if i < 2: if const_n < 0: raise ValueError( @@ -2203,9 +2203,7 @@ def __init__( if imshp_i is not None: # Components of imshp should be constant or ints try: - get_underlying_scalar_constant_value( - imshp_i, only_process_constants=True - ) + get_scalar_constant_value(imshp_i, only_process_constants=True) except NotScalarConstantError: raise ValueError( "imshp should be None or a tuple of constant int values" @@ -2218,9 +2216,7 @@ def __init__( if kshp_i is not None: # Components of kshp should be constant or ints try: - get_underlying_scalar_constant_value( - kshp_i, only_process_constants=True - ) + get_scalar_constant_value(kshp_i, only_process_constants=True) except NotScalarConstantError: raise ValueError( "kshp should be None or a tuple of constant int values" diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index 9de2b3f938..76d853a71f 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -678,7 +678,7 @@ def make_node(self, x, repeats): out_shape = [None] else: try: - const_reps = ptb.get_underlying_scalar_constant_value(repeats) + const_reps = ptb.get_scalar_constant_value(repeats) except NotScalarConstantError: const_reps = None if const_reps == 1: diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index d9b82b0ad5..0cea718927 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -56,7 +56,6 @@ cast, fill, get_scalar_constant_value, - get_underlying_scalar_constant_value, join, ones_like, register_infer_shape, @@ -738,7 +737,7 @@ def local_remove_useless_assert(fgraph, node): n_conds = len(node.inputs[1:]) for c in node.inputs[1:]: try: - const = get_underlying_scalar_constant_value(c) + const = get_scalar_constant_value(c) if 0 != const.ndim or const == 0: # Should we raise an error here? How to be sure it @@ -833,7 +832,7 @@ def local_join_empty(fgraph, node): return new_inputs = [] try: - join_idx = get_underlying_scalar_constant_value( + join_idx = get_scalar_constant_value( node.inputs[0], only_process_constants=True ) except NotScalarConstantError: diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 572b105680..18e3bb88c4 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -152,18 +152,16 @@ def local_0_dot_x(fgraph, node): x = node.inputs[0] y = node.inputs[1] - replace = False - try: - if get_underlying_scalar_constant_value(x, only_process_constants=True) == 0: - replace = True - except NotScalarConstantError: - pass - - try: - if get_underlying_scalar_constant_value(y, only_process_constants=True) == 0: - replace = True - except NotScalarConstantError: - pass + replace = ( + get_underlying_scalar_constant_value( + x, only_process_constants=True, raise_not_constant=False + ) + == 0 + or get_underlying_scalar_constant_value( + y, only_process_constants=True, raise_not_constant=False + ) + == 0 + ) if replace: constant_zero = constant(0, dtype=node.outputs[0].type.dtype) @@ -2135,7 +2133,7 @@ def local_add_remove_zeros(fgraph, node): y = get_underlying_scalar_constant_value(inp) except NotScalarConstantError: y = inp - if np.all(y == 0.0): + if y == 0.0: continue new_inputs.append(inp) @@ -2233,7 +2231,7 @@ def local_abs_merge(fgraph, node): ) except NotScalarConstantError: return False - if not (const >= 0).all(): + if not const >= 0: return False inputs.append(i) else: @@ -2881,7 +2879,7 @@ def _is_1(expr): """ try: v = get_underlying_scalar_constant_value(expr) - return np.allclose(v, 1) + return np.isclose(v, 1) except NotScalarConstantError: return False @@ -3049,7 +3047,7 @@ def is_neg(var): for idx, mul_input in enumerate(var_node.inputs): try: constant = get_underlying_scalar_constant_value(mul_input) - is_minus_1 = np.allclose(constant, -1) + is_minus_1 = np.isclose(constant, -1) except NotScalarConstantError: is_minus_1 = False if is_minus_1: diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index 8b1e483ee4..27b1415dff 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -23,7 +23,6 @@ cast, constant, get_scalar_constant_value, - get_underlying_scalar_constant_value, register_infer_shape, stack, ) @@ -213,7 +212,7 @@ def shape_ir(self, i, r): # Do not call make_node for test_value s = Shape_i(i)(r) try: - s = get_underlying_scalar_constant_value(s) + s = get_scalar_constant_value(s) except NotScalarConstantError: pass return s @@ -297,7 +296,7 @@ def unpack(self, s_i, var): assert len(idx) == 1 idx = idx[0] try: - i = get_underlying_scalar_constant_value(idx) + i = get_scalar_constant_value(idx) except NotScalarConstantError: pass else: @@ -452,7 +451,7 @@ def update_shape(self, r, other_r): ) or self.lscalar_one.equals(merged_shape[i]) or self.lscalar_one.equals( - get_underlying_scalar_constant_value( + get_scalar_constant_value( merged_shape[i], only_process_constants=True, raise_not_constant=False, @@ -481,9 +480,7 @@ def set_shape_i(self, r, i, s_i): or r.type.shape[idx] != 1 or self.lscalar_one.equals(new_shape[idx]) or self.lscalar_one.equals( - get_underlying_scalar_constant_value( - new_shape[idx], raise_not_constant=False - ) + get_scalar_constant_value(new_shape[idx], raise_not_constant=False) ) for idx in range(r.type.ndim) ) diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index 2b79a57212..cdfcedc36d 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -995,7 +995,7 @@ def local_useless_subtensor(fgraph, node): if isinstance(idx.stop, int | np.integer): length_pos_data = sys.maxsize try: - length_pos_data = get_underlying_scalar_constant_value( + length_pos_data = get_scalar_constant_value( length_pos, only_process_constants=True ) except NotScalarConstantError: @@ -1060,7 +1060,7 @@ def local_useless_AdvancedSubtensor1(fgraph, node): # get length of the indexed tensor along the first axis try: - length = get_underlying_scalar_constant_value( + length = get_scalar_constant_value( shape_of[node.inputs[0]][0], only_process_constants=True ) except NotScalarConstantError: @@ -1732,7 +1732,7 @@ def local_join_subtensors(fgraph, node): axis, tensors = node.inputs[0], node.inputs[1:] try: - axis = get_underlying_scalar_constant_value(axis) + axis = get_scalar_constant_value(axis) except NotScalarConstantError: return @@ -1793,12 +1793,7 @@ def local_join_subtensors(fgraph, node): if step is None: continue try: - if ( - get_underlying_scalar_constant_value( - step, only_process_constants=True - ) - != 1 - ): + if get_scalar_constant_value(step, only_process_constants=True) != 1: return None except NotScalarConstantError: return None diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index 2193c11575..4c753de8f4 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -15,7 +15,12 @@ from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType from pytensor.scalar import int32 -from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length +from pytensor.tensor import ( + _get_vector_length, + as_tensor_variable, + get_scalar_constant_value, + get_vector_length, +) from pytensor.tensor import basic as ptb from pytensor.tensor.elemwise import get_normalized_batch_axes from pytensor.tensor.exceptions import NotScalarConstantError @@ -401,8 +406,6 @@ class SpecifyShape(COp): _output_type_depends_on_input_value = True def make_node(self, x, *shape): - from pytensor.tensor.basic import get_underlying_scalar_constant_value - x = ptb.as_tensor_variable(x) shape = tuple( @@ -430,7 +433,7 @@ def make_node(self, x, *shape): type_shape[i] = xts else: try: - type_s = get_underlying_scalar_constant_value(s) + type_s = get_scalar_constant_value(s) if type_s is not None: type_shape[i] = int(type_s) except NotScalarConstantError: @@ -461,7 +464,7 @@ def infer_shape(self, fgraph, node, shapes): for dim in range(node.inputs[0].type.ndim): s = shape[dim] try: - s = ptb.get_underlying_scalar_constant_value(s) + s = ptb.get_scalar_constant_value(s) # We assume that `None` shapes are always retrieved by # `get_underlying_scalar_constant_value`, and only in that case do we default to # the shape of the input variable @@ -587,7 +590,7 @@ def specify_shape( @_get_vector_length.register(SpecifyShape) # type: ignore def _get_vector_length_SpecifyShape(op: Op, var: TensorVariable) -> int: try: - return int(ptb.get_underlying_scalar_constant_value(var.owner.inputs[1]).item()) + return int(ptb.get_scalar_constant_value(var.owner.inputs[1]).item()) except NotScalarConstantError: raise ValueError(f"Length of {var} cannot be determined") @@ -668,7 +671,7 @@ def make_node(self, x, shp): y = shp_list[index] y = ptb.as_tensor_variable(y) try: - s_val = ptb.get_underlying_scalar_constant_value(y).item() + s_val = ptb.get_scalar_constant_value(y).item() if s_val >= 0: out_shape[index] = s_val except NotScalarConstantError: diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index b0f4aaf9fc..806a008931 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -29,7 +29,7 @@ from pytensor.tensor.basic import ( ScalarFromTensor, alloc, - get_underlying_scalar_constant_value, + get_scalar_constant_value, nonzero, scalar_from_tensor, ) @@ -778,7 +778,7 @@ def conv(val): return slice(conv(val.start), conv(val.stop), conv(val.step)) else: try: - return get_underlying_scalar_constant_value( + return get_scalar_constant_value( val, only_process_constants=only_process_constants, elemwise=elemwise, @@ -855,7 +855,7 @@ def extract_const(value): if value is None: return value, True try: - value = get_underlying_scalar_constant_value(value) + value = get_scalar_constant_value(value) return value, True except NotScalarConstantError: return value, False @@ -2989,17 +2989,17 @@ def _get_vector_length_Subtensor(op, var): start = ( None if indices[0].start is None - else get_underlying_scalar_constant_value(indices[0].start) + else get_scalar_constant_value(indices[0].start) ) stop = ( None if indices[0].stop is None - else get_underlying_scalar_constant_value(indices[0].stop) + else get_scalar_constant_value(indices[0].stop) ) step = ( None if indices[0].step is None - else get_underlying_scalar_constant_value(indices[0].step) + else get_scalar_constant_value(indices[0].step) ) if start == stop: From 70c64a0b674eb2fe51761dd759a9d76dc1232c2d Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Mon, 21 Oct 2024 17:38:55 +0200 Subject: [PATCH 5/6] Use more strict `get_scalar_constant_value` when the input must be a scalar --- pytensor/tensor/shape.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index 4c753de8f4..455d791e1d 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -18,7 +18,6 @@ from pytensor.tensor import ( _get_vector_length, as_tensor_variable, - get_scalar_constant_value, get_vector_length, ) from pytensor.tensor import basic as ptb @@ -433,7 +432,7 @@ def make_node(self, x, *shape): type_shape[i] = xts else: try: - type_s = get_scalar_constant_value(s) + type_s = ptb.get_scalar_constant_value(s) if type_s is not None: type_shape[i] = int(type_s) except NotScalarConstantError: From 1461a39e6f5433d0cede5ed15d34c4c09418a66e Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Mon, 21 Oct 2024 17:53:10 +0200 Subject: [PATCH 6/6] Remove internal get_constant helper Fixes bug in `local_add_neg_to_sub` reported in https://github.com/pymc-devs/pytensor/issues/584 --- pytensor/tensor/rewriting/math.py | 83 +++++++++++++++-------------- tests/tensor/rewriting/test_math.py | 8 +-- 2 files changed, 48 insertions(+), 43 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 18e3bb88c4..a3adb23fd2 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -125,24 +125,6 @@ def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False): return consts, origconsts, nonconsts -def get_constant(v): - """ - - Returns - ------- - object - A numeric constant if v is a Constant or, well, a - numeric constant. If v is a plain Variable, returns None. - - """ - if isinstance(v, TensorConstant): - return v.unique_value - elif isinstance(v, Variable): - return None - else: - return v - - @register_canonicalize @register_stabilize @node_rewriter([Dot]) @@ -1021,8 +1003,8 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None): """ Find all constants and put them together into a single constant. - Finds all constants in orig_num and orig_denum (using - get_constant) and puts them together into a single + Finds all constants in orig_num and orig_denum + and puts them together into a single constant. The constant is inserted as the first element of the numerator. If the constant is the neutral element, it is removed from the numerator. @@ -1043,17 +1025,15 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None): numct, denumct = [], [] for v in orig_num: - ct = get_constant(v) - if ct is not None: + if isinstance(v, TensorConstant) and v.unique_value is not None: # We found a constant in the numerator! # We add it to numct - numct.append(ct) + numct.append(v.unique_value) else: num.append(v) for v in orig_denum: - ct = get_constant(v) - if ct is not None: - denumct.append(ct) + if isinstance(v, TensorConstant) and v.unique_value is not None: + denumct.append(v.unique_value) else: denum.append(v) @@ -1077,11 +1057,13 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None): if orig_num and len(numct) == 1 and len(denumct) == 0 and ct: # In that case we should only have one constant in `ct`. - assert len(ct) == 1 - first_num_ct = get_constant(orig_num[0]) - if first_num_ct is not None and ct[0].type.values_eq( - ct[0].data, first_num_ct - ): + [var_ct] = ct + + num_ct = None + if isinstance(var_ct, TensorConstant): + num_ct = var_ct.unique_value + + if num_ct is not None and var_ct.type.values_eq(var_ct.data, num_ct): # This is an important trick :( if it so happens that: # * there's exactly one constant on the numerator and none on # the denominator @@ -1864,9 +1846,12 @@ def local_add_neg_to_sub(fgraph, node): return [new_out] # Check if it is a negative constant - const = get_constant(second) - if const is not None and const < 0: - new_out = sub(first, np.abs(const)) + if ( + isinstance(second, TensorConstant) + and second.unique_value is not None + and second.unique_value < 0 + ): + new_out = sub(first, np.abs(second.data)) return [new_out] @@ -1895,7 +1880,12 @@ def local_mul_zero(fgraph, node): @register_specialize @node_rewriter([true_div]) def local_div_to_reciprocal(fgraph, node): - if np.all(get_constant(node.inputs[0]) == 1.0): + if ( + get_underlying_scalar_constant_value( + node.inputs[0], only_process_constants=True, raise_not_constant=False + ) + == 1.0 + ): out = node.outputs[0] new_out = reciprocal(local_mul_canonizer.merge_num_denum(node.inputs[1:], [])) # The ones could have forced upcasting @@ -1916,7 +1906,9 @@ def local_reciprocal_canon(fgraph, node): @register_canonicalize @node_rewriter([pt_pow]) def local_pow_canonicalize(fgraph, node): - cst = get_constant(node.inputs[1]) + cst = get_underlying_scalar_constant_value( + node.inputs[1], only_process_constants=True, raise_not_constant=False + ) if cst == 0: return [alloc_like(1, node.outputs[0], fgraph)] if cst == 1: @@ -1947,7 +1939,12 @@ def local_intdiv_by_one(fgraph, node): @node_rewriter([int_div, true_div]) def local_zero_div(fgraph, node): """0 / x -> 0""" - if get_constant(node.inputs[0]) == 0: + if ( + get_underlying_scalar_constant_value( + node.inputs[0], only_process_constants=True, raise_not_constant=False + ) + == 0 + ): ret = alloc_like(0, node.outputs[0], fgraph) ret.tag.values_eq_approx = values_eq_approx_remove_nan return [ret] @@ -1960,7 +1957,9 @@ def local_pow_specialize(fgraph, node): odtype = node.outputs[0].dtype xsym = node.inputs[0] ysym = node.inputs[1] - y = get_constant(ysym) + y = get_underlying_scalar_constant_value( + ysym, only_process_constants=True, raise_not_constant=False + ) if (y is not None) and not broadcasted_by(xsym, ysym): rval = None @@ -1998,7 +1997,9 @@ def local_pow_to_nested_squaring(fgraph, node): odtype = node.outputs[0].dtype xsym = node.inputs[0] ysym = node.inputs[1] - y = get_constant(ysym) + y = get_underlying_scalar_constant_value( + ysym, only_process_constants=True, raise_not_constant=False + ) # the next line is needed to fix a strange case that I don't # know how to make a separate test. @@ -2081,7 +2082,9 @@ def local_mul_specialize(fgraph, node): nb_neg_node += 1 # remove special case arguments of 1, -1 or 0 - y = get_constant(inp) + y = get_underlying_scalar_constant_value( + inp, raise_not_constant=False, only_process_constants=True + ) if y == 1.0: nb_cst += 1 elif y == -1.0: diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index da7675b296..306bc4f455 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -4376,11 +4376,13 @@ def test_local_add_neg_to_sub(first_negative): assert np.allclose(f(x_test, y_test), exp) -def test_local_add_neg_to_sub_const(): +@pytest.mark.parametrize("const_left", (True, False)) +def test_local_add_neg_to_sub_const(const_left): x = vector("x") - const = 5.0 + const = np.full((3, 2), 5.0) + out = -const + x if const_left else x + (-const) - f = function([x], x + (-const), mode=Mode("py")) + f = function([x], out, mode=Mode("py")) nodes = [ node.op