Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] PyTree compatibility #501

Merged
merged 7 commits into from
Jul 31, 2023
Merged

[Feature] PyTree compatibility #501

merged 7 commits into from
Jul 31, 2023

Conversation

vmoens
Copy link
Contributor

@vmoens vmoens commented Jul 26, 2023

Description

Makes tensordict compatible with pytree.
Solves the conflict with vmap between custom _add_batch_dims and torch version.
We ignore the tensordict node in pytree within vmap and rely on TensorDict._add_batch_dims instead.

cc @NicolasHug @ezyang @zou3519

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 26, 2023
@vmoens vmoens added the enhancement New feature or request label Jul 26, 2023
if not str_spec.startswith("D"):
return None
context_strings, child_strings = _str_to_dict(str_spec)
return TensorDict, context_strings, child_strings
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @zou3519

We're going to delete these APIs soon, you don't have to implement it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed I saw that the tests were breaking, I guess it's because of that

@ezyang
Copy link

ezyang commented Jul 26, 2023

Looks reasonable!

@zou3519 for the non-monkeypatched vmap support, maybe we need some way for non-Tensor objects to say "hey, I need special batch dim handling"?

@github-actions
Copy link

github-actions bot commented Jul 27, 2023

$\color{#D29922}\textsf{\Large⚠\kern{0.2cm}\normalsize Warning}$ Result of CPU Benchmark Tests

Total Benchmarks: 109. Improved: $\large\color{#35bf28}3$. Worsened: $\large\color{#d91a1a}5$.

Expand to view detailed results
Name Max Mean Ops Ops on Repo HEAD Change
test_plain_set_nested 44.6010μs 20.2688μs 49.3369 KOps/s 49.8803 KOps/s $\color{#d91a1a}-1.09\%$
test_plain_set_stack_nested 0.2301ms 0.1883ms 5.3120 KOps/s 5.3984 KOps/s $\color{#d91a1a}-1.60\%$
test_plain_set_nested_inplace 59.3010μs 23.4181μs 42.7020 KOps/s 42.5577 KOps/s $\color{#35bf28}+0.34\%$
test_plain_set_stack_nested_inplace 0.2612ms 0.2210ms 4.5255 KOps/s 4.4899 KOps/s $\color{#35bf28}+0.79\%$
test_items 41.1010μs 3.4122μs 293.0650 KOps/s 289.7257 KOps/s $\color{#35bf28}+1.15\%$
test_items_nested 0.5769ms 0.3669ms 2.7258 KOps/s 2.7219 KOps/s $\color{#35bf28}+0.14\%$
test_items_nested_locked 4.3453ms 0.3875ms 2.5805 KOps/s 2.7379 KOps/s $\textbf{\color{#d91a1a}-5.75\%}$
test_items_nested_leaf 0.2606ms 0.2237ms 4.4703 KOps/s 4.2662 KOps/s $\color{#35bf28}+4.78\%$
test_items_stack_nested 2.1157ms 1.9783ms 505.4859 Ops/s 504.2864 Ops/s $\color{#35bf28}+0.24\%$
test_items_stack_nested_leaf 1.9003ms 1.7995ms 555.7114 Ops/s 554.8192 Ops/s $\color{#35bf28}+0.16\%$
test_items_stack_nested_locked 1.0379ms 0.9743ms 1.0264 KOps/s 989.9899 Ops/s $\color{#35bf28}+3.67\%$
test_keys 32.7010μs 5.0751μs 197.0395 KOps/s 196.4536 KOps/s $\color{#35bf28}+0.30\%$
test_keys_nested 1.1068ms 0.1823ms 5.4846 KOps/s 5.3869 KOps/s $\color{#35bf28}+1.81\%$
test_keys_nested_locked 0.2608ms 0.1807ms 5.5343 KOps/s 5.4840 KOps/s $\color{#35bf28}+0.92\%$
test_keys_nested_leaf 0.3980ms 0.1737ms 5.7567 KOps/s 5.1978 KOps/s $\textbf{\color{#35bf28}+10.75\%}$
test_keys_stack_nested 1.8710ms 1.7556ms 569.5995 Ops/s 564.0670 Ops/s $\color{#35bf28}+0.98\%$
test_keys_stack_nested_leaf 1.8362ms 1.7523ms 570.6943 Ops/s 566.4699 Ops/s $\color{#35bf28}+0.75\%$
test_keys_stack_nested_locked 0.8207ms 0.7493ms 1.3346 KOps/s 1.3301 KOps/s $\color{#35bf28}+0.34\%$
test_values 32.9000μs 1.5546μs 643.2658 KOps/s 651.8493 KOps/s $\color{#d91a1a}-1.32\%$
test_values_nested 0.1091ms 66.8753μs 14.9532 KOps/s 15.0137 KOps/s $\color{#d91a1a}-0.40\%$
test_values_nested_locked 0.1365ms 67.1748μs 14.8865 KOps/s 15.1439 KOps/s $\color{#d91a1a}-1.70\%$
test_values_nested_leaf 0.1275ms 59.7038μs 16.7493 KOps/s 16.9419 KOps/s $\color{#d91a1a}-1.14\%$
test_values_stack_nested 1.6632ms 1.5928ms 627.8230 Ops/s 627.0963 Ops/s $\color{#35bf28}+0.12\%$
test_values_stack_nested_leaf 1.6554ms 1.5836ms 631.4562 Ops/s 632.6455 Ops/s $\color{#d91a1a}-0.19\%$
test_values_stack_nested_locked 0.7407ms 0.6447ms 1.5511 KOps/s 1.5537 KOps/s $\color{#d91a1a}-0.16\%$
test_membership 30.7010μs 1.8376μs 544.1747 KOps/s 537.6081 KOps/s $\color{#35bf28}+1.22\%$
test_membership_nested 32.2010μs 3.5907μs 278.4950 KOps/s 282.6569 KOps/s $\color{#d91a1a}-1.47\%$
test_membership_nested_leaf 25.7000μs 3.5279μs 283.4548 KOps/s 281.5054 KOps/s $\color{#35bf28}+0.69\%$
test_membership_stacked_nested 71.6010μs 14.2601μs 70.1258 KOps/s 71.1296 KOps/s $\color{#d91a1a}-1.41\%$
test_membership_stacked_nested_leaf 51.6000μs 14.2674μs 70.0897 KOps/s 70.7645 KOps/s $\color{#d91a1a}-0.95\%$
test_membership_nested_last 64.1010μs 7.5723μs 132.0595 KOps/s 132.2411 KOps/s $\color{#d91a1a}-0.14\%$
test_membership_nested_leaf_last 40.2000μs 7.4535μs 134.1645 KOps/s 132.9676 KOps/s $\color{#35bf28}+0.90\%$
test_membership_stacked_nested_last 0.2539ms 0.2245ms 4.4540 KOps/s 4.3544 KOps/s $\color{#35bf28}+2.29\%$
test_membership_stacked_nested_leaf_last 45.5000μs 16.7499μs 59.7019 KOps/s 60.7078 KOps/s $\color{#d91a1a}-1.66\%$
test_nested_getleaf 45.1000μs 15.6577μs 63.8665 KOps/s 63.7428 KOps/s $\color{#35bf28}+0.19\%$
test_nested_get 42.1000μs 14.8558μs 67.3138 KOps/s 67.2367 KOps/s $\color{#35bf28}+0.11\%$
test_stacked_getleaf 1.0263ms 0.8760ms 1.1415 KOps/s 1.1551 KOps/s $\color{#d91a1a}-1.18\%$
test_stacked_get 0.8810ms 0.8364ms 1.1956 KOps/s 1.2028 KOps/s $\color{#d91a1a}-0.60\%$
test_nested_getitemleaf 44.7000μs 15.7353μs 63.5515 KOps/s 63.5436 KOps/s $\color{#35bf28}+0.01\%$
test_nested_getitem 71.2010μs 14.8488μs 67.3454 KOps/s 66.8066 KOps/s $\color{#35bf28}+0.81\%$
test_stacked_getitemleaf 1.0428ms 0.8745ms 1.1435 KOps/s 1.1479 KOps/s $\color{#d91a1a}-0.38\%$
test_stacked_getitem 0.8794ms 0.8335ms 1.1997 KOps/s 1.1969 KOps/s $\color{#35bf28}+0.24\%$
test_lock_nested 95.3401ms 1.5180ms 658.7729 Ops/s 705.4411 Ops/s $\textbf{\color{#d91a1a}-6.62\%}$
test_lock_stack_nested 0.1160s 21.5536ms 46.3959 Ops/s 50.4398 Ops/s $\textbf{\color{#d91a1a}-8.02\%}$
test_unlock_nested 92.0168ms 1.5342ms 651.8134 Ops/s 654.3776 Ops/s $\color{#d91a1a}-0.39\%$
test_unlock_stack_nested 0.1166s 22.1299ms 45.1877 Ops/s 48.5799 Ops/s $\textbf{\color{#d91a1a}-6.98\%}$
test_flatten_speed 1.0695ms 1.0216ms 978.8857 Ops/s 994.6758 Ops/s $\color{#d91a1a}-1.59\%$
test_unflatten_speed 1.8929ms 1.8450ms 542.0056 Ops/s 552.8322 Ops/s $\color{#d91a1a}-1.96\%$
test_common_ops 1.3877ms 1.1042ms 905.6108 Ops/s 907.3379 Ops/s $\color{#d91a1a}-0.19\%$
test_creation 34.0000μs 6.2189μs 160.7989 KOps/s 161.4770 KOps/s $\color{#d91a1a}-0.42\%$
test_creation_empty 45.8010μs 14.0264μs 71.2944 KOps/s 73.1127 KOps/s $\color{#d91a1a}-2.49\%$
test_creation_nested_1 61.4010μs 25.2920μs 39.5383 KOps/s 39.9138 KOps/s $\color{#d91a1a}-0.94\%$
test_creation_nested_2 63.2010μs 27.8803μs 35.8676 KOps/s 36.2566 KOps/s $\color{#d91a1a}-1.07\%$
test_clone 0.2116ms 24.5299μs 40.7666 KOps/s 40.0808 KOps/s $\color{#35bf28}+1.71\%$
test_getitem[int] 0.1318ms 27.1777μs 36.7948 KOps/s 36.3342 KOps/s $\color{#35bf28}+1.27\%$
test_getitem[slice_int] 0.1286ms 53.0185μs 18.8614 KOps/s 18.6978 KOps/s $\color{#35bf28}+0.87\%$
test_getitem[range] 0.1216ms 81.4685μs 12.2747 KOps/s 12.3332 KOps/s $\color{#d91a1a}-0.47\%$
test_getitem[tuple] 89.7010μs 44.3156μs 22.5654 KOps/s 22.1311 KOps/s $\color{#35bf28}+1.96\%$
test_getitem[list] 0.4586ms 77.7527μs 12.8613 KOps/s 13.0758 KOps/s $\color{#d91a1a}-1.64\%$
test_setitem_dim[int] 57.2010μs 32.7084μs 30.5732 KOps/s 30.8314 KOps/s $\color{#d91a1a}-0.84\%$
test_setitem_dim[slice_int] 87.8010μs 57.7233μs 17.3240 KOps/s 17.4016 KOps/s $\color{#d91a1a}-0.45\%$
test_setitem_dim[range] 0.1257ms 78.5358μs 12.7331 KOps/s 12.6851 KOps/s $\color{#35bf28}+0.38\%$
test_setitem_dim[tuple] 89.5010μs 48.1727μs 20.7586 KOps/s 20.8239 KOps/s $\color{#d91a1a}-0.31\%$
test_setitem 0.2418ms 32.8247μs 30.4649 KOps/s 31.2731 KOps/s $\color{#d91a1a}-2.58\%$
test_set 0.2017ms 31.4842μs 31.7620 KOps/s 32.4902 KOps/s $\color{#d91a1a}-2.24\%$
test_set_shared 0.3877ms 0.1799ms 5.5589 KOps/s 5.5508 KOps/s $\color{#35bf28}+0.15\%$
test_update 0.2070ms 35.5973μs 28.0920 KOps/s 28.2225 KOps/s $\color{#d91a1a}-0.46\%$
test_update_nested 0.3075ms 52.8315μs 18.9281 KOps/s 19.1800 KOps/s $\color{#d91a1a}-1.31\%$
test_set_nested 0.2068ms 34.6172μs 28.8874 KOps/s 29.2524 KOps/s $\color{#d91a1a}-1.25\%$
test_set_nested_new 0.2560ms 52.9621μs 18.8814 KOps/s 18.9136 KOps/s $\color{#d91a1a}-0.17\%$
test_select 0.2704ms 96.9350μs 10.3162 KOps/s 10.3404 KOps/s $\color{#d91a1a}-0.23\%$
test_unbind_speed 0.7058ms 0.6462ms 1.5475 KOps/s 1.5463 KOps/s $\color{#35bf28}+0.07\%$
test_unbind_speed_stack0 0.1064s 9.4912ms 105.3610 Ops/s 106.5492 Ops/s $\color{#d91a1a}-1.12\%$
test_unbind_speed_stack1 36.9000μs 1.1459μs 872.6525 KOps/s 1.0696 MOps/s $\textbf{\color{#d91a1a}-18.41\%}$
test_creation[device0] 0.5663ms 0.4588ms 2.1798 KOps/s 2.1934 KOps/s $\color{#d91a1a}-0.62\%$
test_creation_from_tensor 3.3405ms 0.5156ms 1.9396 KOps/s 1.9503 KOps/s $\color{#d91a1a}-0.55\%$
test_add_one[memmap_tensor0] 1.7185ms 32.2600μs 30.9981 KOps/s 29.9758 KOps/s $\color{#35bf28}+3.41\%$
test_contiguous[memmap_tensor0] 37.9000μs 8.5536μs 116.9098 KOps/s 113.6463 KOps/s $\color{#35bf28}+2.87\%$
test_stack[memmap_tensor0] 94.3010μs 26.8073μs 37.3032 KOps/s 36.8955 KOps/s $\color{#35bf28}+1.11\%$
test_memmaptd_index 0.4086ms 0.3105ms 3.2201 KOps/s 3.1616 KOps/s $\color{#35bf28}+1.85\%$
test_memmaptd_index_astensor 1.4785ms 1.3473ms 742.2295 Ops/s 732.8135 Ops/s $\color{#35bf28}+1.28\%$
test_memmaptd_index_op 2.8614ms 2.6146ms 382.4704 Ops/s 377.2368 Ops/s $\color{#35bf28}+1.39\%$
test_reshape_pytree 0.1050ms 37.7285μs 26.5052 KOps/s 26.0908 KOps/s $\color{#35bf28}+1.59\%$
test_reshape_td 92.0010μs 45.6639μs 21.8991 KOps/s 22.3851 KOps/s $\color{#d91a1a}-2.17\%$
test_view_pytree 99.0020μs 35.2183μs 28.3944 KOps/s 28.2752 KOps/s $\color{#35bf28}+0.42\%$
test_view_td 35.7000μs 8.8340μs 113.1994 KOps/s 114.7216 KOps/s $\color{#d91a1a}-1.33\%$
test_unbind_pytree 78.2010μs 39.0483μs 25.6093 KOps/s 24.6807 KOps/s $\color{#35bf28}+3.76\%$
test_unbind_td 0.1787ms 96.2214μs 10.3927 KOps/s 10.4074 KOps/s $\color{#d91a1a}-0.14\%$
test_split_pytree 96.2010μs 45.0152μs 22.2147 KOps/s 22.0971 KOps/s $\color{#35bf28}+0.53\%$
test_split_td 0.9764ms 0.1184ms 8.4451 KOps/s 8.6110 KOps/s $\color{#d91a1a}-1.93\%$
test_add_pytree 98.8010μs 48.1959μs 20.7487 KOps/s 20.8750 KOps/s $\color{#d91a1a}-0.61\%$
test_add_td 0.1063ms 76.4733μs 13.0765 KOps/s 13.4798 KOps/s $\color{#d91a1a}-2.99\%$
test_distributed 35.0010μs 9.0257μs 110.7949 KOps/s 108.7756 KOps/s $\color{#35bf28}+1.86\%$
test_tdmodule 0.2045ms 29.1121μs 34.3500 KOps/s 34.8120 KOps/s $\color{#d91a1a}-1.33\%$
test_tdmodule_dispatch 0.3035ms 55.8213μs 17.9143 KOps/s 17.8242 KOps/s $\color{#35bf28}+0.51\%$
test_tdseq 0.6298ms 33.3325μs 30.0008 KOps/s 29.9102 KOps/s $\color{#35bf28}+0.30\%$
test_tdseq_dispatch 0.2218ms 66.9781μs 14.9303 KOps/s 14.9125 KOps/s $\color{#35bf28}+0.12\%$
test_instantiation_functorch 2.1735ms 1.6357ms 611.3405 Ops/s 609.0728 Ops/s $\color{#35bf28}+0.37\%$
test_instantiation_td 2.1300ms 1.3599ms 735.3740 Ops/s 729.6130 Ops/s $\color{#35bf28}+0.79\%$
test_exec_functorch 0.2590ms 0.1872ms 5.3413 KOps/s 5.2366 KOps/s $\color{#35bf28}+2.00\%$
test_exec_td 0.2231ms 0.1769ms 5.6541 KOps/s 5.5350 KOps/s $\color{#35bf28}+2.15\%$
test_vmap_mlp_speed[True-True] 2.2949ms 1.2167ms 821.8927 Ops/s 753.9204 Ops/s $\textbf{\color{#35bf28}+9.02\%}$
test_vmap_mlp_speed[True-False] 1.2352ms 0.6247ms 1.6008 KOps/s 1.6338 KOps/s $\color{#d91a1a}-2.02\%$
test_vmap_mlp_speed[False-True] 2.1746ms 1.0348ms 966.3385 Ops/s 990.4335 Ops/s $\color{#d91a1a}-2.43\%$
test_vmap_mlp_speed[False-False] 7.9722ms 0.4735ms 2.1121 KOps/s 2.1827 KOps/s $\color{#d91a1a}-3.23\%$
test_vmap_transformer_speed[True-True] 15.3061ms 14.1775ms 70.5342 Ops/s 68.7471 Ops/s $\color{#35bf28}+2.60\%$
test_vmap_transformer_speed[True-False] 10.0372ms 9.1490ms 109.3021 Ops/s 102.9905 Ops/s $\textbf{\color{#35bf28}+6.13\%}$
test_vmap_transformer_speed[False-True] 15.0934ms 13.9873ms 71.4933 Ops/s 71.4979 Ops/s $-0.01\%$
test_vmap_transformer_speed[False-False] 9.8064ms 8.9806ms 111.3513 Ops/s 111.2151 Ops/s $\color{#35bf28}+0.12\%$

@vmoens vmoens marked this pull request as ready for review July 27, 2023 08:43
@vmoens
Copy link
Contributor Author

vmoens commented Jul 27, 2023

Looks reasonable!

@zou3519 for the non-monkeypatched vmap support, maybe we need some way for non-Tensor objects to say "hey, I need special batch dim handling"?

There are 2 things we monkey patch in tensordict:

  • first we need our "own" add_batch_dims since tensordict has its own shape: unlike regular dicts, we need to "hide" part of the shape when we vmap over a td.
    Example:
fun = lambda x, td: td.set("y", x+1)
td = TensorDict({}, torch.Size([3]))
x = torch.randn(4)
vmap(fun, (None, 0))(x, td)

We "hide" the torch.Size([3]) within the TD to make it broadcastable with shape torch.Size([4]) of the tensor x then recompute the batch size of the output accordingly.
The result is this:

TensorDict(
    fields={
        y: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
  1. The monkey patching from this PR follows the other one: we don't want to decompose the tensordict because:
    a. if we keep the batch size as metadata it will be hard to manipulate it when we want to reconstruct the tensordict with a new batch size
    b. it is inefficient (building nested tensordicts has some non-negligeable overhead due to the shape/metadata checks we do over the leading dimensions of the tensors -- internally we can skip those checks within tensordict because we know what we're feeding to our tensordict).

Having a registration mechanism for vmap as you suggest would be awesome but that would require to look at both of the things I mentioned here: how to have a custom add_batch_dims / remove_batch_dims for a specific class and how to skip pytree's flatten ops if there is a custom dim operation that will follow.

Copy link
Contributor

@matteobettini matteobettini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zou3519
Copy link

zou3519 commented Jul 27, 2023

I'm not sure how I feel about extending custom object support in vmap. The current contract is that vmap does not support custom objects, unless you have a pytree.

One workaround I can brainstorm that doesn't involve monkeypatching is to store TensorDict's batch size as a Meta tensor. For example, if we construct a TensorDict({}, batch_size=[3, 4]), then we store a Meta tensor of size (3, 4) somewhere in the TensorDict. Whenever TensorDict needs to know its batch size, it can query the shape of the Meta tensor.

When one vmaps over that, assuming TensorDict is a pytree, then we should end up with the correct semantics (under one vmap, that Meta tensor's shape would be [4], which is what we want).

Now, this doesn't quite solve the performance problem. We've been moving towards "use torch.compile to eliminate overhead for performance" though, so in a long-term state we'd just use torch.compile to make things faster.

@vmoens
Copy link
Contributor Author

vmoens commented Jul 27, 2023

thanks for the suggestion!
I just tried, the idea of the meta-tensor with tensordict shape is a good one (though torch.empty(shape, device="meta") is much slower than simply using shapes).

There is another issue with this which is that I will need to provide the in_dim / out_dim for each leaf of the tensordict, whereas now I can do

vmap(fun, (0, 0))(tensor, tensordict)

Imagine I have a deeply nested tensordict with many many parameters, it is much harder to provide each in_dim (especially if they all match).

@zou3519
Copy link

zou3519 commented Jul 27, 2023

Do all the Tensors in the TensorDict have the same batched-ness? E.g. if batch_size is (3, 4), do all the Tensors have shape (3, 4, *)?

@vmoens
Copy link
Contributor Author

vmoens commented Jul 27, 2023

Do all the Tensors in the TensorDict have the same batched-ness? E.g. if batch_size is (3, 4), do all the Tensors have shape (3, 4, *)?

Yes but that does not mean that each level has the same batch size:

You could have

        td = TensorDict(
            {
                "a": TensorDict(
                    {
                        "b": TensorDict({"c": torch.ones(2, 3, 4)}, batch_size=[2, 3]),
                        "d": torch.ones(2),
                    },
                    batch_size=[2],
                ),
                "e": 1,
            },
            batch_size=[],
        )

where each level of the structure has one more batch dimension.
For instance, you could have a dataset with batches of data, each containing a stack of tensordicts stacked together...

@vmoens vmoens merged commit 80225a5 into main Jul 31, 2023
@vmoens vmoens deleted the pytree branch July 31, 2023 17:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants