Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Hooks and Buffers for TensorDictParams #502

Merged
merged 4 commits into from
Jul 28, 2023
Merged

Conversation

vmoens
Copy link
Contributor

@vmoens vmoens commented Jul 28, 2023

Description

Introduces 2 features for TensorDictParams:

  • pre_get_hooks allow users to register operations to be executed every time a get operation is executed (incl. values, items, __getitem__, get, _get_str, _get_tuple). For instance, this allows to create an instance of TensorDictParams that gets the .data attribute of another one while keeping the content similar after casting to device or dtype.
  • Introduces a Buffer class, similar to nn.Parameter but without gradient by default. If a tensor is not a parameter in TensorDictParam, it will be a Buffer. This change allows us to call TensorDictParams.to(...) while being sure that every leaf of the structure has the same identity before and after the operation, as well as the buffers in _buffers. Not using this class would result in the content of _buffers being updated but not the content of the tensordict, thereby breaking the sync between the two.

cc @matteobettini

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 28, 2023
@github-actions
Copy link

github-actions bot commented Jul 28, 2023

$\color{#D29922}\textsf{\Large⚠\kern{0.2cm}\normalsize Warning}$ Result of CPU Benchmark Tests

Total Benchmarks: 109. Improved: $\large\color{#35bf28}1$. Worsened: $\large\color{#d91a1a}3$.

Expand to view detailed results
Name Max Mean Ops Ops on Repo HEAD Change
test_plain_set_nested 37.3000μs 20.0877μs 49.7817 KOps/s 49.0593 KOps/s $\color{#35bf28}+1.47\%$
test_plain_set_stack_nested 0.2272ms 0.1863ms 5.3687 KOps/s 5.3535 KOps/s $\color{#35bf28}+0.28\%$
test_plain_set_nested_inplace 44.4000μs 23.6096μs 42.3556 KOps/s 42.1409 KOps/s $\color{#35bf28}+0.51\%$
test_plain_set_stack_nested_inplace 0.2523ms 0.2202ms 4.5404 KOps/s 4.4918 KOps/s $\color{#35bf28}+1.08\%$
test_items 30.0000μs 3.4059μs 293.6063 KOps/s 293.3663 KOps/s $\color{#35bf28}+0.08\%$
test_items_nested 0.3996ms 0.3632ms 2.7534 KOps/s 2.6986 KOps/s $\color{#35bf28}+2.03\%$
test_items_nested_locked 1.9584ms 0.3665ms 2.7286 KOps/s 2.7065 KOps/s $\color{#35bf28}+0.82\%$
test_items_nested_leaf 0.2537ms 0.2216ms 4.5117 KOps/s 4.4110 KOps/s $\color{#35bf28}+2.28\%$
test_items_stack_nested 2.0862ms 1.9841ms 504.0084 Ops/s 498.0062 Ops/s $\color{#35bf28}+1.21\%$
test_items_stack_nested_leaf 1.8571ms 1.8087ms 552.8707 Ops/s 547.4835 Ops/s $\color{#35bf28}+0.98\%$
test_items_stack_nested_locked 1.2862ms 0.9910ms 1.0091 KOps/s 996.9644 Ops/s $\color{#35bf28}+1.22\%$
test_keys 28.9010μs 5.0839μs 196.6995 KOps/s 196.4193 KOps/s $\color{#35bf28}+0.14\%$
test_keys_nested 0.9235ms 0.1820ms 5.4953 KOps/s 5.4251 KOps/s $\color{#35bf28}+1.29\%$
test_keys_nested_locked 0.2614ms 0.1807ms 5.5341 KOps/s 5.5003 KOps/s $\color{#35bf28}+0.61\%$
test_keys_nested_leaf 0.3652ms 0.1744ms 5.7355 KOps/s 5.2889 KOps/s $\textbf{\color{#35bf28}+8.44\%}$
test_keys_stack_nested 1.8203ms 1.7587ms 568.6111 Ops/s 560.1492 Ops/s $\color{#35bf28}+1.51\%$
test_keys_stack_nested_leaf 1.8118ms 1.7469ms 572.4587 Ops/s 561.5309 Ops/s $\color{#35bf28}+1.95\%$
test_keys_stack_nested_locked 0.7897ms 0.7383ms 1.3545 KOps/s 1.2976 KOps/s $\color{#35bf28}+4.38\%$
test_values 12.3010μs 1.5490μs 645.5724 KOps/s 649.4471 KOps/s $\color{#d91a1a}-0.60\%$
test_values_nested 0.1389ms 67.2445μs 14.8711 KOps/s 15.0243 KOps/s $\color{#d91a1a}-1.02\%$
test_values_nested_locked 0.1011ms 66.7508μs 14.9811 KOps/s 15.0571 KOps/s $\color{#d91a1a}-0.50\%$
test_values_nested_leaf 0.1195ms 59.5160μs 16.8022 KOps/s 17.0216 KOps/s $\color{#d91a1a}-1.29\%$
test_values_stack_nested 1.6559ms 1.5932ms 627.6527 Ops/s 618.9004 Ops/s $\color{#35bf28}+1.41\%$
test_values_stack_nested_leaf 1.6461ms 1.5875ms 629.9218 Ops/s 621.2346 Ops/s $\color{#35bf28}+1.40\%$
test_values_stack_nested_locked 0.8807ms 0.6389ms 1.5653 KOps/s 1.5133 KOps/s $\color{#35bf28}+3.44\%$
test_membership 22.7010μs 1.8638μs 536.5491 KOps/s 554.8687 KOps/s $\color{#d91a1a}-3.30\%$
test_membership_nested 33.1000μs 3.5813μs 279.2244 KOps/s 276.4048 KOps/s $\color{#35bf28}+1.02\%$
test_membership_nested_leaf 17.2000μs 3.6101μs 277.0007 KOps/s 280.9197 KOps/s $\color{#d91a1a}-1.40\%$
test_membership_stacked_nested 39.2010μs 14.3615μs 69.6306 KOps/s 69.8071 KOps/s $\color{#d91a1a}-0.25\%$
test_membership_stacked_nested_leaf 41.8000μs 14.3555μs 69.6597 KOps/s 69.7960 KOps/s $\color{#d91a1a}-0.20\%$
test_membership_nested_last 31.7000μs 7.5983μs 131.6081 KOps/s 133.8356 KOps/s $\color{#d91a1a}-1.66\%$
test_membership_nested_leaf_last 36.9000μs 7.5681μs 132.1342 KOps/s 133.2253 KOps/s $\color{#d91a1a}-0.82\%$
test_membership_stacked_nested_last 0.2912ms 0.2239ms 4.4670 KOps/s 4.4253 KOps/s $\color{#35bf28}+0.94\%$
test_membership_stacked_nested_leaf_last 57.5000μs 16.8104μs 59.4870 KOps/s 59.4568 KOps/s $\color{#35bf28}+0.05\%$
test_nested_getleaf 0.2201ms 15.5495μs 64.3109 KOps/s 63.7981 KOps/s $\color{#35bf28}+0.80\%$
test_nested_get 74.5010μs 14.7846μs 67.6381 KOps/s 66.8767 KOps/s $\color{#35bf28}+1.14\%$
test_stacked_getleaf 0.9727ms 0.8777ms 1.1393 KOps/s 1.1276 KOps/s $\color{#35bf28}+1.04\%$
test_stacked_get 0.8796ms 0.8415ms 1.1883 KOps/s 1.1888 KOps/s $\color{#d91a1a}-0.04\%$
test_nested_getitemleaf 45.4000μs 15.5570μs 64.2799 KOps/s 63.7677 KOps/s $\color{#35bf28}+0.80\%$
test_nested_getitem 40.7000μs 14.7200μs 67.9348 KOps/s 66.6994 KOps/s $\color{#35bf28}+1.85\%$
test_stacked_getitemleaf 0.9961ms 0.8827ms 1.1329 KOps/s 1.1308 KOps/s $\color{#35bf28}+0.19\%$
test_stacked_getitem 0.9437ms 0.8447ms 1.1838 KOps/s 1.1834 KOps/s $\color{#35bf28}+0.03\%$
test_lock_nested 77.5493ms 1.5053ms 664.3043 Ops/s 702.0089 Ops/s $\textbf{\color{#d91a1a}-5.37\%}$
test_lock_stack_nested 96.9980ms 20.2820ms 49.3048 Ops/s 52.7186 Ops/s $\textbf{\color{#d91a1a}-6.48\%}$
test_unlock_nested 74.1501ms 1.5149ms 660.1081 Ops/s 656.5689 Ops/s $\color{#35bf28}+0.54\%$
test_unlock_stack_nested 97.6116ms 20.9317ms 47.7744 Ops/s 51.4332 Ops/s $\textbf{\color{#d91a1a}-7.11\%}$
test_flatten_speed 1.1168ms 1.0117ms 988.4231 Ops/s 986.1066 Ops/s $\color{#35bf28}+0.23\%$
test_unflatten_speed 1.8874ms 1.8372ms 544.3044 Ops/s 547.6928 Ops/s $\color{#d91a1a}-0.62\%$
test_common_ops 1.3501ms 1.0842ms 922.3169 Ops/s 897.8341 Ops/s $\color{#35bf28}+2.73\%$
test_creation 41.8000μs 6.1739μs 161.9721 KOps/s 165.7203 KOps/s $\color{#d91a1a}-2.26\%$
test_creation_empty 45.9000μs 13.7512μs 72.7210 KOps/s 73.6298 KOps/s $\color{#d91a1a}-1.23\%$
test_creation_nested_1 57.7000μs 24.9012μs 40.1587 KOps/s 40.0712 KOps/s $\color{#35bf28}+0.22\%$
test_creation_nested_2 58.1010μs 27.4795μs 36.3908 KOps/s 36.9204 KOps/s $\color{#d91a1a}-1.43\%$
test_clone 0.1757ms 24.4425μs 40.9123 KOps/s 39.7165 KOps/s $\color{#35bf28}+3.01\%$
test_getitem[int] 59.7010μs 26.9599μs 37.0921 KOps/s 35.9662 KOps/s $\color{#35bf28}+3.13\%$
test_getitem[slice_int] 0.1059ms 53.4461μs 18.7104 KOps/s 18.3536 KOps/s $\color{#35bf28}+1.94\%$
test_getitem[range] 0.1146ms 81.5689μs 12.2596 KOps/s 11.9595 KOps/s $\color{#35bf28}+2.51\%$
test_getitem[tuple] 83.5010μs 44.2660μs 22.5907 KOps/s 22.0527 KOps/s $\color{#35bf28}+2.44\%$
test_getitem[list] 0.3669ms 77.3171μs 12.9338 KOps/s 12.6362 KOps/s $\color{#35bf28}+2.35\%$
test_setitem_dim[int] 52.4000μs 32.4714μs 30.7963 KOps/s 29.9059 KOps/s $\color{#35bf28}+2.98\%$
test_setitem_dim[slice_int] 98.4010μs 57.6976μs 17.3317 KOps/s 16.8096 KOps/s $\color{#35bf28}+3.11\%$
test_setitem_dim[range] 0.1142ms 78.1131μs 12.8019 KOps/s 12.3529 KOps/s $\color{#35bf28}+3.64\%$
test_setitem_dim[tuple] 69.9010μs 47.7203μs 20.9554 KOps/s 20.2280 KOps/s $\color{#35bf28}+3.60\%$
test_setitem 0.1845ms 32.6692μs 30.6099 KOps/s 30.5102 KOps/s $\color{#35bf28}+0.33\%$
test_set 0.2031ms 31.0232μs 32.2340 KOps/s 31.4549 KOps/s $\color{#35bf28}+2.48\%$
test_set_shared 0.3696ms 0.1763ms 5.6718 KOps/s 5.6319 KOps/s $\color{#35bf28}+0.71\%$
test_update 0.2160ms 34.7961μs 28.7388 KOps/s 27.6270 KOps/s $\color{#35bf28}+4.02\%$
test_update_nested 0.2255ms 52.3778μs 19.0921 KOps/s 18.8388 KOps/s $\color{#35bf28}+1.34\%$
test_set_nested 0.1816ms 34.1416μs 29.2897 KOps/s 28.4112 KOps/s $\color{#35bf28}+3.09\%$
test_set_nested_new 0.2229ms 52.3669μs 19.0960 KOps/s 18.4963 KOps/s $\color{#35bf28}+3.24\%$
test_select 0.2751ms 96.0850μs 10.4074 KOps/s 10.1363 KOps/s $\color{#35bf28}+2.68\%$
test_unbind_speed 0.7404ms 0.6418ms 1.5582 KOps/s 1.5542 KOps/s $\color{#35bf28}+0.26\%$
test_unbind_speed_stack0 86.8046ms 8.9222ms 112.0799 Ops/s 112.5744 Ops/s $\color{#d91a1a}-0.44\%$
test_unbind_speed_stack1 25.3000μs 1.1691μs 855.3945 KOps/s 851.5972 KOps/s $\color{#35bf28}+0.45\%$
test_creation[device0] 0.5399ms 0.4498ms 2.2233 KOps/s 2.1849 KOps/s $\color{#35bf28}+1.76\%$
test_creation_from_tensor 0.6409ms 0.5040ms 1.9841 KOps/s 1.9519 KOps/s $\color{#35bf28}+1.65\%$
test_add_one[memmap_tensor0] 2.1281ms 32.4344μs 30.8314 KOps/s 29.5532 KOps/s $\color{#35bf28}+4.33\%$
test_contiguous[memmap_tensor0] 42.1000μs 8.7049μs 114.8773 KOps/s 110.8945 KOps/s $\color{#35bf28}+3.59\%$
test_stack[memmap_tensor0] 89.2000μs 26.6176μs 37.5691 KOps/s 36.4943 KOps/s $\color{#35bf28}+2.95\%$
test_memmaptd_index 0.3957ms 0.3121ms 3.2036 KOps/s 3.1528 KOps/s $\color{#35bf28}+1.61\%$
test_memmaptd_index_astensor 1.4643ms 1.3521ms 739.5740 Ops/s 728.0736 Ops/s $\color{#35bf28}+1.58\%$
test_memmaptd_index_op 2.6946ms 2.6092ms 383.2604 Ops/s 374.2561 Ops/s $\color{#35bf28}+2.41\%$
test_reshape_pytree 0.1010ms 37.7305μs 26.5037 KOps/s 26.2430 KOps/s $\color{#35bf28}+0.99\%$
test_reshape_td 78.1010μs 44.7172μs 22.3628 KOps/s 22.2213 KOps/s $\color{#35bf28}+0.64\%$
test_view_pytree 96.5010μs 35.3602μs 28.2804 KOps/s 28.2607 KOps/s $\color{#35bf28}+0.07\%$
test_view_td 37.6000μs 8.8303μs 113.2470 KOps/s 111.6552 KOps/s $\color{#35bf28}+1.43\%$
test_unbind_pytree 0.1154ms 38.9527μs 25.6722 KOps/s 25.5069 KOps/s $\color{#35bf28}+0.65\%$
test_unbind_td 0.1863ms 94.6874μs 10.5611 KOps/s 10.4798 KOps/s $\color{#35bf28}+0.78\%$
test_split_pytree 89.5010μs 45.2758μs 22.0869 KOps/s 22.1688 KOps/s $\color{#d91a1a}-0.37\%$
test_split_td 0.9035ms 0.1146ms 8.7290 KOps/s 8.6914 KOps/s $\color{#35bf28}+0.43\%$
test_add_pytree 89.1000μs 47.6210μs 20.9991 KOps/s 20.2246 KOps/s $\color{#35bf28}+3.83\%$
test_add_td 0.1225ms 75.1450μs 13.3076 KOps/s 13.1754 KOps/s $\color{#35bf28}+1.00\%$
test_distributed 74.5010μs 9.1042μs 109.8388 KOps/s 110.7211 KOps/s $\color{#d91a1a}-0.80\%$
test_tdmodule 0.1983ms 28.5234μs 35.0589 KOps/s 34.4520 KOps/s $\color{#35bf28}+1.76\%$
test_tdmodule_dispatch 0.3004ms 55.3223μs 18.0759 KOps/s 17.6302 KOps/s $\color{#35bf28}+2.53\%$
test_tdseq 0.5922ms 32.9233μs 30.3736 KOps/s 29.5422 KOps/s $\color{#35bf28}+2.81\%$
test_tdseq_dispatch 0.2181ms 67.1361μs 14.8951 KOps/s 14.5211 KOps/s $\color{#35bf28}+2.58\%$
test_instantiation_functorch 2.1464ms 1.6352ms 611.5438 Ops/s 602.2004 Ops/s $\color{#35bf28}+1.55\%$
test_instantiation_td 2.1263ms 1.3710ms 729.4043 Ops/s 729.2378 Ops/s $\color{#35bf28}+0.02\%$
test_exec_functorch 0.2794ms 0.1863ms 5.3673 KOps/s 5.3162 KOps/s $\color{#35bf28}+0.96\%$
test_exec_td 0.3404ms 0.1788ms 5.5932 KOps/s 5.5224 KOps/s $\color{#35bf28}+1.28\%$
test_vmap_mlp_speed[True-True] 2.8075ms 1.1831ms 845.2251 Ops/s 835.9186 Ops/s $\color{#35bf28}+1.11\%$
test_vmap_mlp_speed[True-False] 2.5695ms 0.6081ms 1.6445 KOps/s 1.6408 KOps/s $\color{#35bf28}+0.22\%$
test_vmap_mlp_speed[False-True] 3.4308ms 1.0176ms 982.6875 Ops/s 1.0003 KOps/s $\color{#d91a1a}-1.76\%$
test_vmap_mlp_speed[False-False] 1.1651ms 0.4522ms 2.2112 KOps/s 2.2306 KOps/s $\color{#d91a1a}-0.87\%$
test_vmap_transformer_speed[True-True] 15.5427ms 14.4471ms 69.2179 Ops/s 71.6106 Ops/s $\color{#d91a1a}-3.34\%$
test_vmap_transformer_speed[True-False] 10.5846ms 9.5617ms 104.5839 Ops/s 108.3936 Ops/s $\color{#d91a1a}-3.51\%$
test_vmap_transformer_speed[False-True] 16.1456ms 14.4345ms 69.2787 Ops/s 70.8686 Ops/s $\color{#d91a1a}-2.24\%$
test_vmap_transformer_speed[False-False] 10.5167ms 9.3752ms 106.6644 Ops/s 111.9075 Ops/s $\color{#d91a1a}-4.69\%$

Copy link
Contributor

@matteobettini matteobettini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, some nits

tensordict/nn/params.py Outdated Show resolved Hide resolved
tensordict/nn/params.py Outdated Show resolved Hide resolved
tensordict/nn/params.py Outdated Show resolved Hide resolved
@vmoens vmoens merged commit 5270f76 into main Jul 28, 2023
@vmoens vmoens deleted the hook_tdparams branch July 28, 2023 15:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants