-
Notifications
You must be signed in to change notification settings - Fork 76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Hooks and Buffers for TensorDictParams #502
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
facebook-github-bot
added
the
CLA Signed
This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
label
Jul 28, 2023
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_plain_set_nested | 37.3000μs | 20.0877μs | 49.7817 KOps/s | 49.0593 KOps/s | |
test_plain_set_stack_nested | 0.2272ms | 0.1863ms | 5.3687 KOps/s | 5.3535 KOps/s | |
test_plain_set_nested_inplace | 44.4000μs | 23.6096μs | 42.3556 KOps/s | 42.1409 KOps/s | |
test_plain_set_stack_nested_inplace | 0.2523ms | 0.2202ms | 4.5404 KOps/s | 4.4918 KOps/s | |
test_items | 30.0000μs | 3.4059μs | 293.6063 KOps/s | 293.3663 KOps/s | |
test_items_nested | 0.3996ms | 0.3632ms | 2.7534 KOps/s | 2.6986 KOps/s | |
test_items_nested_locked | 1.9584ms | 0.3665ms | 2.7286 KOps/s | 2.7065 KOps/s | |
test_items_nested_leaf | 0.2537ms | 0.2216ms | 4.5117 KOps/s | 4.4110 KOps/s | |
test_items_stack_nested | 2.0862ms | 1.9841ms | 504.0084 Ops/s | 498.0062 Ops/s | |
test_items_stack_nested_leaf | 1.8571ms | 1.8087ms | 552.8707 Ops/s | 547.4835 Ops/s | |
test_items_stack_nested_locked | 1.2862ms | 0.9910ms | 1.0091 KOps/s | 996.9644 Ops/s | |
test_keys | 28.9010μs | 5.0839μs | 196.6995 KOps/s | 196.4193 KOps/s | |
test_keys_nested | 0.9235ms | 0.1820ms | 5.4953 KOps/s | 5.4251 KOps/s | |
test_keys_nested_locked | 0.2614ms | 0.1807ms | 5.5341 KOps/s | 5.5003 KOps/s | |
test_keys_nested_leaf | 0.3652ms | 0.1744ms | 5.7355 KOps/s | 5.2889 KOps/s | |
test_keys_stack_nested | 1.8203ms | 1.7587ms | 568.6111 Ops/s | 560.1492 Ops/s | |
test_keys_stack_nested_leaf | 1.8118ms | 1.7469ms | 572.4587 Ops/s | 561.5309 Ops/s | |
test_keys_stack_nested_locked | 0.7897ms | 0.7383ms | 1.3545 KOps/s | 1.2976 KOps/s | |
test_values | 12.3010μs | 1.5490μs | 645.5724 KOps/s | 649.4471 KOps/s | |
test_values_nested | 0.1389ms | 67.2445μs | 14.8711 KOps/s | 15.0243 KOps/s | |
test_values_nested_locked | 0.1011ms | 66.7508μs | 14.9811 KOps/s | 15.0571 KOps/s | |
test_values_nested_leaf | 0.1195ms | 59.5160μs | 16.8022 KOps/s | 17.0216 KOps/s | |
test_values_stack_nested | 1.6559ms | 1.5932ms | 627.6527 Ops/s | 618.9004 Ops/s | |
test_values_stack_nested_leaf | 1.6461ms | 1.5875ms | 629.9218 Ops/s | 621.2346 Ops/s | |
test_values_stack_nested_locked | 0.8807ms | 0.6389ms | 1.5653 KOps/s | 1.5133 KOps/s | |
test_membership | 22.7010μs | 1.8638μs | 536.5491 KOps/s | 554.8687 KOps/s | |
test_membership_nested | 33.1000μs | 3.5813μs | 279.2244 KOps/s | 276.4048 KOps/s | |
test_membership_nested_leaf | 17.2000μs | 3.6101μs | 277.0007 KOps/s | 280.9197 KOps/s | |
test_membership_stacked_nested | 39.2010μs | 14.3615μs | 69.6306 KOps/s | 69.8071 KOps/s | |
test_membership_stacked_nested_leaf | 41.8000μs | 14.3555μs | 69.6597 KOps/s | 69.7960 KOps/s | |
test_membership_nested_last | 31.7000μs | 7.5983μs | 131.6081 KOps/s | 133.8356 KOps/s | |
test_membership_nested_leaf_last | 36.9000μs | 7.5681μs | 132.1342 KOps/s | 133.2253 KOps/s | |
test_membership_stacked_nested_last | 0.2912ms | 0.2239ms | 4.4670 KOps/s | 4.4253 KOps/s | |
test_membership_stacked_nested_leaf_last | 57.5000μs | 16.8104μs | 59.4870 KOps/s | 59.4568 KOps/s | |
test_nested_getleaf | 0.2201ms | 15.5495μs | 64.3109 KOps/s | 63.7981 KOps/s | |
test_nested_get | 74.5010μs | 14.7846μs | 67.6381 KOps/s | 66.8767 KOps/s | |
test_stacked_getleaf | 0.9727ms | 0.8777ms | 1.1393 KOps/s | 1.1276 KOps/s | |
test_stacked_get | 0.8796ms | 0.8415ms | 1.1883 KOps/s | 1.1888 KOps/s | |
test_nested_getitemleaf | 45.4000μs | 15.5570μs | 64.2799 KOps/s | 63.7677 KOps/s | |
test_nested_getitem | 40.7000μs | 14.7200μs | 67.9348 KOps/s | 66.6994 KOps/s | |
test_stacked_getitemleaf | 0.9961ms | 0.8827ms | 1.1329 KOps/s | 1.1308 KOps/s | |
test_stacked_getitem | 0.9437ms | 0.8447ms | 1.1838 KOps/s | 1.1834 KOps/s | |
test_lock_nested | 77.5493ms | 1.5053ms | 664.3043 Ops/s | 702.0089 Ops/s | |
test_lock_stack_nested | 96.9980ms | 20.2820ms | 49.3048 Ops/s | 52.7186 Ops/s | |
test_unlock_nested | 74.1501ms | 1.5149ms | 660.1081 Ops/s | 656.5689 Ops/s | |
test_unlock_stack_nested | 97.6116ms | 20.9317ms | 47.7744 Ops/s | 51.4332 Ops/s | |
test_flatten_speed | 1.1168ms | 1.0117ms | 988.4231 Ops/s | 986.1066 Ops/s | |
test_unflatten_speed | 1.8874ms | 1.8372ms | 544.3044 Ops/s | 547.6928 Ops/s | |
test_common_ops | 1.3501ms | 1.0842ms | 922.3169 Ops/s | 897.8341 Ops/s | |
test_creation | 41.8000μs | 6.1739μs | 161.9721 KOps/s | 165.7203 KOps/s | |
test_creation_empty | 45.9000μs | 13.7512μs | 72.7210 KOps/s | 73.6298 KOps/s | |
test_creation_nested_1 | 57.7000μs | 24.9012μs | 40.1587 KOps/s | 40.0712 KOps/s | |
test_creation_nested_2 | 58.1010μs | 27.4795μs | 36.3908 KOps/s | 36.9204 KOps/s | |
test_clone | 0.1757ms | 24.4425μs | 40.9123 KOps/s | 39.7165 KOps/s | |
test_getitem[int] | 59.7010μs | 26.9599μs | 37.0921 KOps/s | 35.9662 KOps/s | |
test_getitem[slice_int] | 0.1059ms | 53.4461μs | 18.7104 KOps/s | 18.3536 KOps/s | |
test_getitem[range] | 0.1146ms | 81.5689μs | 12.2596 KOps/s | 11.9595 KOps/s | |
test_getitem[tuple] | 83.5010μs | 44.2660μs | 22.5907 KOps/s | 22.0527 KOps/s | |
test_getitem[list] | 0.3669ms | 77.3171μs | 12.9338 KOps/s | 12.6362 KOps/s | |
test_setitem_dim[int] | 52.4000μs | 32.4714μs | 30.7963 KOps/s | 29.9059 KOps/s | |
test_setitem_dim[slice_int] | 98.4010μs | 57.6976μs | 17.3317 KOps/s | 16.8096 KOps/s | |
test_setitem_dim[range] | 0.1142ms | 78.1131μs | 12.8019 KOps/s | 12.3529 KOps/s | |
test_setitem_dim[tuple] | 69.9010μs | 47.7203μs | 20.9554 KOps/s | 20.2280 KOps/s | |
test_setitem | 0.1845ms | 32.6692μs | 30.6099 KOps/s | 30.5102 KOps/s | |
test_set | 0.2031ms | 31.0232μs | 32.2340 KOps/s | 31.4549 KOps/s | |
test_set_shared | 0.3696ms | 0.1763ms | 5.6718 KOps/s | 5.6319 KOps/s | |
test_update | 0.2160ms | 34.7961μs | 28.7388 KOps/s | 27.6270 KOps/s | |
test_update_nested | 0.2255ms | 52.3778μs | 19.0921 KOps/s | 18.8388 KOps/s | |
test_set_nested | 0.1816ms | 34.1416μs | 29.2897 KOps/s | 28.4112 KOps/s | |
test_set_nested_new | 0.2229ms | 52.3669μs | 19.0960 KOps/s | 18.4963 KOps/s | |
test_select | 0.2751ms | 96.0850μs | 10.4074 KOps/s | 10.1363 KOps/s | |
test_unbind_speed | 0.7404ms | 0.6418ms | 1.5582 KOps/s | 1.5542 KOps/s | |
test_unbind_speed_stack0 | 86.8046ms | 8.9222ms | 112.0799 Ops/s | 112.5744 Ops/s | |
test_unbind_speed_stack1 | 25.3000μs | 1.1691μs | 855.3945 KOps/s | 851.5972 KOps/s | |
test_creation[device0] | 0.5399ms | 0.4498ms | 2.2233 KOps/s | 2.1849 KOps/s | |
test_creation_from_tensor | 0.6409ms | 0.5040ms | 1.9841 KOps/s | 1.9519 KOps/s | |
test_add_one[memmap_tensor0] | 2.1281ms | 32.4344μs | 30.8314 KOps/s | 29.5532 KOps/s | |
test_contiguous[memmap_tensor0] | 42.1000μs | 8.7049μs | 114.8773 KOps/s | 110.8945 KOps/s | |
test_stack[memmap_tensor0] | 89.2000μs | 26.6176μs | 37.5691 KOps/s | 36.4943 KOps/s | |
test_memmaptd_index | 0.3957ms | 0.3121ms | 3.2036 KOps/s | 3.1528 KOps/s | |
test_memmaptd_index_astensor | 1.4643ms | 1.3521ms | 739.5740 Ops/s | 728.0736 Ops/s | |
test_memmaptd_index_op | 2.6946ms | 2.6092ms | 383.2604 Ops/s | 374.2561 Ops/s | |
test_reshape_pytree | 0.1010ms | 37.7305μs | 26.5037 KOps/s | 26.2430 KOps/s | |
test_reshape_td | 78.1010μs | 44.7172μs | 22.3628 KOps/s | 22.2213 KOps/s | |
test_view_pytree | 96.5010μs | 35.3602μs | 28.2804 KOps/s | 28.2607 KOps/s | |
test_view_td | 37.6000μs | 8.8303μs | 113.2470 KOps/s | 111.6552 KOps/s | |
test_unbind_pytree | 0.1154ms | 38.9527μs | 25.6722 KOps/s | 25.5069 KOps/s | |
test_unbind_td | 0.1863ms | 94.6874μs | 10.5611 KOps/s | 10.4798 KOps/s | |
test_split_pytree | 89.5010μs | 45.2758μs | 22.0869 KOps/s | 22.1688 KOps/s | |
test_split_td | 0.9035ms | 0.1146ms | 8.7290 KOps/s | 8.6914 KOps/s | |
test_add_pytree | 89.1000μs | 47.6210μs | 20.9991 KOps/s | 20.2246 KOps/s | |
test_add_td | 0.1225ms | 75.1450μs | 13.3076 KOps/s | 13.1754 KOps/s | |
test_distributed | 74.5010μs | 9.1042μs | 109.8388 KOps/s | 110.7211 KOps/s | |
test_tdmodule | 0.1983ms | 28.5234μs | 35.0589 KOps/s | 34.4520 KOps/s | |
test_tdmodule_dispatch | 0.3004ms | 55.3223μs | 18.0759 KOps/s | 17.6302 KOps/s | |
test_tdseq | 0.5922ms | 32.9233μs | 30.3736 KOps/s | 29.5422 KOps/s | |
test_tdseq_dispatch | 0.2181ms | 67.1361μs | 14.8951 KOps/s | 14.5211 KOps/s | |
test_instantiation_functorch | 2.1464ms | 1.6352ms | 611.5438 Ops/s | 602.2004 Ops/s | |
test_instantiation_td | 2.1263ms | 1.3710ms | 729.4043 Ops/s | 729.2378 Ops/s | |
test_exec_functorch | 0.2794ms | 0.1863ms | 5.3673 KOps/s | 5.3162 KOps/s | |
test_exec_td | 0.3404ms | 0.1788ms | 5.5932 KOps/s | 5.5224 KOps/s | |
test_vmap_mlp_speed[True-True] | 2.8075ms | 1.1831ms | 845.2251 Ops/s | 835.9186 Ops/s | |
test_vmap_mlp_speed[True-False] | 2.5695ms | 0.6081ms | 1.6445 KOps/s | 1.6408 KOps/s | |
test_vmap_mlp_speed[False-True] | 3.4308ms | 1.0176ms | 982.6875 Ops/s | 1.0003 KOps/s | |
test_vmap_mlp_speed[False-False] | 1.1651ms | 0.4522ms | 2.2112 KOps/s | 2.2306 KOps/s | |
test_vmap_transformer_speed[True-True] | 15.5427ms | 14.4471ms | 69.2179 Ops/s | 71.6106 Ops/s | |
test_vmap_transformer_speed[True-False] | 10.5846ms | 9.5617ms | 104.5839 Ops/s | 108.3936 Ops/s | |
test_vmap_transformer_speed[False-True] | 16.1456ms | 14.4345ms | 69.2787 Ops/s | 70.8686 Ops/s | |
test_vmap_transformer_speed[False-False] | 10.5167ms | 9.3752ms | 106.6644 Ops/s | 111.9075 Ops/s |
matteobettini
approved these changes
Jul 28, 2023
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, some nits
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
CLA Signed
This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
Introduces 2 features for TensorDictParams:
pre_get_hooks
allow users to register operations to be executed every time aget
operation is executed (incl. values, items,__getitem__
,get
,_get_str
,_get_tuple
). For instance, this allows to create an instance of TensorDictParams that gets the.data
attribute of another one while keeping the content similar after casting to device or dtype.Buffer
class, similar tonn.Parameter
but without gradient by default. If a tensor is not a parameter in TensorDictParam, it will be aBuffer
. This change allows us to callTensorDictParams.to(...)
while being sure that every leaf of the structure has the same identity before and after the operation, as well as the buffers in_buffers
. Not using this class would result in the content of_buffers
being updated but not the content of the tensordict, thereby breaking the sync between the two.cc @matteobettini