Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] TensorDict with dynamic, input-dependent batch_size is not torch.export.exportable #1003

Open
3 tasks done
egaznep opened this issue Sep 20, 2024 · 4 comments · May be fixed by #1004
Open
3 tasks done

[BUG] TensorDict with dynamic, input-dependent batch_size is not torch.export.exportable #1003

egaznep opened this issue Sep 20, 2024 · 4 comments · May be fixed by #1004
Assignees
Labels
bug Something isn't working

Comments

@egaznep
Copy link
Contributor

egaznep commented Sep 20, 2024

Describe the bug

To Reproduce

import torch
import tensordict

class Test(torch.nn.Module):
    def forward(self, x: torch.Tensor, y: torch.Tensor):
        return tensordict.TensorDict({
            'x': x,
            'y': y,
            },
            batch_size=x.shape[0] # comment this line out and it works, but batch_size = [] and not x.shape[0]
         )
    
test = Test()
result = torch.export.export(
    test, 
    args=(torch.zeros(2,100), torch.zeros(2,100)),
    strict=False,
    dynamic_shapes={
        'x': {0: torch.export.Dim('batch'), 1: torch.export.Dim('time')},
        'y': {0: torch.export.Dim('batch'), 1: torch.export.Dim('time')}
        }
    )
print(result.module()(torch.zeros(5,100), torch.zeros(5,100)))
(myenv) egaznep@...@volta: $ python scripts/mwe.py                                                                                                                                  
Traceback (most recent call last):                                                                                                                                                                                              
  File "./myenv/lib/python3.10/site-packages/tensordict/_td.py", line 2082, in _parse_batch_size                                                                                               
    return torch.Size(batch_size)                                                                                                                                                                                               
TypeError: 'SymInt' object is not iterable

During handling of the above exception, another exception occurred:

Traceback (without unrelated parts):
  File "./scripts/mwe.py", line 6, in forward
    return tensordict.TensorDict({
  File "./myenv/lib/python3.10/site-packages/tensordict/_td.py", line 285, in __init__
    self._batch_size = self._parse_batch_size(source, batch_size)
  File "./myenv/lib/python3.10/site-packages/tensordict/_td.py", line 2090, in _parse_batch_size
    raise ValueError(ERR)
ValueError: batch size was not specified when creating the TensorDict instance and it could not be retrieved from source.

Expected behavior

I would expect the module to be successfully converted into an ExportedProgram.

System info

Describe the characteristic of your environment:

  • Describe how the library was installed (pip, source, ...)
  • Python version
  • Versions of any other relevant libraries
print(tensordict.__version__, numpy.__version__, sys.version, sys.platform, torch.__version__)
2024.09.19 2.0.2 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0] linux 2.6.0.dev20240919

Reason and Possible fixes

If you know or suspect the reason for this bug, paste the code lines and suggest modifications.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@egaznep egaznep added the bug Something isn't working label Sep 20, 2024
@vmoens
Copy link
Contributor

vmoens commented Sep 20, 2024

Good timing! I'm just merging PRs related to export now.
Have you tried with nightlies?

EDIT
Oh sorry just saw that this was nightly already. Weird!
I'll fix that

@vmoens
Copy link
Contributor

vmoens commented Sep 20, 2024

After further investigation, this works

batch_size=x.shape[:1]

but using x.shape[0] fails to be detected as a number here

elif isinstance(batch_size, Number):

I'll dig into this!

@egaznep
Copy link
Contributor Author

egaznep commented Sep 20, 2024

Good timing! I'm just merging PRs related to export now. Have you tried with nightlies?

EDIT Oh sorry just saw that this was nightly already. Weird! I'll fix that

Happy to hear that, I really appreciate the time and effort spent on this library. It's making my development process much neater and easier. Yes I saw the activity and with some hope tried the nightly to see if it would resolve my issues.

After further investigation, this works

batch_size=x.shape[:1]

but using x.shape[0] fails to be detected as a number here

elif isinstance(batch_size, Number):

I'll dig into this!

How did I not think of doing that? 😆 But my happiness was short-lived since with x.shape[:1] I still get the following output (wondering if you got the same too):

TensorDict(
    fields={
        x: Tensor(shape=torch.Size([5, 100]), device=cpu, dtype=torch.float32, is_shared=False),
        y: Tensor(shape=torch.Size([5, 100]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

where the batch_size is still not [5], and also some warnings regarding shape mismatches

W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172] Ignored guard Eq(s0, 5) == False, this could result in accuracy problems
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172] Stack (most recent call last):
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172]   File "./scripts/mwe.py", line 23, in <module>
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172]     print(result.module()(torch.zeros(5,100), torch.zeros(5,100)))
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172]   File "./myenv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 784, in call_wrapped
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172]     return self._wrapped_call(self, *args, **kwargs)
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172]   File "./myenv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 348, in __call__
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172]     return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172]   File "./myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172]     return self._call_impl(*args, **kwargs)
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172]   File "./myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172]     return inner()
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172]   File "./myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1790, in inner
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172]     result = forward_call(*args, **kwargs)
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172]   File "<eval_with_key>.9", line 6, in forward
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172]     return pytree.tree_unflatten((x, y), self._out_spec)
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172]   File "./myenv/lib/python3.10/site-packages/torch/utils/_pytree.py", line 887, in tree_unflatten
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172]     return treespec.unflatten(leaves)
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172]   File "./myenv/lib/python3.10/site-packages/torch/utils/_pytree.py", line 824, in unflatten
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172]     return unflatten_fn(child_pytrees, self.context)
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172]   File "./myenv/lib/python3.10/site-packages/tensordict/_pytree.py", line 131, in _tensordict_unflatten
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172]     if any(_shape(tensor)[:batch_dims] != batch_size for tensor in values):
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172]   File "./myenv/lib/python3.10/site-packages/tensordict/_pytree.py", line 131, in <genexpr>
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172]     if any(_shape(tensor)[:batch_dims] != batch_size for tensor in values):
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172]   File "./myenv/lib/python3.10/site-packages/torch/__init__.py", line 680, in __bool__
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172]     return self.node.bool_()
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172]   File "./myenv/lib/python3.10/site-packages/torch/fx/experimental/sym_node.py", line 511, in bool_
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172]     return self.guard_bool("", 0)
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172]   File "./myenv/lib/python3.10/site-packages/torch/fx/experimental/sym_node.py", line 449, in guard_bool
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172]     r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172]   File "./myenv/lib/python3.10/site-packages/torch/fx/experimental/recording.py", line 262, in wrapper
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172]     return retlog(fn(*args, **kwargs))
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172]   File "./myenv/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 5243, in evaluate_expr
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172]     return self._evaluate_expr(orig_expr, hint, fx_node, size_oblivious, forcing_spec=forcing_spec)
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172]   File "./myenv/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 5369, in _evaluate_expr
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172]     self._check_frozen(expr, concrete_val)
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172]   File "./myenv/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 5172, in _check_frozen
W0920 18:01:47.759000 279364 site-packages/torch/fx/experimental/symbolic_shapes.py:5172]     log.warning("Ignored guard %s == %s, this could result in accuracy problems", expr, concrete_val, stack_info=True)```

@vmoens vmoens linked a pull request Sep 20, 2024 that will close this issue
@vmoens
Copy link
Contributor

vmoens commented Sep 20, 2024

Ah that's interesting let me check that one!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants