-
Notifications
You must be signed in to change notification settings - Fork 76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] TensorDict
with dynamic, input-dependent batch_size
is not torch.export.export
able
#1003
Comments
Good timing! I'm just merging PRs related to export now. EDIT |
After further investigation, this works
but using Line 2086 in 1b6f451
I'll dig into this! |
Happy to hear that, I really appreciate the time and effort spent on this library. It's making my development process much neater and easier. Yes I saw the activity and with some hope tried the nightly to see if it would resolve my issues.
How did I not think of doing that? 😆 But my happiness was short-lived since with TensorDict(
fields={
x: Tensor(shape=torch.Size([5, 100]), device=cpu, dtype=torch.float32, is_shared=False),
y: Tensor(shape=torch.Size([5, 100]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False) where the W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172] Ignored guard Eq(s0, 5) == False, this could result in accuracy problems
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172] Stack (most recent call last):
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172] File "./scripts/mwe.py", line 23, in <module>
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172] print(result.module()(torch.zeros(5,100), torch.zeros(5,100)))
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172] File "./myenv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 784, in call_wrapped
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172] return self._wrapped_call(self, *args, **kwargs)
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172] File "./myenv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 348, in __call__
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172] return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172] File "./myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172] return self._call_impl(*args, **kwargs)
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172] File "./myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172] return inner()
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172] File "./myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1790, in inner
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172] result = forward_call(*args, **kwargs)
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172] File "<eval_with_key>.9", line 6, in forward
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172] return pytree.tree_unflatten((x, y), self._out_spec)
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172] File "./myenv/lib/python3.10/site-packages/torch/utils/_pytree.py", line 887, in tree_unflatten
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172] return treespec.unflatten(leaves)
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172] File "./myenv/lib/python3.10/site-packages/torch/utils/_pytree.py", line 824, in unflatten
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172] return unflatten_fn(child_pytrees, self.context)
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172] File "./myenv/lib/python3.10/site-packages/tensordict/_pytree.py", line 131, in _tensordict_unflatten
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172] if any(_shape(tensor)[:batch_dims] != batch_size for tensor in values):
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172] File "./myenv/lib/python3.10/site-packages/tensordict/_pytree.py", line 131, in <genexpr>
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172] if any(_shape(tensor)[:batch_dims] != batch_size for tensor in values):
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172] File "./myenv/lib/python3.10/site-packages/torch/__init__.py", line 680, in __bool__
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172] return self.node.bool_()
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172] File "./myenv/lib/python3.10/site-packages/torch/fx/experimental/sym_node.py", line 511, in bool_
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172] return self.guard_bool("", 0)
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172] File "./myenv/lib/python3.10/site-packages/torch/fx/experimental/sym_node.py", line 449, in guard_bool
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172] r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172] File "./myenv/lib/python3.10/site-packages/torch/fx/experimental/recording.py", line 262, in wrapper
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172] return retlog(fn(*args, **kwargs))
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172] File "./myenv/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 5243, in evaluate_expr
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172] return self._evaluate_expr(orig_expr, hint, fx_node, size_oblivious, forcing_spec=forcing_spec)
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172] File "./myenv/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 5369, in _evaluate_expr
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172] self._check_frozen(expr, concrete_val)
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172] File "./myenv/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 5172, in _check_frozen
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172] log.warning("Ignored guard %s == %s, this could result in accuracy problems", expr, concrete_val, stack_info=True)``` |
Ah that's interesting let me check that one! |
Describe the bug
To Reproduce
Expected behavior
I would expect the module to be successfully converted into an
ExportedProgram
.System info
Describe the characteristic of your environment:
Reason and Possible fixes
If you know or suspect the reason for this bug, paste the code lines and suggest modifications.
Checklist
The text was updated successfully, but these errors were encountered: