-
Notifications
You must be signed in to change notification settings - Fork 109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make all zips explicitly strict or non-strict #850
Conversation
ada9880
to
a5de1b6
Compare
@Armavica may be crazy work, but can we get a separate commit where we make the non-strict zips. That way it's easier to evaluate if it sounds correct or may be a bug somewhere? Or I guess I can just ctrl+f for it |
@ricardoV94 Yes, I was planning to present this PR(s) in several steps:
|
Sounds good @Armavica |
dc0aa6e
to
36868a5
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #850 +/- ##
=======================================
Coverage 82.12% 82.12%
=======================================
Files 183 183
Lines 47986 47986
Branches 8644 8644
=======================================
Hits 39409 39409
Misses 6411 6411
Partials 2166 2166
|
Is this ready for review? Seems like all test are passing now |
There are still |
5746335
to
e3965ef
Compare
Actually, I find it difficult to make more progress here, so I am signalling this for review. |
pytensor/scalar/loop.py
Outdated
@@ -93,7 +93,7 @@ def _validate_updates( | |||
) | |||
else: | |||
update = outputs | |||
for i, u in zip(init, update, strict=False): | |||
for i, u in zip(init[: len(update)], update, strict=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use strict=False
here?
@@ -1745,7 +1745,7 @@ def setup_method(self): | |||
self.random_stream = np.random.default_rng(utt.fetch_seed()) | |||
|
|||
self.inputs_shapes = [(8, 1, 12, 12), (1, 1, 5, 5), (1, 1, 5, 6), (1, 1, 6, 6)] | |||
self.filters_shapes = [(5, 1, 2, 2), (1, 1, 3, 3)] | |||
self.filters_shapes = [(5, 1, 2, 2), (1, 1, 3, 3)] * 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was this a bug?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it was zipping input_shapes and filter_shapes together so the two last input_shapes were never being used in the tests.
@@ -648,7 +648,7 @@ def local_subtensor_of_alloc(fgraph, node): | |||
# Slices to take from val | |||
val_slices = [] | |||
|
|||
for i, (sl, dim) in enumerate(zip(slices, dims, strict=False)): | |||
for i, (sl, dim) in enumerate(zip(slices, dims[: len(slices)], strict=True)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like my previous comment, I find this less readable. The strict=False
indicates clearly that we don't expect the sequences to necessarily have the same length?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But if they're not the same length why are we zipping them? Are we sure they're always ordered correctly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because they are ordered correctly yes, and presumably what comes after doesn't matter. It's quite common in Subtensor operations / rewrites
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't mind reverting this, but just to argue a bit in the favor of strict=True
, I think an additional advantage of this approach is that it makes it clearer which one of the two lists is supposed to be shorter. I personally find that I understand more about what is happening here when I read this version compared to strict=False
.
pytensor/tensor/shape.py
Outdated
if len(shape) != x.type.ndim: | ||
return _specify_shape(x, *shape) | ||
|
||
new_shape_matches = all( | ||
s == xts for (s, xts) in zip(shape, x.type.shape, strict=True) if s is not None | ||
) | ||
if new_shape_matches: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is awkward, the use of strict=False
seems fine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am surprised, this really looks better to me. What do you think of:
if len(shape) != x.type.ndim:
return _specify_shape(x, *shape)
if all(s in (None, xts) for (s, xts) in zip(shape, x.type.shape, strict=True)):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My problem is the double call to SpecifyShape.
Also, we already established in the comment that if there's different lengths the function is going to raise so the strict=False follows naturally?
6a15a26
to
9007ce1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Skimmed a bit. Generally good, so let's try to merge sooner rather than later. But there are several cases I want to push back, as they run on hot loops. Everything inside an Op.perform
/ Op.make_thunk
(and linker / function counterpart). This is the code we run in every pytensor function call and we don't want to add any extra overhead. PyTensor function is already bad as is. I did not highlight everything, I can give another pass later.
Also need to investigate Numba/Torch as we don't want to introduce runtime checks in the compiled code, depending on what they do when they see the strict=True
in the bytecode. JAX shouldn't be a problem since they just ignore any error checks at runtime by default.
pytensor/compile/function/types.py
Outdated
@@ -999,7 +1003,7 @@ def __call__(self, *args, **kwargs): | |||
# output reference from the internal storage cells | |||
if getattr(self.vm, "allow_gc", False): | |||
for o_container, o_variable in zip( | |||
self.output_storage, self.maker.fgraph.outputs | |||
self.output_storage, self.maker.fgraph.outputs, strict=True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not put any strict in the Function.__call__
. This is a hot loop
pytensor/compile/builders.py
Outdated
@@ -853,5 +863,5 @@ def clone(self): | |||
def perform(self, node, inputs, outputs): | |||
variables = self.fn(*inputs) | |||
assert len(variables) == len(outputs) | |||
for output, variable in zip(outputs, variables): | |||
for output, variable in zip(outputs, variables, strict=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Len is asserted above, either remove strict or assert
pytensor/ifelse.py
Outdated
@@ -301,7 +305,7 @@ def thunk(): | |||
if len(ls) > 0: | |||
return ls | |||
else: | |||
for out, t in zip(outputs, input_true_branch): | |||
for out, t in zip(outputs, input_true_branch, strict=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thunk is a hot loop, remove strict
pytensor/ifelse.py
Outdated
@@ -321,7 +325,7 @@ def thunk(): | |||
if len(ls) > 0: | |||
return ls | |||
else: | |||
for out, f in zip(outputs, inputs_false_branch): | |||
for out, f in zip(outputs, inputs_false_branch, strict=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also here
@@ -158,7 +158,7 @@ def advancedincsubtensor1_inplace(x, val, idxs): | |||
def advancedincsubtensor1_inplace(x, vals, idxs): | |||
if not len(idxs) == len(vals): | |||
raise ValueError("The number of indices and values must match.") | |||
for idx, val in zip(idxs, vals): | |||
for idx, val in zip(idxs, vals, strict=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For all numba dispatch, we have to check if strict=True adds overhead in the Jitted code. If not sure, better to remove.
@@ -34,7 +34,7 @@ def shape_i(x): | |||
def pytorch_funcify_SpecifyShape(op, node, **kwargs): | |||
def specifyshape(x, *shape): | |||
assert x.ndim == len(shape) | |||
for actual, expected in zip(x.shape, shape): | |||
for actual, expected in zip(x.shape, shape, strict=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even more critical for torch. There's also a risk it would introduce a graph break which we definitely do not want
pytensor/scalar/basic.py
Outdated
@@ -4324,7 +4328,7 @@ def make_node(self, *inputs): | |||
|
|||
def perform(self, node, inputs, output_storage): | |||
outputs = self.py_perform_fn(*inputs) | |||
for storage, out_val in zip(output_storage, outputs): | |||
for storage, out_val in zip(output_storage, outputs, strict=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hot loop. Everything inside a thunk / perform should be left as is
pytensor/scalar/loop.py
Outdated
@@ -207,7 +207,7 @@ def perform(self, node, inputs, output_storage): | |||
for i in range(n_steps): | |||
carry = inner_fn(*carry, *constant) | |||
|
|||
for storage, out_val in zip(output_storage, carry): | |||
for storage, out_val in zip(output_storage, carry, strict=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hot loop
984adeb
to
095a892
Compare
@ricardoV94 Fixed all your points, let me know if there are more |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did an exhaustive pass, hopefully that's all
pytensor/compile/function/types.py
Outdated
@@ -1009,7 +1014,7 @@ def __call__(self, *args, **kwargs): | |||
if getattr(self.vm, "need_update_inputs", True): | |||
# Update the inputs that have an update function | |||
for input, storage in reversed( | |||
list(zip(self.maker.expanded_inputs, input_storage)) | |||
list(zip(self.maker.expanded_inputs, input_storage, strict=True)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hot loop
pytensor/compile/function/types.py
Outdated
@@ -1108,7 +1113,7 @@ def _pickle_Function(f): | |||
input_storage = [] | |||
|
|||
for (input, indices, inputs), (required, refeed, default) in zip( | |||
f.indices, f.defaults | |||
f.indices, f.defaults, strict=True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hot loop
pytensor/compile/function/types.py
Outdated
@@ -1040,7 +1045,7 @@ def __call__(self, *args, **kwargs): | |||
assert len(self.output_keys) == len(outputs) | |||
|
|||
if output_subset is None: | |||
return dict(zip(self.output_keys, outputs)) | |||
return dict(zip(self.output_keys, outputs, strict=True)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hot loop
pytensor/link/basic.py
Outdated
@@ -537,12 +539,13 @@ def make_thunk(self, **kwargs): | |||
|
|||
def f(): | |||
for inputs in input_lists[1:]: | |||
for input1, input2 in zip(inputs0, inputs): | |||
for input1, input2 in zip(inputs0, inputs, strict=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hot loop
@@ -403,7 +403,7 @@ def py_perform_return(inputs): | |||
def py_perform_return(inputs): | |||
return tuple( | |||
out_type.filter(out[0]) | |||
for out_type, out in zip(output_types, py_perform(inputs)) | |||
for out_type, out in zip(output_types, py_perform(inputs), strict=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hot loop
@@ -112,7 +112,9 @@ def broadcast_params( | |||
for p in params: | |||
param_shape = tuple( | |||
1 if bcast else s | |||
for s, bcast in zip(p.shape, getattr(p, "broadcastable", (False,) * p.ndim)) | |||
for s, bcast in zip( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hot loop
@@ -123,7 +125,8 @@ def broadcast_params( | |||
broadcast_to_fn = broadcast_to if use_pytensor else np.broadcast_to | |||
|
|||
bcast_params = [ | |||
broadcast_to_fn(param, shape) for shape, param in zip(shapes, params) | |||
broadcast_to_fn(param, shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hot loop
@@ -448,7 +448,9 @@ def perform(self, node, inp, out_): | |||
raise AssertionError( | |||
f"SpecifyShape: Got {x.ndim} dimensions (shape {x.shape}), expected {ndim} dimensions with shape {tuple(shape)}." | |||
) | |||
if not all(xs == s for xs, s in zip(x.shape, shape) if s is not None): | |||
if not all( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hot loop
pytensor/tensor/type.py
Outdated
@@ -250,7 +250,7 @@ def filter(self, data, strict=False, allow_downcast=None) -> np.ndarray: | |||
|
|||
if not all( | |||
ds == ts if ts is not None else True | |||
for ds, ts in zip(data.shape, self.shape) | |||
for ds, ts in zip(data.shape, self.shape, strict=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hot loop, because filter is called during function evaluation calls
@@ -325,7 +325,10 @@ def is_super(self, otype): | |||
and otype.ndim == self.ndim | |||
# `otype` is allowed to be as or more shape-specific than `self`, | |||
# but not less | |||
and all(sb == ob or sb is None for sb, ob in zip(self.shape, otype.shape)) | |||
and all( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
5cac2e6
to
f8bc010
Compare
Is it ready @Armavica ? |
Yes I think I changed everything you pointed out, if not it's my mistake |
Thanks 🙏🙏 |
Thank you for your review! |
Description
strict=True
argument to all zips when it doesn't produce mistakes in the test suite (464 of them), andstrict=False
to the others (28 of them)There remains 10 non-strict zips that I find difficult to understand.
Related Issue
Checklist
Type of change