-
Notifications
You must be signed in to change notification settings - Fork 76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BugFix] torch.where
: Expand mask on the right in lazy stack tds
#542
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
facebook-github-bot
added
the
CLA Signed
This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
label
Oct 10, 2023
vmoens
changed the title
[BugFix] Expand mask on the right in lazy stack tds
[BugFix] Oct 10, 2023
torch.where
: Expand mask on the right in lazy stack tds
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_plain_set_nested | 39.3010μs | 20.0317μs | 49.9208 KOps/s | 49.7603 KOps/s | |
test_plain_set_stack_nested | 0.2249ms | 0.1847ms | 5.4134 KOps/s | 5.3740 KOps/s | |
test_plain_set_nested_inplace | 41.3000μs | 23.6374μs | 42.3059 KOps/s | 42.3696 KOps/s | |
test_plain_set_stack_nested_inplace | 0.3069ms | 0.2189ms | 4.5680 KOps/s | 4.5492 KOps/s | |
test_items | 20.8010μs | 3.5134μs | 284.6217 KOps/s | 283.0765 KOps/s | |
test_items_nested | 0.4697ms | 0.3701ms | 2.7018 KOps/s | 2.7397 KOps/s | |
test_items_nested_locked | 0.4584ms | 0.3703ms | 2.7006 KOps/s | 2.7456 KOps/s | |
test_items_nested_leaf | 0.3074ms | 0.2233ms | 4.4785 KOps/s | 4.4966 KOps/s | |
test_items_stack_nested | 2.1321ms | 1.9754ms | 506.2388 Ops/s | 505.3010 Ops/s | |
test_items_stack_nested_leaf | 1.9100ms | 1.7965ms | 556.6445 Ops/s | 548.9681 Ops/s | |
test_items_stack_nested_locked | 1.0932ms | 0.9712ms | 1.0297 KOps/s | 1.0137 KOps/s | |
test_keys | 27.0010μs | 5.0654μs | 197.4175 KOps/s | 196.1921 KOps/s | |
test_keys_nested | 0.8365ms | 0.1817ms | 5.5042 KOps/s | 4.9852 KOps/s | |
test_keys_nested_locked | 0.2638ms | 0.1800ms | 5.5569 KOps/s | 5.5000 KOps/s | |
test_keys_nested_leaf | 0.9100ms | 0.1766ms | 5.6616 KOps/s | 5.7291 KOps/s | |
test_keys_stack_nested | 1.9717ms | 1.8201ms | 549.4156 Ops/s | 541.6559 Ops/s | |
test_keys_stack_nested_leaf | 1.9799ms | 1.8221ms | 548.8197 Ops/s | 545.0268 Ops/s | |
test_keys_stack_nested_locked | 1.0184ms | 0.8178ms | 1.2228 KOps/s | 1.2032 KOps/s | |
test_values | 17.5003μs | 1.3895μs | 719.6575 KOps/s | 621.6983 KOps/s | |
test_values_nested | 0.1501ms | 66.9743μs | 14.9311 KOps/s | 14.7533 KOps/s | |
test_values_nested_locked | 0.1585ms | 67.1261μs | 14.8973 KOps/s | 14.8513 KOps/s | |
test_values_nested_leaf | 0.1550ms | 58.3605μs | 17.1349 KOps/s | 16.9566 KOps/s | |
test_values_stack_nested | 1.6642ms | 1.5867ms | 630.2372 Ops/s | 628.8637 Ops/s | |
test_values_stack_nested_leaf | 1.7215ms | 1.5801ms | 632.8844 Ops/s | 633.3336 Ops/s | |
test_values_stack_nested_locked | 0.8040ms | 0.6396ms | 1.5635 KOps/s | 1.5319 KOps/s | |
test_membership | 67.0010μs | 1.8863μs | 530.1284 KOps/s | 525.6228 KOps/s | |
test_membership_nested | 40.4010μs | 3.6515μs | 273.8566 KOps/s | 272.7630 KOps/s | |
test_membership_nested_leaf | 25.6010μs | 3.7439μs | 267.1037 KOps/s | 271.3841 KOps/s | |
test_membership_stacked_nested | 0.1007ms | 14.4092μs | 69.4000 KOps/s | 69.3092 KOps/s | |
test_membership_stacked_nested_leaf | 70.8010μs | 14.3789μs | 69.5462 KOps/s | 69.4387 KOps/s | |
test_membership_nested_last | 27.1000μs | 7.5400μs | 132.6254 KOps/s | 132.8320 KOps/s | |
test_membership_nested_leaf_last | 93.2020μs | 7.6358μs | 130.9628 KOps/s | 133.4512 KOps/s | |
test_membership_stacked_nested_last | 0.3581ms | 0.2280ms | 4.3853 KOps/s | 4.4272 KOps/s | |
test_membership_stacked_nested_leaf_last | 54.2010μs | 16.8313μs | 59.4131 KOps/s | 59.1548 KOps/s | |
test_nested_getleaf | 0.1027ms | 15.6424μs | 63.9290 KOps/s | 64.1580 KOps/s | |
test_nested_get | 0.1067ms | 14.8766μs | 67.2195 KOps/s | 67.7628 KOps/s | |
test_stacked_getleaf | 1.0582ms | 0.8818ms | 1.1340 KOps/s | 1.1436 KOps/s | |
test_stacked_get | 0.9685ms | 0.8388ms | 1.1922 KOps/s | 1.1879 KOps/s | |
test_nested_getitemleaf | 96.5010μs | 15.6292μs | 63.9828 KOps/s | 64.4908 KOps/s | |
test_nested_getitem | 0.1062ms | 14.8617μs | 67.2872 KOps/s | 67.7218 KOps/s | |
test_stacked_getitemleaf | 1.0689ms | 0.8843ms | 1.1308 KOps/s | 1.1444 KOps/s | |
test_stacked_getitem | 0.9908ms | 0.8388ms | 1.1921 KOps/s | 1.1932 KOps/s | |
test_lock_nested | 96.0545ms | 1.5995ms | 625.2008 Ops/s | 678.5814 Ops/s | |
test_lock_stack_nested | 0.1159s | 21.8540ms | 45.7582 Ops/s | 44.5582 Ops/s | |
test_unlock_nested | 75.3770ms | 1.5782ms | 633.6456 Ops/s | 634.6044 Ops/s | |
test_unlock_stack_nested | 0.1159s | 22.4832ms | 44.4776 Ops/s | 46.7330 Ops/s | |
test_flatten_speed | 1.0456ms | 1.0003ms | 999.6978 Ops/s | 967.6769 Ops/s | |
test_unflatten_speed | 1.8772ms | 1.7925ms | 557.8955 Ops/s | 547.2071 Ops/s | |
test_common_ops | 7.0275ms | 1.1165ms | 895.6822 Ops/s | 888.0178 Ops/s | |
test_creation | 0.2873ms | 6.1170μs | 163.4780 KOps/s | 163.2646 KOps/s | |
test_creation_empty | 84.9020μs | 13.4551μs | 74.3214 KOps/s | 73.2266 KOps/s | |
test_creation_nested_1 | 56.1010μs | 24.4363μs | 40.9227 KOps/s | 40.1688 KOps/s | |
test_creation_nested_2 | 74.1010μs | 26.6349μs | 37.5447 KOps/s | 37.0601 KOps/s | |
test_clone | 0.2100ms | 24.8616μs | 40.2227 KOps/s | 41.1436 KOps/s | |
test_getitem[int] | 57.3010μs | 28.2781μs | 35.3631 KOps/s | 35.2494 KOps/s | |
test_getitem[slice_int] | 98.3020μs | 55.2106μs | 18.1125 KOps/s | 18.1014 KOps/s | |
test_getitem[range] | 0.1770ms | 81.0748μs | 12.3343 KOps/s | 12.3035 KOps/s | |
test_getitem[tuple] | 70.1010μs | 45.7862μs | 21.8406 KOps/s | 21.4105 KOps/s | |
test_getitem[list] | 0.4040ms | 76.5691μs | 13.0601 KOps/s | 12.9804 KOps/s | |
test_setitem_dim[int] | 69.9010μs | 34.5950μs | 28.9059 KOps/s | 29.1416 KOps/s | |
test_setitem_dim[slice_int] | 92.7020μs | 61.3029μs | 16.3125 KOps/s | 16.5435 KOps/s | |
test_setitem_dim[range] | 0.1189ms | 80.3825μs | 12.4405 KOps/s | 12.2250 KOps/s | |
test_setitem_dim[tuple] | 76.9010μs | 50.2961μs | 19.8823 KOps/s | 19.8762 KOps/s | |
test_setitem | 0.2335ms | 32.0979μs | 31.1547 KOps/s | 31.1057 KOps/s | |
test_set | 0.2394ms | 30.6950μs | 32.5786 KOps/s | 32.1935 KOps/s | |
test_set_shared | 6.4905ms | 0.2081ms | 4.8046 KOps/s | 4.9344 KOps/s | |
test_update | 0.2436ms | 35.0241μs | 28.5518 KOps/s | 28.1346 KOps/s | |
test_update_nested | 0.3040ms | 51.7060μs | 19.3401 KOps/s | 19.1310 KOps/s | |
test_set_nested | 0.2045ms | 34.0193μs | 29.3950 KOps/s | 28.9959 KOps/s | |
test_set_nested_new | 0.2724ms | 53.5183μs | 18.6852 KOps/s | 18.8586 KOps/s | |
test_select | 0.2803ms | 99.0671μs | 10.0942 KOps/s | 10.3495 KOps/s | |
test_unbind_speed | 0.6852ms | 0.6453ms | 1.5496 KOps/s | 1.5544 KOps/s | |
test_unbind_speed_stack0 | 93.6423ms | 8.8237ms | 113.3308 Ops/s | 108.8995 Ops/s | |
test_unbind_speed_stack1 | 11.4002μs | 0.9463μs | 1.0568 MOps/s | 863.9092 KOps/s | |
test_creation[device0] | 0.5963ms | 0.4516ms | 2.2146 KOps/s | 2.2132 KOps/s | |
test_creation_from_tensor | 4.5838ms | 0.5104ms | 1.9593 KOps/s | 1.9502 KOps/s | |
test_add_one[memmap_tensor0] | 2.2176ms | 32.8085μs | 30.4799 KOps/s | 30.0315 KOps/s | |
test_contiguous[memmap_tensor0] | 37.0000μs | 8.7577μs | 114.1855 KOps/s | 109.3695 KOps/s | |
test_stack[memmap_tensor0] | 89.5020μs | 26.8878μs | 37.1916 KOps/s | 37.3791 KOps/s | |
test_memmaptd_index | 0.3847ms | 0.3154ms | 3.1702 KOps/s | 3.1236 KOps/s | |
test_memmaptd_index_astensor | 1.2771ms | 1.2207ms | 819.2231 Ops/s | 810.2075 Ops/s | |
test_memmaptd_index_op | 2.7302ms | 2.6700ms | 374.5280 Ops/s | 377.4644 Ops/s | |
test_reshape_pytree | 93.2020μs | 32.4585μs | 30.8086 KOps/s | 30.2262 KOps/s | |
test_reshape_td | 0.3555ms | 40.0688μs | 24.9570 KOps/s | 24.5804 KOps/s | |
test_view_pytree | 0.1399ms | 32.1996μs | 31.0563 KOps/s | 30.6202 KOps/s | |
test_view_td | 70.0020μs | 8.8803μs | 112.6091 KOps/s | 112.4053 KOps/s | |
test_unbind_pytree | 79.7010μs | 37.5351μs | 26.6417 KOps/s | 26.2873 KOps/s | |
test_unbind_td | 0.2090ms | 95.4766μs | 10.4738 KOps/s | 10.3124 KOps/s | |
test_split_pytree | 81.8010μs | 37.0200μs | 27.0124 KOps/s | 26.6062 KOps/s | |
test_split_td | 1.0116ms | 0.1069ms | 9.3533 KOps/s | 9.1746 KOps/s | |
test_add_pytree | 97.7020μs | 46.7915μs | 21.3714 KOps/s | 21.2650 KOps/s | |
test_add_td | 0.1559ms | 76.8822μs | 13.0069 KOps/s | 12.9811 KOps/s | |
test_distributed | 24.5000μs | 8.7714μs | 114.0063 KOps/s | 111.5959 KOps/s | |
test_tdmodule | 1.7473ms | 29.0542μs | 34.4184 KOps/s | 35.7583 KOps/s | |
test_tdmodule_dispatch | 0.2819ms | 52.8669μs | 18.9154 KOps/s | 18.5245 KOps/s | |
test_tdseq | 53.9000μs | 32.5450μs | 30.7267 KOps/s | 30.3281 KOps/s | |
test_tdseq_dispatch | 0.1878ms | 65.0498μs | 15.3728 KOps/s | 15.0229 KOps/s | |
test_instantiation_functorch | 2.2803ms | 1.6522ms | 605.2390 Ops/s | 603.6009 Ops/s | |
test_instantiation_td | 2.0625ms | 1.3624ms | 733.9987 Ops/s | 649.8956 Ops/s | |
test_exec_functorch | 0.2443ms | 0.1958ms | 5.1071 KOps/s | 5.0981 KOps/s | |
test_exec_td | 0.2893ms | 0.1861ms | 5.3748 KOps/s | 5.2983 KOps/s | |
test_vmap_mlp_speed[True-True] | 10.6313ms | 1.2330ms | 811.0055 Ops/s | 819.1887 Ops/s | |
test_vmap_mlp_speed[True-False] | 8.5331ms | 0.6557ms | 1.5252 KOps/s | 1.5429 KOps/s | |
test_vmap_mlp_speed[False-True] | 10.7645ms | 1.0803ms | 925.6845 Ops/s | 940.7586 Ops/s | |
test_vmap_mlp_speed[False-False] | 3.5379ms | 0.4837ms | 2.0674 KOps/s | 2.0105 KOps/s |
matteobettini
approved these changes
Oct 10, 2023
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
bug
Something isn't working
CLA Signed
This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Masks in
torch.where
for lazy stacks need to have as many dims as the tensordict to be unbound along the stack dim. This PR makes sure that it's the case.