Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make all zips explicitly strict or non-strict #850

Merged
merged 13 commits into from
Nov 19, 2024
Merged

Make all zips explicitly strict or non-strict #850

merged 13 commits into from
Nov 19, 2024

Conversation

Armavica
Copy link
Member

@Armavica Armavica commented Jun 24, 2024

Description

  • First commit: adding a strict=True argument to all zips when it doesn't produce mistakes in the test suite (464 of them), and strict=False to the others (28 of them)
  • Second commit: enable ruff rule requiring and explicit strict argument to all zips
  • Rest of the commits: transform the non-strict zips into strict zips (18 of them for now)

There remains 10 non-strict zips that I find difficult to understand.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@Armavica Armavica force-pushed the zip-strict branch 2 times, most recently from ada9880 to a5de1b6 Compare June 26, 2024 22:40
@ricardoV94
Copy link
Member

ricardoV94 commented Jun 27, 2024

@Armavica may be crazy work, but can we get a separate commit where we make the non-strict zips. That way it's easier to evaluate if it sounds correct or may be a bug somewhere?

Or I guess I can just ctrl+f for it

@Armavica
Copy link
Member Author

Armavica commented Jun 28, 2024

@ricardoV94 Yes, I was planning to present this PR(s) in several steps:

  • Commits that add strict=True to zips without tests failing
  • Commits that add strict=True to zips and fix their failures (bugs)
  • Commits that add strict=False to the remaining zips that need it, or rewrite them so they can be made strict
    How does that sound to you?

@ricardoV94
Copy link
Member

Sounds good @Armavica

@Armavica Armavica force-pushed the zip-strict branch 5 times, most recently from dc0aa6e to 36868a5 Compare June 29, 2024 06:32
Copy link

codecov bot commented Jun 29, 2024

Codecov Report

Attention: Patch coverage is 91.55556% with 19 lines in your changes missing coverage. Please review.

Project coverage is 82.12%. Comparing base (6de3151) to head (a4d7bb4).
Report is 13 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/scan/rewriting.py 61.53% 5 Missing ⚠️
pytensor/d3viz/formatting.py 0.00% 2 Missing ⚠️
pytensor/ifelse.py 60.00% 2 Missing ⚠️
pytensor/scan/op.py 87.50% 0 Missing and 2 partials ⚠️
pytensor/compile/debugmode.py 50.00% 1 Missing ⚠️
pytensor/link/basic.py 83.33% 1 Missing ⚠️
pytensor/link/utils.py 66.66% 1 Missing ⚠️
pytensor/printing.py 50.00% 1 Missing ⚠️
pytensor/scalar/basic.py 80.00% 1 Missing ⚠️
pytensor/tensor/blockwise.py 88.88% 1 Missing ⚠️
... and 2 more
Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main     #850   +/-   ##
=======================================
  Coverage   82.12%   82.12%           
=======================================
  Files         183      183           
  Lines       47986    47986           
  Branches     8644     8644           
=======================================
  Hits        39409    39409           
  Misses       6411     6411           
  Partials     2166     2166           
Files with missing lines Coverage Δ
pytensor/compile/builders.py 88.66% <100.00%> (ø)
pytensor/compile/function/pfunc.py 82.92% <100.00%> (ø)
pytensor/compile/function/types.py 80.48% <100.00%> (ø)
pytensor/gradient.py 77.57% <100.00%> (ø)
pytensor/graph/basic.py 88.69% <100.00%> (ø)
pytensor/graph/op.py 88.08% <ø> (ø)
pytensor/graph/replace.py 84.21% <100.00%> (ø)
pytensor/graph/rewriting/basic.py 70.43% <100.00%> (ø)
pytensor/link/c/basic.py 87.48% <100.00%> (ø)
pytensor/link/c/cmodule.py 60.48% <100.00%> (ø)
... and 59 more
---- 🚨 Try these New Features:

@jessegrabowski
Copy link
Member

Is this ready for review? Seems like all test are passing now

@Armavica
Copy link
Member Author

Armavica commented Jul 7, 2024

Is this ready for review? Seems like all test are passing now

There are still 11 10 instances of non-strict zips that produce errors if I make them strict, I need to investigate them one by one to see if that's expected behaviour or not.

@Armavica Armavica force-pushed the zip-strict branch 2 times, most recently from 5746335 to e3965ef Compare July 7, 2024 13:05
@Armavica
Copy link
Member Author

Armavica commented Jul 7, 2024

Actually, I find it difficult to make more progress here, so I am signalling this for review.
I added 464 easy strict=True, 18 less immediate ones, and there are still 10 strict=False that I find the most difficult to understand. I think that they could be handled in another PR. This one introduces 464+18 = 482 safeguards, which I think is a good score :)

