Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disallow traceable tensor subclasses #1270

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from

Conversation

crcrpar
Copy link
Collaborator

@crcrpar crcrpar commented Oct 8, 2024

What does this PR do?

Conservatively errors out when any of tensor inputs are a traceable tensor subclass and any non-pytorch executor is enabled.
Ideally we should interpret, translate, and disable __torch_dispatch__ eventually.

@t-vi
Copy link
Collaborator

t-vi commented Oct 8, 2024

Conservatively errors out when any of tensor inputs are a traceable tensor subclass and any non-pytorch executor is enabled.

I think it would be good to just error out and not depend on the executors here. WDYT?

@crcrpar crcrpar marked this pull request as draft October 8, 2024 07:57
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Eligible tensor subclasses are one that implements both
`__tensor_flatten__` and `__tensor_unflatten__`.

Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
@crcrpar crcrpar force-pushed the crpa/ban-traceable-subclass-for-exs branch from 2b9b4c1 to 02341a0 Compare October 17, 2024 05:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants