Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate redundant utilities for extracting constants #1046

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions pytensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
# pytensor code, since this code may want to log some messages.
import logging
import sys
import warnings
from functools import singledispatch
from pathlib import Path
from typing import Any, NoReturn, Optional
Expand Down Expand Up @@ -148,12 +149,10 @@ def get_underlying_scalar_constant(v):
If `v` is not some view of constant data, then raise a
`NotScalarConstantError`.
"""
# Is it necessary to test for presence of pytensor.sparse at runtime?
sparse = globals().get("sparse")
if sparse and isinstance(v.type, sparse.SparseTensorType):
if v.owner is not None and isinstance(v.owner.op, sparse.CSM):
data = v.owner.inputs[0]
return tensor.get_underlying_scalar_constant_value(data)
warnings.warn(
"get_underlying_scalar_constant is deprecated. Use tensor.get_underlying_scalar_constant_value instead.",
DeprecationWarning,
)
return tensor.get_underlying_scalar_constant_value(v)


Expand Down
4 changes: 2 additions & 2 deletions pytensor/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -1330,7 +1330,7 @@ def try_to_copy_if_needed(var):
f" {i}. Since this input is only connected "
"to integer-valued outputs, it should "
"evaluate to zeros, but it evaluates to"
f"{pytensor.get_underlying_scalar_constant(term)}."
f"{pytensor.get_underlying_scalar_constant_value(term)}."
)
raise ValueError(msg)

Expand Down Expand Up @@ -2172,7 +2172,7 @@ def _is_zero(x):

no_constant_value = True
try:
constant_value = pytensor.get_underlying_scalar_constant(x)
constant_value = pytensor.get_underlying_scalar_constant_value(x)
no_constant_value = False
except pytensor.tensor.exceptions.NotScalarConstantError:
pass
Expand Down
6 changes: 3 additions & 3 deletions pytensor/link/jax/dispatch/tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
Split,
TensorFromScalar,
Tri,
get_underlying_scalar_constant_value,
get_scalar_constant_value,
)
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.shape import Shape_i
Expand Down Expand Up @@ -103,7 +103,7 @@ def join(axis, *tensors):
def jax_funcify_Split(op: Split, node, **kwargs):
_, axis, splits = node.inputs
try:
constant_axis = get_underlying_scalar_constant_value(axis)
constant_axis = get_scalar_constant_value(axis)
except NotScalarConstantError:
constant_axis = None
warnings.warn(
Expand All @@ -113,7 +113,7 @@ def jax_funcify_Split(op: Split, node, **kwargs):
try:
constant_splits = np.array(
[
get_underlying_scalar_constant_value(splits[i])
get_scalar_constant_value(splits[i])
for i in range(get_vector_length(splits))
]
)
Expand Down
2 changes: 1 addition & 1 deletion pytensor/scan/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def wrap_into_list(x):
n_fixed_steps = int(n_steps)
else:
try:
n_fixed_steps = pt.get_underlying_scalar_constant_value(n_steps)
n_fixed_steps = pt.get_scalar_constant_value(n_steps)
except NotScalarConstantError:
n_fixed_steps = None

Expand Down
42 changes: 26 additions & 16 deletions pytensor/scan/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from pytensor.tensor.basic import (
Alloc,
AllocEmpty,
get_underlying_scalar_constant_value,
get_scalar_constant_value,
)
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
Expand All @@ -71,7 +71,7 @@
get_slice_elements,
set_subtensor,
)
from pytensor.tensor.variable import TensorConstant, get_unique_constant_value
from pytensor.tensor.variable import TensorConstant


list_opt_slice = [
Expand Down Expand Up @@ -136,10 +136,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
all_ins = list(graph_inputs(op_outs))
for idx in range(op_info.n_seqs):
node_inp = node.inputs[idx + 1]
if (
isinstance(node_inp, TensorConstant)
and get_unique_constant_value(node_inp) is not None
):
if isinstance(node_inp, TensorConstant) and node_inp.unique_value is not None:
try:
# This works if input is a constant that has all entries
# equal
Expand Down Expand Up @@ -668,8 +665,10 @@ def inner_sitsot_only_last_step_used(
client = fgraph.clients[outer_var][0][0]
if isinstance(client, Apply) and isinstance(client.op, Subtensor):
lst = get_idx_list(client.inputs, client.op.idx_list)
if len(lst) == 1 and pt.extract_constant(lst[0]) == -1:
return True
return (
len(lst) == 1
and get_scalar_constant_value(idx[0], raise_not_constant=False) == -1
)

return False

Expand Down Expand Up @@ -1344,10 +1343,17 @@ def scan_save_mem(fgraph, node):
if isinstance(this_slice[0], slice) and this_slice[0].stop is None:
global_nsteps = None
if isinstance(cf_slice[0], slice):
stop = pt.extract_constant(cf_slice[0].stop)
stop = get_scalar_constant_value(
cf_slice[0].stop, raise_not_constant=False
)
else:
stop = pt.extract_constant(cf_slice[0]) + 1
if stop == maxsize or stop == pt.extract_constant(length):
stop = (
get_scalar_constant_value(cf_slice[0], raise_not_constant=False)
+ 1
)
if stop == maxsize or stop == get_scalar_constant_value(
length, raise_not_constant=False
):
stop = None
else:
# there is a **gotcha** here ! Namely, scan returns an
Expand Down Expand Up @@ -1451,9 +1457,13 @@ def scan_save_mem(fgraph, node):
cf_slice = get_canonical_form_slice(this_slice[0], length)

if isinstance(cf_slice[0], slice):
start = pt.extract_constant(cf_slice[0].start)
start = pt.get_scalar_constant_value(
cf_slice[0].start, raise_not_constant=False
)
else:
start = pt.extract_constant(cf_slice[0])
start = pt.get_scalar_constant_value(
cf_slice[0], raise_not_constant=False
)

if start == 0 or store_steps[i] == 0:
store_steps[i] = 0
Expand Down Expand Up @@ -1628,7 +1638,7 @@ def scan_save_mem(fgraph, node):
# 3.6 Compose the new scan
# TODO: currently we don't support scan with 0 step. So
# don't create one.
if pt.extract_constant(node_ins[0]) == 0:
if get_scalar_constant_value(node_ins[0], raise_not_constant=False) == 0:
return False

# Do not call make_node for test_value
Expand Down Expand Up @@ -1965,13 +1975,13 @@ def belongs_to_set(self, node, set_nodes):

nsteps = node.inputs[0]
try:
nsteps = int(get_underlying_scalar_constant_value(nsteps))
nsteps = int(get_scalar_constant_value(nsteps))
except NotScalarConstantError:
pass

rep_nsteps = rep_node.inputs[0]
try:
rep_nsteps = int(get_underlying_scalar_constant_value(rep_nsteps))
rep_nsteps = int(get_scalar_constant_value(rep_nsteps))
except NotScalarConstantError:
pass

Expand Down
Loading
Loading