-
Notifications
You must be signed in to change notification settings - Fork 76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] PyTree compatibility #501
Conversation
tensordict/_pytree.py
Outdated
if not str_spec.startswith("D"): | ||
return None | ||
context_strings, child_strings = _str_to_dict(str_spec) | ||
return TensorDict, context_strings, child_strings |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @zou3519
We're going to delete these APIs soon, you don't have to implement it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed I saw that the tests were breaking, I guess it's because of that
Looks reasonable! @zou3519 for the non-monkeypatched vmap support, maybe we need some way for non-Tensor objects to say "hey, I need special batch dim handling"? |
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_plain_set_nested | 44.6010μs | 20.2688μs | 49.3369 KOps/s | 49.8803 KOps/s | |
test_plain_set_stack_nested | 0.2301ms | 0.1883ms | 5.3120 KOps/s | 5.3984 KOps/s | |
test_plain_set_nested_inplace | 59.3010μs | 23.4181μs | 42.7020 KOps/s | 42.5577 KOps/s | |
test_plain_set_stack_nested_inplace | 0.2612ms | 0.2210ms | 4.5255 KOps/s | 4.4899 KOps/s | |
test_items | 41.1010μs | 3.4122μs | 293.0650 KOps/s | 289.7257 KOps/s | |
test_items_nested | 0.5769ms | 0.3669ms | 2.7258 KOps/s | 2.7219 KOps/s | |
test_items_nested_locked | 4.3453ms | 0.3875ms | 2.5805 KOps/s | 2.7379 KOps/s | |
test_items_nested_leaf | 0.2606ms | 0.2237ms | 4.4703 KOps/s | 4.2662 KOps/s | |
test_items_stack_nested | 2.1157ms | 1.9783ms | 505.4859 Ops/s | 504.2864 Ops/s | |
test_items_stack_nested_leaf | 1.9003ms | 1.7995ms | 555.7114 Ops/s | 554.8192 Ops/s | |
test_items_stack_nested_locked | 1.0379ms | 0.9743ms | 1.0264 KOps/s | 989.9899 Ops/s | |
test_keys | 32.7010μs | 5.0751μs | 197.0395 KOps/s | 196.4536 KOps/s | |
test_keys_nested | 1.1068ms | 0.1823ms | 5.4846 KOps/s | 5.3869 KOps/s | |
test_keys_nested_locked | 0.2608ms | 0.1807ms | 5.5343 KOps/s | 5.4840 KOps/s | |
test_keys_nested_leaf | 0.3980ms | 0.1737ms | 5.7567 KOps/s | 5.1978 KOps/s | |
test_keys_stack_nested | 1.8710ms | 1.7556ms | 569.5995 Ops/s | 564.0670 Ops/s | |
test_keys_stack_nested_leaf | 1.8362ms | 1.7523ms | 570.6943 Ops/s | 566.4699 Ops/s | |
test_keys_stack_nested_locked | 0.8207ms | 0.7493ms | 1.3346 KOps/s | 1.3301 KOps/s | |
test_values | 32.9000μs | 1.5546μs | 643.2658 KOps/s | 651.8493 KOps/s | |
test_values_nested | 0.1091ms | 66.8753μs | 14.9532 KOps/s | 15.0137 KOps/s | |
test_values_nested_locked | 0.1365ms | 67.1748μs | 14.8865 KOps/s | 15.1439 KOps/s | |
test_values_nested_leaf | 0.1275ms | 59.7038μs | 16.7493 KOps/s | 16.9419 KOps/s | |
test_values_stack_nested | 1.6632ms | 1.5928ms | 627.8230 Ops/s | 627.0963 Ops/s | |
test_values_stack_nested_leaf | 1.6554ms | 1.5836ms | 631.4562 Ops/s | 632.6455 Ops/s | |
test_values_stack_nested_locked | 0.7407ms | 0.6447ms | 1.5511 KOps/s | 1.5537 KOps/s | |
test_membership | 30.7010μs | 1.8376μs | 544.1747 KOps/s | 537.6081 KOps/s | |
test_membership_nested | 32.2010μs | 3.5907μs | 278.4950 KOps/s | 282.6569 KOps/s | |
test_membership_nested_leaf | 25.7000μs | 3.5279μs | 283.4548 KOps/s | 281.5054 KOps/s | |
test_membership_stacked_nested | 71.6010μs | 14.2601μs | 70.1258 KOps/s | 71.1296 KOps/s | |
test_membership_stacked_nested_leaf | 51.6000μs | 14.2674μs | 70.0897 KOps/s | 70.7645 KOps/s | |
test_membership_nested_last | 64.1010μs | 7.5723μs | 132.0595 KOps/s | 132.2411 KOps/s | |
test_membership_nested_leaf_last | 40.2000μs | 7.4535μs | 134.1645 KOps/s | 132.9676 KOps/s | |
test_membership_stacked_nested_last | 0.2539ms | 0.2245ms | 4.4540 KOps/s | 4.3544 KOps/s | |
test_membership_stacked_nested_leaf_last | 45.5000μs | 16.7499μs | 59.7019 KOps/s | 60.7078 KOps/s | |
test_nested_getleaf | 45.1000μs | 15.6577μs | 63.8665 KOps/s | 63.7428 KOps/s | |
test_nested_get | 42.1000μs | 14.8558μs | 67.3138 KOps/s | 67.2367 KOps/s | |
test_stacked_getleaf | 1.0263ms | 0.8760ms | 1.1415 KOps/s | 1.1551 KOps/s | |
test_stacked_get | 0.8810ms | 0.8364ms | 1.1956 KOps/s | 1.2028 KOps/s | |
test_nested_getitemleaf | 44.7000μs | 15.7353μs | 63.5515 KOps/s | 63.5436 KOps/s | |
test_nested_getitem | 71.2010μs | 14.8488μs | 67.3454 KOps/s | 66.8066 KOps/s | |
test_stacked_getitemleaf | 1.0428ms | 0.8745ms | 1.1435 KOps/s | 1.1479 KOps/s | |
test_stacked_getitem | 0.8794ms | 0.8335ms | 1.1997 KOps/s | 1.1969 KOps/s | |
test_lock_nested | 95.3401ms | 1.5180ms | 658.7729 Ops/s | 705.4411 Ops/s | |
test_lock_stack_nested | 0.1160s | 21.5536ms | 46.3959 Ops/s | 50.4398 Ops/s | |
test_unlock_nested | 92.0168ms | 1.5342ms | 651.8134 Ops/s | 654.3776 Ops/s | |
test_unlock_stack_nested | 0.1166s | 22.1299ms | 45.1877 Ops/s | 48.5799 Ops/s | |
test_flatten_speed | 1.0695ms | 1.0216ms | 978.8857 Ops/s | 994.6758 Ops/s | |
test_unflatten_speed | 1.8929ms | 1.8450ms | 542.0056 Ops/s | 552.8322 Ops/s | |
test_common_ops | 1.3877ms | 1.1042ms | 905.6108 Ops/s | 907.3379 Ops/s | |
test_creation | 34.0000μs | 6.2189μs | 160.7989 KOps/s | 161.4770 KOps/s | |
test_creation_empty | 45.8010μs | 14.0264μs | 71.2944 KOps/s | 73.1127 KOps/s | |
test_creation_nested_1 | 61.4010μs | 25.2920μs | 39.5383 KOps/s | 39.9138 KOps/s | |
test_creation_nested_2 | 63.2010μs | 27.8803μs | 35.8676 KOps/s | 36.2566 KOps/s | |
test_clone | 0.2116ms | 24.5299μs | 40.7666 KOps/s | 40.0808 KOps/s | |
test_getitem[int] | 0.1318ms | 27.1777μs | 36.7948 KOps/s | 36.3342 KOps/s | |
test_getitem[slice_int] | 0.1286ms | 53.0185μs | 18.8614 KOps/s | 18.6978 KOps/s | |
test_getitem[range] | 0.1216ms | 81.4685μs | 12.2747 KOps/s | 12.3332 KOps/s | |
test_getitem[tuple] | 89.7010μs | 44.3156μs | 22.5654 KOps/s | 22.1311 KOps/s | |
test_getitem[list] | 0.4586ms | 77.7527μs | 12.8613 KOps/s | 13.0758 KOps/s | |
test_setitem_dim[int] | 57.2010μs | 32.7084μs | 30.5732 KOps/s | 30.8314 KOps/s | |
test_setitem_dim[slice_int] | 87.8010μs | 57.7233μs | 17.3240 KOps/s | 17.4016 KOps/s | |
test_setitem_dim[range] | 0.1257ms | 78.5358μs | 12.7331 KOps/s | 12.6851 KOps/s | |
test_setitem_dim[tuple] | 89.5010μs | 48.1727μs | 20.7586 KOps/s | 20.8239 KOps/s | |
test_setitem | 0.2418ms | 32.8247μs | 30.4649 KOps/s | 31.2731 KOps/s | |
test_set | 0.2017ms | 31.4842μs | 31.7620 KOps/s | 32.4902 KOps/s | |
test_set_shared | 0.3877ms | 0.1799ms | 5.5589 KOps/s | 5.5508 KOps/s | |
test_update | 0.2070ms | 35.5973μs | 28.0920 KOps/s | 28.2225 KOps/s | |
test_update_nested | 0.3075ms | 52.8315μs | 18.9281 KOps/s | 19.1800 KOps/s | |
test_set_nested | 0.2068ms | 34.6172μs | 28.8874 KOps/s | 29.2524 KOps/s | |
test_set_nested_new | 0.2560ms | 52.9621μs | 18.8814 KOps/s | 18.9136 KOps/s | |
test_select | 0.2704ms | 96.9350μs | 10.3162 KOps/s | 10.3404 KOps/s | |
test_unbind_speed | 0.7058ms | 0.6462ms | 1.5475 KOps/s | 1.5463 KOps/s | |
test_unbind_speed_stack0 | 0.1064s | 9.4912ms | 105.3610 Ops/s | 106.5492 Ops/s | |
test_unbind_speed_stack1 | 36.9000μs | 1.1459μs | 872.6525 KOps/s | 1.0696 MOps/s | |
test_creation[device0] | 0.5663ms | 0.4588ms | 2.1798 KOps/s | 2.1934 KOps/s | |
test_creation_from_tensor | 3.3405ms | 0.5156ms | 1.9396 KOps/s | 1.9503 KOps/s | |
test_add_one[memmap_tensor0] | 1.7185ms | 32.2600μs | 30.9981 KOps/s | 29.9758 KOps/s | |
test_contiguous[memmap_tensor0] | 37.9000μs | 8.5536μs | 116.9098 KOps/s | 113.6463 KOps/s | |
test_stack[memmap_tensor0] | 94.3010μs | 26.8073μs | 37.3032 KOps/s | 36.8955 KOps/s | |
test_memmaptd_index | 0.4086ms | 0.3105ms | 3.2201 KOps/s | 3.1616 KOps/s | |
test_memmaptd_index_astensor | 1.4785ms | 1.3473ms | 742.2295 Ops/s | 732.8135 Ops/s | |
test_memmaptd_index_op | 2.8614ms | 2.6146ms | 382.4704 Ops/s | 377.2368 Ops/s | |
test_reshape_pytree | 0.1050ms | 37.7285μs | 26.5052 KOps/s | 26.0908 KOps/s | |
test_reshape_td | 92.0010μs | 45.6639μs | 21.8991 KOps/s | 22.3851 KOps/s | |
test_view_pytree | 99.0020μs | 35.2183μs | 28.3944 KOps/s | 28.2752 KOps/s | |
test_view_td | 35.7000μs | 8.8340μs | 113.1994 KOps/s | 114.7216 KOps/s | |
test_unbind_pytree | 78.2010μs | 39.0483μs | 25.6093 KOps/s | 24.6807 KOps/s | |
test_unbind_td | 0.1787ms | 96.2214μs | 10.3927 KOps/s | 10.4074 KOps/s | |
test_split_pytree | 96.2010μs | 45.0152μs | 22.2147 KOps/s | 22.0971 KOps/s | |
test_split_td | 0.9764ms | 0.1184ms | 8.4451 KOps/s | 8.6110 KOps/s | |
test_add_pytree | 98.8010μs | 48.1959μs | 20.7487 KOps/s | 20.8750 KOps/s | |
test_add_td | 0.1063ms | 76.4733μs | 13.0765 KOps/s | 13.4798 KOps/s | |
test_distributed | 35.0010μs | 9.0257μs | 110.7949 KOps/s | 108.7756 KOps/s | |
test_tdmodule | 0.2045ms | 29.1121μs | 34.3500 KOps/s | 34.8120 KOps/s | |
test_tdmodule_dispatch | 0.3035ms | 55.8213μs | 17.9143 KOps/s | 17.8242 KOps/s | |
test_tdseq | 0.6298ms | 33.3325μs | 30.0008 KOps/s | 29.9102 KOps/s | |
test_tdseq_dispatch | 0.2218ms | 66.9781μs | 14.9303 KOps/s | 14.9125 KOps/s | |
test_instantiation_functorch | 2.1735ms | 1.6357ms | 611.3405 Ops/s | 609.0728 Ops/s | |
test_instantiation_td | 2.1300ms | 1.3599ms | 735.3740 Ops/s | 729.6130 Ops/s | |
test_exec_functorch | 0.2590ms | 0.1872ms | 5.3413 KOps/s | 5.2366 KOps/s | |
test_exec_td | 0.2231ms | 0.1769ms | 5.6541 KOps/s | 5.5350 KOps/s | |
test_vmap_mlp_speed[True-True] | 2.2949ms | 1.2167ms | 821.8927 Ops/s | 753.9204 Ops/s | |
test_vmap_mlp_speed[True-False] | 1.2352ms | 0.6247ms | 1.6008 KOps/s | 1.6338 KOps/s | |
test_vmap_mlp_speed[False-True] | 2.1746ms | 1.0348ms | 966.3385 Ops/s | 990.4335 Ops/s | |
test_vmap_mlp_speed[False-False] | 7.9722ms | 0.4735ms | 2.1121 KOps/s | 2.1827 KOps/s | |
test_vmap_transformer_speed[True-True] | 15.3061ms | 14.1775ms | 70.5342 Ops/s | 68.7471 Ops/s | |
test_vmap_transformer_speed[True-False] | 10.0372ms | 9.1490ms | 109.3021 Ops/s | 102.9905 Ops/s | |
test_vmap_transformer_speed[False-True] | 15.0934ms | 13.9873ms | 71.4933 Ops/s | 71.4979 Ops/s | |
test_vmap_transformer_speed[False-False] | 9.8064ms | 8.9806ms | 111.3513 Ops/s | 111.2151 Ops/s |
There are 2 things we monkey patch in tensordict:
fun = lambda x, td: td.set("y", x+1)
td = TensorDict({}, torch.Size([3]))
x = torch.randn(4)
vmap(fun, (None, 0))(x, td) We "hide" the
Having a registration mechanism for vmap as you suggest would be awesome but that would require to look at both of the things I mentioned here: how to have a custom add_batch_dims / remove_batch_dims for a specific class and how to skip pytree's flatten ops if there is a custom dim operation that will follow. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
I'm not sure how I feel about extending custom object support in vmap. The current contract is that vmap does not support custom objects, unless you have a pytree. One workaround I can brainstorm that doesn't involve monkeypatching is to store TensorDict's batch size as a Meta tensor. For example, if we construct a TensorDict({}, batch_size=[3, 4]), then we store a Meta tensor of size (3, 4) somewhere in the TensorDict. Whenever TensorDict needs to know its batch size, it can query the shape of the Meta tensor. When one vmaps over that, assuming TensorDict is a pytree, then we should end up with the correct semantics (under one vmap, that Meta tensor's shape would be [4], which is what we want). Now, this doesn't quite solve the performance problem. We've been moving towards "use torch.compile to eliminate overhead for performance" though, so in a long-term state we'd just use torch.compile to make things faster. |
thanks for the suggestion! There is another issue with this which is that I will need to provide the in_dim / out_dim for each leaf of the tensordict, whereas now I can do vmap(fun, (0, 0))(tensor, tensordict) Imagine I have a deeply nested tensordict with many many parameters, it is much harder to provide each in_dim (especially if they all match). |
Do all the Tensors in the TensorDict have the same batched-ness? E.g. if batch_size is (3, 4), do all the Tensors have shape (3, 4, *)? |
Yes but that does not mean that each level has the same batch size: You could have td = TensorDict(
{
"a": TensorDict(
{
"b": TensorDict({"c": torch.ones(2, 3, 4)}, batch_size=[2, 3]),
"d": torch.ones(2),
},
batch_size=[2],
),
"e": 1,
},
batch_size=[],
) where each level of the structure has one more batch dimension. |
Description
Makes tensordict compatible with pytree.
Solves the conflict with vmap between custom _add_batch_dims and torch version.
We ignore the tensordict node in pytree within vmap and rely on TensorDict._add_batch_dims instead.
cc @NicolasHug @ezyang @zou3519