Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add shape_unsafe tag to rewrites that can hide shape errors #381

Merged
merged 9 commits into from
Aug 7, 2023

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Jul 13, 2023

  • Remove BroadcastTo in favor of Alloc
  • Address spoken inconsistencies between Second / Alloc in rewrites
  • Add tag to easily exclude rewrites that can hide shape errors
  • Simplify such rewrites
  • Improve static shape of Alloc

Note that from a user standpoint, providing static shapes (via vector("x", shape=(5,)) or specify_shapes) will many times reveal shape errors immediately (this is the case for 99% of PyMC models). In this case users should feel pretty safe about "shape_unsafe" rewrites because they aren't really masking anything that wasn't checked before already.

Alloc.make_node now also raises early when it can see the provided shape is inconsistent. Alloc and Elemwise make up all of the tagged "shape_unsafe" rewrites so far.

With this PR, users can also do mode=get_default_mode().excluding("shape_unsafe") or add shape_unsafe to the excluding config to skip these rewrites at the cost of less optimizations.

Closes #367

@ricardoV94 ricardoV94 changed the title Remove BroadcastTo Remove BroadcastTo and add shape_unsafe tag to rewrites that make shape assumptions Jul 13, 2023
@ricardoV94 ricardoV94 changed the title Remove BroadcastTo and add shape_unsafe tag to rewrites that make shape assumptions Remove BroadcastTo and add shape_unsafe tag to rewrites that can hide shape errors Jul 13, 2023
@ricardoV94 ricardoV94 changed the title Remove BroadcastTo and add shape_unsafe tag to rewrites that can hide shape errors Add shape_unsafe tag to rewrites that can hide shape errors Jul 13, 2023
@@ -1757,7 +1616,19 @@ def broadcast_arrays(*args: TensorVariable) -> Tuple[TensorVariable, ...]:
The arrays to broadcast.

"""
return tuple(broadcast_to(a, broadcast_shape(*args)) for a in args)

def broadcast_with_others(a, others):
Copy link
Member Author

@ricardoV94 ricardoV94 Jul 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We discussed with @aseyboldt that it may make sense to generalize Second so that it accepts arbitrary many inputs and returns every variable as output. This would become a flat broadcast_arrays once Elemwised, and make rewrites easier to read. By overriding the __str__ we can also make it much more readable in debug_print than the current nested Second

@codecov-commenter
Copy link

Codecov Report

Merging #381 (2a3adbe) into main (7218431) will decrease coverage by 0.03%.
The diff coverage is 90.90%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #381      +/-   ##
==========================================
- Coverage   80.44%   80.41%   -0.03%     
==========================================
  Files         156      156              
  Lines       45470    45413      -57     
  Branches    11136    11119      -17     
==========================================
- Hits        36578    36520      -58     
- Misses       6687     6693       +6     
+ Partials     2205     2200       -5     
Impacted Files Coverage Δ
pytensor/link/jax/dispatch/extra_ops.py 72.58% <ø> (-3.48%) ⬇️
pytensor/link/numba/dispatch/extra_ops.py 91.78% <ø> (-0.47%) ⬇️
pytensor/tensor/rewriting/elemwise.py 88.99% <ø> (ø)
pytensor/tensor/rewriting/math.py 86.32% <72.72%> (-0.05%) ⬇️
pytensor/tensor/rewriting/basic.py 94.01% <95.74%> (+0.48%) ⬆️
pytensor/configdefaults.py 65.92% <100.00%> (-0.10%) ⬇️
pytensor/tensor/basic.py 90.82% <100.00%> (+0.05%) ⬆️
pytensor/tensor/extra_ops.py 88.53% <100.00%> (-0.48%) ⬇️
pytensor/tensor/rewriting/extra_ops.py 88.23% <100.00%> (-1.02%) ⬇️

... and 9 files with indirect coverage changes

@ricardoV94 ricardoV94 marked this pull request as ready for review July 14, 2023 12:43
@ricardoV94
Copy link
Member Author

Fixing #379 should also help with the "unsafety" concerns

pytensor/tensor/basic.py Outdated Show resolved Hide resolved
@@ -1561,141 +1561,6 @@ def broadcast_shape_iter(
return tuple(result_dims)


class BroadcastTo(COp):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BroadcastTo is imported in pymc a couple of times. Maybe we should leave an empty Op here, that is deprecated and doesn't do anything?

Copy link
Member Author

@ricardoV94 ricardoV94 Aug 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some of the removed rewrites are also directly imported.

This shouldn't be a problem however. I marked this PR as a major release so we will bump the version above the upper-bound pinned by PyMC. When we update the pin on PyMC I'll address the changes. They require some manual review anyway to see if the logic that depended on BroadcastTo was valid per our new rules and can be transferred to Alloc.

This was all on the logprob inference module AFAICT so impact should be pretty contained.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good

@aseyboldt
Copy link
Member

Other than the two suggestions above this looks good :-)

@ricardoV94 ricardoV94 merged commit c6b0858 into pymc-devs:main Aug 7, 2023
51 of 52 checks passed
@ricardoV94 ricardoV94 deleted the cleanup_broadcast branch August 7, 2023 12:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Alloc vs BroadcastTo vs Second
3 participants