@Armavica Armavica marked this pull request as ready for review July 7, 2024 13:36
@@ -93,7 +93,7 @@ def _validate_updates(
)
else:
update = outputs
for i, u in zip(init, update, strict=False):
for i, u in zip(init[: len(update)], update, strict=True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use strict=False here?

@@ -1745,7 +1745,7 @@ def setup_method(self):
self.random_stream = np.random.default_rng(utt.fetch_seed())

self.inputs_shapes = [(8, 1, 12, 12), (1, 1, 5, 5), (1, 1, 5, 6), (1, 1, 6, 6)]
self.filters_shapes = [(5, 1, 2, 2), (1, 1, 3, 3)]
self.filters_shapes = [(5, 1, 2, 2), (1, 1, 3, 3)] * 2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this a bug?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it was zipping input_shapes and filter_shapes together so the two last input_shapes were never being used in the tests.

@@ -648,7 +648,7 @@ def local_subtensor_of_alloc(fgraph, node):
# Slices to take from val
val_slices = []

for i, (sl, dim) in enumerate(zip(slices, dims, strict=False)):
for i, (sl, dim) in enumerate(zip(slices, dims[: len(slices)], strict=True)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like my previous comment, I find this less readable. The strict=False indicates clearly that we don't expect the sequences to necessarily have the same length?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But if they're not the same length why are we zipping them? Are we sure they're always ordered correctly?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because they are ordered correctly yes, and presumably what comes after doesn't matter. It's quite common in Subtensor operations / rewrites

Copy link
Member Author

@Armavica Armavica Jul 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mind reverting this, but just to argue a bit in the favor of strict=True, I think an additional advantage of this approach is that it makes it clearer which one of the two lists is supposed to be shorter. I personally find that I understand more about what is happening here when I read this version compared to strict=False.

Comment on lines 594 to 589
if len(shape) != x.type.ndim:
return _specify_shape(x, *shape)

new_shape_matches = all(
s == xts for (s, xts) in zip(shape, x.type.shape, strict=True) if s is not None
)
if new_shape_matches:
Copy link
Member

@ricardoV94 ricardoV94 Jul 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awkward, the use of strict=False seems fine

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am surprised, this really looks better to me. What do you think of:

    if len(shape) != x.type.ndim:
        return _specify_shape(x, *shape)

    if all(s in (None, xts) for (s, xts) in zip(shape, x.type.shape, strict=True)):

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My problem is the double call to SpecifyShape.

Also, we already established in the comment that if there's different lengths the function is going to raise so the strict=False follows naturally?

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Skimmed a bit. Generally good, so let's try to merge sooner rather than later. But there are several cases I want to push back, as they run on hot loops. Everything inside an Op.perform / Op.make_thunk (and linker / function counterpart). This is the code we run in every pytensor function call and we don't want to add any extra overhead. PyTensor function is already bad as is. I did not highlight everything, I can give another pass later.

Also need to investigate Numba/Torch as we don't want to introduce runtime checks in the compiled code, depending on what they do when they see the strict=True in the bytecode. JAX shouldn't be a problem since they just ignore any error checks at runtime by default.

@@ -999,7 +1003,7 @@ def __call__(self, *args, **kwargs):
# output reference from the internal storage cells
if getattr(self.vm, "allow_gc", False):
for o_container, o_variable in zip(
self.output_storage, self.maker.fgraph.outputs
self.output_storage, self.maker.fgraph.outputs, strict=True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not put any strict in the Function.__call__. This is a hot loop

@@ -853,5 +863,5 @@ def clone(self):
def perform(self, node, inputs, outputs):
variables = self.fn(*inputs)
assert len(variables) == len(outputs)
for output, variable in zip(outputs, variables):
for output, variable in zip(outputs, variables, strict=True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Len is asserted above, either remove strict or assert

@@ -301,7 +305,7 @@ def thunk():
if len(ls) > 0:
return ls
else:
for out, t in zip(outputs, input_true_branch):
for out, t in zip(outputs, input_true_branch, strict=True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thunk is a hot loop, remove strict

@@ -321,7 +325,7 @@ def thunk():
if len(ls) > 0:
return ls
else:
for out, f in zip(outputs, inputs_false_branch):
for out, f in zip(outputs, inputs_false_branch, strict=True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also here

pytensor/link/basic.py Show resolved Hide resolved
@@ -158,7 +158,7 @@ def advancedincsubtensor1_inplace(x, val, idxs):
def advancedincsubtensor1_inplace(x, vals, idxs):
if not len(idxs) == len(vals):
raise ValueError("The number of indices and values must match.")
for idx, val in zip(idxs, vals):
for idx, val in zip(idxs, vals, strict=True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For all numba dispatch, we have to check if strict=True adds overhead in the Jitted code. If not sure, better to remove.

@@ -34,7 +34,7 @@ def shape_i(x):
def pytorch_funcify_SpecifyShape(op, node, **kwargs):
def specifyshape(x, *shape):
assert x.ndim == len(shape)
for actual, expected in zip(x.shape, shape):
for actual, expected in zip(x.shape, shape, strict=True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even more critical for torch. There's also a risk it would introduce a graph break which we definitely do not want

pytensor/scalar/basic.py Show resolved Hide resolved
@@ -4324,7 +4328,7 @@ def make_node(self, *inputs):

def perform(self, node, inputs, output_storage):
outputs = self.py_perform_fn(*inputs)
for storage, out_val in zip(output_storage, outputs):
for storage, out_val in zip(output_storage, outputs, strict=True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hot loop. Everything inside a thunk / perform should be left as is

@@ -207,7 +207,7 @@ def perform(self, node, inputs, output_storage):
for i in range(n_steps):
carry = inner_fn(*carry, *constant)

for storage, out_val in zip(output_storage, carry):
for storage, out_val in zip(output_storage, carry, strict=True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hot loop

@Armavica Armavica force-pushed the zip-strict branch 2 times, most recently from 984adeb to 095a892 Compare November 19, 2024 08:40
@Armavica
Copy link
Member Author

@ricardoV94 Fixed all your points, let me know if there are more

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did an exhaustive pass, hopefully that's all

@@ -1009,7 +1014,7 @@ def __call__(self, *args, **kwargs):
if getattr(self.vm, "need_update_inputs", True):
# Update the inputs that have an update function
for input, storage in reversed(
list(zip(self.maker.expanded_inputs, input_storage))
list(zip(self.maker.expanded_inputs, input_storage, strict=True))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hot loop

@@ -1108,7 +1113,7 @@ def _pickle_Function(f):
input_storage = []

for (input, indices, inputs), (required, refeed, default) in zip(
f.indices, f.defaults
f.indices, f.defaults, strict=True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hot loop

@@ -1040,7 +1045,7 @@ def __call__(self, *args, **kwargs):
assert len(self.output_keys) == len(outputs)

if output_subset is None:
return dict(zip(self.output_keys, outputs))
return dict(zip(self.output_keys, outputs, strict=True))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hot loop

@@ -537,12 +539,13 @@ def make_thunk(self, **kwargs):

def f():
for inputs in input_lists[1:]:
for input1, input2 in zip(inputs0, inputs):
for input1, input2 in zip(inputs0, inputs, strict=True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hot loop

@@ -403,7 +403,7 @@ def py_perform_return(inputs):
def py_perform_return(inputs):
return tuple(
out_type.filter(out[0])
for out_type, out in zip(output_types, py_perform(inputs))
for out_type, out in zip(output_types, py_perform(inputs), strict=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hot loop

@@ -112,7 +112,9 @@ def broadcast_params(
for p in params:
param_shape = tuple(
1 if bcast else s
for s, bcast in zip(p.shape, getattr(p, "broadcastable", (False,) * p.ndim))
for s, bcast in zip(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hot loop

@@ -123,7 +125,8 @@ def broadcast_params(
broadcast_to_fn = broadcast_to if use_pytensor else np.broadcast_to

bcast_params = [
broadcast_to_fn(param, shape) for shape, param in zip(shapes, params)
broadcast_to_fn(param, shape)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hot loop

@@ -448,7 +448,9 @@ def perform(self, node, inp, out_):
raise AssertionError(
f"SpecifyShape: Got {x.ndim} dimensions (shape {x.shape}), expected {ndim} dimensions with shape {tuple(shape)}."
)
if not all(xs == s for xs, s in zip(x.shape, shape) if s is not None):
if not all(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hot loop

@@ -250,7 +250,7 @@ def filter(self, data, strict=False, allow_downcast=None) -> np.ndarray:

if not all(
ds == ts if ts is not None else True
for ds, ts in zip(data.shape, self.shape)
for ds, ts in zip(data.shape, self.shape, strict=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hot loop, because filter is called during function evaluation calls

@@ -325,7 +325,10 @@ def is_super(self, otype):
and otype.ndim == self.ndim
# `otype` is allowed to be as or more shape-specific than `self`,
# but not less
and all(sb == ob or sb is None for sb, ob in zip(self.shape, otype.shape))
and all(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

@Armavica Armavica force-pushed the zip-strict branch 2 times, most recently from 5cac2e6 to f8bc010 Compare November 19, 2024 12:51
@ricardoV94
Copy link
Member

Is it ready @Armavica ?

@Armavica
Copy link
Member Author

Yes I think I changed everything you pointed out, if not it's my mistake

@ricardoV94 ricardoV94 merged commit 4b41e09 into main Nov 19, 2024
61 of 62 checks passed
@ricardoV94 ricardoV94 deleted the zip-strict branch November 19, 2024 14:54
@ricardoV94
Copy link
Member

Thanks 🙏🙏

@Armavica
Copy link
Member Author

Thank you for your review!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants