Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] torch.where: Expand mask on the right in lazy stack tds #542

Merged
merged 10 commits into from
Oct 10, 2023

Conversation

vmoens
Copy link
Contributor

@vmoens vmoens commented Oct 10, 2023

Masks in torch.where for lazy stacks need to have as many dims as the tensordict to be unbound along the stack dim. This PR makes sure that it's the case.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 10, 2023
@vmoens vmoens changed the title [BugFix] Expand mask on the right in lazy stack tds [BugFix] torch.where: Expand mask on the right in lazy stack tds Oct 10, 2023
@vmoens vmoens added the bug Something isn't working label Oct 10, 2023
@github-actions
Copy link

github-actions bot commented Oct 10, 2023

$\color{#D29922}\textsf{\Large⚠\kern{0.2cm}\normalsize Warning}$ Result of CPU Benchmark Tests

Total Benchmarks: 105. Improved: $\large\color{#35bf28}4$. Worsened: $\large\color{#d91a1a}1$.

Expand to view detailed results
Name Max Mean Ops Ops on Repo HEAD Change
test_plain_set_nested 39.3010μs 20.0317μs 49.9208 KOps/s 49.7603 KOps/s $\color{#35bf28}+0.32\%$
test_plain_set_stack_nested 0.2249ms 0.1847ms 5.4134 KOps/s 5.3740 KOps/s $\color{#35bf28}+0.73\%$
test_plain_set_nested_inplace 41.3000μs 23.6374μs 42.3059 KOps/s 42.3696 KOps/s $\color{#d91a1a}-0.15\%$
test_plain_set_stack_nested_inplace 0.3069ms 0.2189ms 4.5680 KOps/s 4.5492 KOps/s $\color{#35bf28}+0.41\%$
test_items 20.8010μs 3.5134μs 284.6217 KOps/s 283.0765 KOps/s $\color{#35bf28}+0.55\%$
test_items_nested 0.4697ms 0.3701ms 2.7018 KOps/s 2.7397 KOps/s $\color{#d91a1a}-1.38\%$
test_items_nested_locked 0.4584ms 0.3703ms 2.7006 KOps/s 2.7456 KOps/s $\color{#d91a1a}-1.64\%$
test_items_nested_leaf 0.3074ms 0.2233ms 4.4785 KOps/s 4.4966 KOps/s $\color{#d91a1a}-0.40\%$
test_items_stack_nested 2.1321ms 1.9754ms 506.2388 Ops/s 505.3010 Ops/s $\color{#35bf28}+0.19\%$
test_items_stack_nested_leaf 1.9100ms 1.7965ms 556.6445 Ops/s 548.9681 Ops/s $\color{#35bf28}+1.40\%$
test_items_stack_nested_locked 1.0932ms 0.9712ms 1.0297 KOps/s 1.0137 KOps/s $\color{#35bf28}+1.58\%$
test_keys 27.0010μs 5.0654μs 197.4175 KOps/s 196.1921 KOps/s $\color{#35bf28}+0.62\%$
test_keys_nested 0.8365ms 0.1817ms 5.5042 KOps/s 4.9852 KOps/s $\textbf{\color{#35bf28}+10.41\%}$
test_keys_nested_locked 0.2638ms 0.1800ms 5.5569 KOps/s 5.5000 KOps/s $\color{#35bf28}+1.03\%$
test_keys_nested_leaf 0.9100ms 0.1766ms 5.6616 KOps/s 5.7291 KOps/s $\color{#d91a1a}-1.18\%$
test_keys_stack_nested 1.9717ms 1.8201ms 549.4156 Ops/s 541.6559 Ops/s $\color{#35bf28}+1.43\%$
test_keys_stack_nested_leaf 1.9799ms 1.8221ms 548.8197 Ops/s 545.0268 Ops/s $\color{#35bf28}+0.70\%$
test_keys_stack_nested_locked 1.0184ms 0.8178ms 1.2228 KOps/s 1.2032 KOps/s $\color{#35bf28}+1.63\%$
test_values 17.5003μs 1.3895μs 719.6575 KOps/s 621.6983 KOps/s $\textbf{\color{#35bf28}+15.76\%}$
test_values_nested 0.1501ms 66.9743μs 14.9311 KOps/s 14.7533 KOps/s $\color{#35bf28}+1.21\%$
test_values_nested_locked 0.1585ms 67.1261μs 14.8973 KOps/s 14.8513 KOps/s $\color{#35bf28}+0.31\%$
test_values_nested_leaf 0.1550ms 58.3605μs 17.1349 KOps/s 16.9566 KOps/s $\color{#35bf28}+1.05\%$
test_values_stack_nested 1.6642ms 1.5867ms 630.2372 Ops/s 628.8637 Ops/s $\color{#35bf28}+0.22\%$
test_values_stack_nested_leaf 1.7215ms 1.5801ms 632.8844 Ops/s 633.3336 Ops/s $\color{#d91a1a}-0.07\%$
test_values_stack_nested_locked 0.8040ms 0.6396ms 1.5635 KOps/s 1.5319 KOps/s $\color{#35bf28}+2.06\%$
test_membership 67.0010μs 1.8863μs 530.1284 KOps/s 525.6228 KOps/s $\color{#35bf28}+0.86\%$
test_membership_nested 40.4010μs 3.6515μs 273.8566 KOps/s 272.7630 KOps/s $\color{#35bf28}+0.40\%$
test_membership_nested_leaf 25.6010μs 3.7439μs 267.1037 KOps/s 271.3841 KOps/s $\color{#d91a1a}-1.58\%$
test_membership_stacked_nested 0.1007ms 14.4092μs 69.4000 KOps/s 69.3092 KOps/s $\color{#35bf28}+0.13\%$
test_membership_stacked_nested_leaf 70.8010μs 14.3789μs 69.5462 KOps/s 69.4387 KOps/s $\color{#35bf28}+0.15\%$
test_membership_nested_last 27.1000μs 7.5400μs 132.6254 KOps/s 132.8320 KOps/s $\color{#d91a1a}-0.16\%$
test_membership_nested_leaf_last 93.2020μs 7.6358μs 130.9628 KOps/s 133.4512 KOps/s $\color{#d91a1a}-1.86\%$
test_membership_stacked_nested_last 0.3581ms 0.2280ms 4.3853 KOps/s 4.4272 KOps/s $\color{#d91a1a}-0.94\%$
test_membership_stacked_nested_leaf_last 54.2010μs 16.8313μs 59.4131 KOps/s 59.1548 KOps/s $\color{#35bf28}+0.44\%$
test_nested_getleaf 0.1027ms 15.6424μs 63.9290 KOps/s 64.1580 KOps/s $\color{#d91a1a}-0.36\%$
test_nested_get 0.1067ms 14.8766μs 67.2195 KOps/s 67.7628 KOps/s $\color{#d91a1a}-0.80\%$
test_stacked_getleaf 1.0582ms 0.8818ms 1.1340 KOps/s 1.1436 KOps/s $\color{#d91a1a}-0.84\%$
test_stacked_get 0.9685ms 0.8388ms 1.1922 KOps/s 1.1879 KOps/s $\color{#35bf28}+0.36\%$
test_nested_getitemleaf 96.5010μs 15.6292μs 63.9828 KOps/s 64.4908 KOps/s $\color{#d91a1a}-0.79\%$
test_nested_getitem 0.1062ms 14.8617μs 67.2872 KOps/s 67.7218 KOps/s $\color{#d91a1a}-0.64\%$
test_stacked_getitemleaf 1.0689ms 0.8843ms 1.1308 KOps/s 1.1444 KOps/s $\color{#d91a1a}-1.19\%$
test_stacked_getitem 0.9908ms 0.8388ms 1.1921 KOps/s 1.1932 KOps/s $\color{#d91a1a}-0.09\%$
test_lock_nested 96.0545ms 1.5995ms 625.2008 Ops/s 678.5814 Ops/s $\textbf{\color{#d91a1a}-7.87\%}$
test_lock_stack_nested 0.1159s 21.8540ms 45.7582 Ops/s 44.5582 Ops/s $\color{#35bf28}+2.69\%$
test_unlock_nested 75.3770ms 1.5782ms 633.6456 Ops/s 634.6044 Ops/s $\color{#d91a1a}-0.15\%$
test_unlock_stack_nested 0.1159s 22.4832ms 44.4776 Ops/s 46.7330 Ops/s $\color{#d91a1a}-4.83\%$
test_flatten_speed 1.0456ms 1.0003ms 999.6978 Ops/s 967.6769 Ops/s $\color{#35bf28}+3.31\%$
test_unflatten_speed 1.8772ms 1.7925ms 557.8955 Ops/s 547.2071 Ops/s $\color{#35bf28}+1.95\%$
test_common_ops 7.0275ms 1.1165ms 895.6822 Ops/s 888.0178 Ops/s $\color{#35bf28}+0.86\%$
test_creation 0.2873ms 6.1170μs 163.4780 KOps/s 163.2646 KOps/s $\color{#35bf28}+0.13\%$
test_creation_empty 84.9020μs 13.4551μs 74.3214 KOps/s 73.2266 KOps/s $\color{#35bf28}+1.50\%$
test_creation_nested_1 56.1010μs 24.4363μs 40.9227 KOps/s 40.1688 KOps/s $\color{#35bf28}+1.88\%$
test_creation_nested_2 74.1010μs 26.6349μs 37.5447 KOps/s 37.0601 KOps/s $\color{#35bf28}+1.31\%$
test_clone 0.2100ms 24.8616μs 40.2227 KOps/s 41.1436 KOps/s $\color{#d91a1a}-2.24\%$
test_getitem[int] 57.3010μs 28.2781μs 35.3631 KOps/s 35.2494 KOps/s $\color{#35bf28}+0.32\%$
test_getitem[slice_int] 98.3020μs 55.2106μs 18.1125 KOps/s 18.1014 KOps/s $\color{#35bf28}+0.06\%$
test_getitem[range] 0.1770ms 81.0748μs 12.3343 KOps/s 12.3035 KOps/s $\color{#35bf28}+0.25\%$
test_getitem[tuple] 70.1010μs 45.7862μs 21.8406 KOps/s 21.4105 KOps/s $\color{#35bf28}+2.01\%$
test_getitem[list] 0.4040ms 76.5691μs 13.0601 KOps/s 12.9804 KOps/s $\color{#35bf28}+0.61\%$
test_setitem_dim[int] 69.9010μs 34.5950μs 28.9059 KOps/s 29.1416 KOps/s $\color{#d91a1a}-0.81\%$
test_setitem_dim[slice_int] 92.7020μs 61.3029μs 16.3125 KOps/s 16.5435 KOps/s $\color{#d91a1a}-1.40\%$
test_setitem_dim[range] 0.1189ms 80.3825μs 12.4405 KOps/s 12.2250 KOps/s $\color{#35bf28}+1.76\%$
test_setitem_dim[tuple] 76.9010μs 50.2961μs 19.8823 KOps/s 19.8762 KOps/s $\color{#35bf28}+0.03\%$
test_setitem 0.2335ms 32.0979μs 31.1547 KOps/s 31.1057 KOps/s $\color{#35bf28}+0.16\%$
test_set 0.2394ms 30.6950μs 32.5786 KOps/s 32.1935 KOps/s $\color{#35bf28}+1.20\%$
test_set_shared 6.4905ms 0.2081ms 4.8046 KOps/s 4.9344 KOps/s $\color{#d91a1a}-2.63\%$
test_update 0.2436ms 35.0241μs 28.5518 KOps/s 28.1346 KOps/s $\color{#35bf28}+1.48\%$
test_update_nested 0.3040ms 51.7060μs 19.3401 KOps/s 19.1310 KOps/s $\color{#35bf28}+1.09\%$
test_set_nested 0.2045ms 34.0193μs 29.3950 KOps/s 28.9959 KOps/s $\color{#35bf28}+1.38\%$
test_set_nested_new 0.2724ms 53.5183μs 18.6852 KOps/s 18.8586 KOps/s $\color{#d91a1a}-0.92\%$
test_select 0.2803ms 99.0671μs 10.0942 KOps/s 10.3495 KOps/s $\color{#d91a1a}-2.47\%$
test_unbind_speed 0.6852ms 0.6453ms 1.5496 KOps/s 1.5544 KOps/s $\color{#d91a1a}-0.31\%$
test_unbind_speed_stack0 93.6423ms 8.8237ms 113.3308 Ops/s 108.8995 Ops/s $\color{#35bf28}+4.07\%$
test_unbind_speed_stack1 11.4002μs 0.9463μs 1.0568 MOps/s 863.9092 KOps/s $\textbf{\color{#35bf28}+22.32\%}$
test_creation[device0] 0.5963ms 0.4516ms 2.2146 KOps/s 2.2132 KOps/s $\color{#35bf28}+0.06\%$
test_creation_from_tensor 4.5838ms 0.5104ms 1.9593 KOps/s 1.9502 KOps/s $\color{#35bf28}+0.47\%$
test_add_one[memmap_tensor0] 2.2176ms 32.8085μs 30.4799 KOps/s 30.0315 KOps/s $\color{#35bf28}+1.49\%$
test_contiguous[memmap_tensor0] 37.0000μs 8.7577μs 114.1855 KOps/s 109.3695 KOps/s $\color{#35bf28}+4.40\%$
test_stack[memmap_tensor0] 89.5020μs 26.8878μs 37.1916 KOps/s 37.3791 KOps/s $\color{#d91a1a}-0.50\%$
test_memmaptd_index 0.3847ms 0.3154ms 3.1702 KOps/s 3.1236 KOps/s $\color{#35bf28}+1.49\%$
test_memmaptd_index_astensor 1.2771ms 1.2207ms 819.2231 Ops/s 810.2075 Ops/s $\color{#35bf28}+1.11\%$
test_memmaptd_index_op 2.7302ms 2.6700ms 374.5280 Ops/s 377.4644 Ops/s $\color{#d91a1a}-0.78\%$
test_reshape_pytree 93.2020μs 32.4585μs 30.8086 KOps/s 30.2262 KOps/s $\color{#35bf28}+1.93\%$
test_reshape_td 0.3555ms 40.0688μs 24.9570 KOps/s 24.5804 KOps/s $\color{#35bf28}+1.53\%$
test_view_pytree 0.1399ms 32.1996μs 31.0563 KOps/s 30.6202 KOps/s $\color{#35bf28}+1.42\%$
test_view_td 70.0020μs 8.8803μs 112.6091 KOps/s 112.4053 KOps/s $\color{#35bf28}+0.18\%$
test_unbind_pytree 79.7010μs 37.5351μs 26.6417 KOps/s 26.2873 KOps/s $\color{#35bf28}+1.35\%$
test_unbind_td 0.2090ms 95.4766μs 10.4738 KOps/s 10.3124 KOps/s $\color{#35bf28}+1.57\%$
test_split_pytree 81.8010μs 37.0200μs 27.0124 KOps/s 26.6062 KOps/s $\color{#35bf28}+1.53\%$
test_split_td 1.0116ms 0.1069ms 9.3533 KOps/s 9.1746 KOps/s $\color{#35bf28}+1.95\%$
test_add_pytree 97.7020μs 46.7915μs 21.3714 KOps/s 21.2650 KOps/s $\color{#35bf28}+0.50\%$
test_add_td 0.1559ms 76.8822μs 13.0069 KOps/s 12.9811 KOps/s $\color{#35bf28}+0.20\%$
test_distributed 24.5000μs 8.7714μs 114.0063 KOps/s 111.5959 KOps/s $\color{#35bf28}+2.16\%$
test_tdmodule 1.7473ms 29.0542μs 34.4184 KOps/s 35.7583 KOps/s $\color{#d91a1a}-3.75\%$
test_tdmodule_dispatch 0.2819ms 52.8669μs 18.9154 KOps/s 18.5245 KOps/s $\color{#35bf28}+2.11\%$
test_tdseq 53.9000μs 32.5450μs 30.7267 KOps/s 30.3281 KOps/s $\color{#35bf28}+1.31\%$
test_tdseq_dispatch 0.1878ms 65.0498μs 15.3728 KOps/s 15.0229 KOps/s $\color{#35bf28}+2.33\%$
test_instantiation_functorch 2.2803ms 1.6522ms 605.2390 Ops/s 603.6009 Ops/s $\color{#35bf28}+0.27\%$
test_instantiation_td 2.0625ms 1.3624ms 733.9987 Ops/s 649.8956 Ops/s $\textbf{\color{#35bf28}+12.94\%}$
test_exec_functorch 0.2443ms 0.1958ms 5.1071 KOps/s 5.0981 KOps/s $\color{#35bf28}+0.18\%$
test_exec_td 0.2893ms 0.1861ms 5.3748 KOps/s 5.2983 KOps/s $\color{#35bf28}+1.44\%$
test_vmap_mlp_speed[True-True] 10.6313ms 1.2330ms 811.0055 Ops/s 819.1887 Ops/s $\color{#d91a1a}-1.00\%$
test_vmap_mlp_speed[True-False] 8.5331ms 0.6557ms 1.5252 KOps/s 1.5429 KOps/s $\color{#d91a1a}-1.15\%$
test_vmap_mlp_speed[False-True] 10.7645ms 1.0803ms 925.6845 Ops/s 940.7586 Ops/s $\color{#d91a1a}-1.60\%$
test_vmap_mlp_speed[False-False] 3.5379ms 0.4837ms 2.0674 KOps/s 2.0105 KOps/s $\color{#35bf28}+2.83\%$

Copy link
Contributor

@matteobettini matteobettini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@vmoens vmoens merged commit d8daea7 into main Oct 10, 2023
38 of 41 checks passed
@vmoens vmoens deleted the bug_fix_nested_where branch October 10, 2023 16:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants