Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Pointwise arithmetic operations using foreach #722

Merged
merged 16 commits into from
Apr 9, 2024
Merged

Conversation

vmoens
Copy link
Contributor

@vmoens vmoens commented Mar 27, 2024

Test the feature:

from tensordict import TensorDict
from torch.nn import Transformer
import torch
from torch.utils.benchmark import Timer
t = Transformer(device="cuda" if torch.cuda.is_available() else "cpu")

params = TensorDict.from_module(t).apply(lambda t: t.data.clone().fill_(0.0))
other_params = TensorDict.from_module(t).apply(lambda t: t.data.clone().fill_(1.0))

assert (params.add(other_params) == 1).all()

print("\n\nadd")
params.add(other_params)
print("foreach", Timer("params.add(other_params)", globals=globals()).adaptive_autorange())
params.lock_()
other_params.lock_()
params.add(other_params)
print("foreach, locked", Timer("params.add(other_params)", globals=globals()).adaptive_autorange())
params.apply(lambda x, y: x+y, other_params)
print("apply", Timer("params.apply(lambda x, y: x+y, other_params)", globals=globals()).adaptive_autorange())
listval0 = params._values_list(True, True)
listval1 = other_params._values_list(True, True)
torch._foreach_add(listval0, listval1)
print("plain foreach", Timer("torch._foreach_add(listval0, listval1)", globals=globals()).adaptive_autorange())

params.unlock_()
other_params.unlock_()

print("\n\nadd_")
params.add(other_params)
print("foreach", Timer("params.add_(other_params)", globals=globals()).adaptive_autorange())
params.lock_()
other_params.lock_()
params.add(other_params)
print("foreach, locked", Timer("params.add_(other_params)", globals=globals()).adaptive_autorange())
params.apply(lambda x, y: x+y, other_params)
print("apply", Timer("params.apply(lambda x, y: x.add_(y), other_params)", globals=globals()).adaptive_autorange())
listval0 = params._values_list(True, True)
listval1 = other_params._values_list(True, True)
torch._foreach_add_(listval0, listval1)
print("plain foreach", Timer("torch._foreach_add_(listval0, listval1)", globals=globals()).adaptive_autorange())

Gives these results:

add
foreach <torch.utils.benchmark.utils.common.Measurement object at 0x7f51c7f5ab60>
params.add(other_params)
  Median: 3.62 ms
  IQR:    0.07 ms (3.59 to 3.66)
  4 measurements, 100 runs per measurement, 1 thread
foreach, locked <torch.utils.benchmark.utils.common.Measurement object at 0x7f5169b13e20>
params.add(other_params)
  Median: 1.94 ms
  IQR:    0.12 ms (1.93 to 2.05)
  4 measurements, 100 runs per measurement, 1 thread
apply <torch.utils.benchmark.utils.common.Measurement object at 0x7f5169b11840>
params.apply(lambda x, y: x+y, other_params)
  Median: 4.18 ms
  IQR:    0.02 ms (4.17 to 4.19)
  4 measurements, 100 runs per measurement, 1 thread
plain foreach <torch.utils.benchmark.utils.common.Measurement object at 0x7f5169b12d40>
torch._foreach_add(listval0, listval1)
  Median: 567.85 us
  IQR:    8.00 us (563.76 to 571.76)
  4 measurements, 100 runs per measurement, 1 thread


add_
foreach <torch.utils.benchmark.utils.common.Measurement object at 0x7f51c7f59cc0>
params.add_(other_params)
  Median: 929.06 us
  IQR:    9.48 us (925.40 to 934.88)
  4 measurements, 100 runs per measurement, 1 thread
foreach, locked <torch.utils.benchmark.utils.common.Measurement object at 0x7f5169b11540>
params.add_(other_params)
  Median: 353.06 us
  IQR:    0.10 us (353.00 to 353.10)
  4 measurements, 1000 runs per measurement, 1 thread
apply <torch.utils.benchmark.utils.common.Measurement object at 0x7f5169b13f10>
params.apply(lambda x, y: x.add_(y), other_params)
  Median: 3.34 ms
  IQR:    0.02 ms (3.32 to 3.34)
  4 measurements, 100 runs per measurement, 1 thread
plain foreach <torch.utils.benchmark.utils.common.Measurement object at 0x7f5169b107f0>
torch._foreach_add_(listval0, listval1)
  Median: 353.14 us
  IQR:    0.07 us (353.10 to 353.17)
  4 measurements, 1000 runs per measurement, 1 thread

So there is some overhead for add due to the fact that we need to rebuild the tensordict, but if we don't (add_) it's as fast as plain foreach and you can keep the nice structure of your params

cc @fmassa @ahmed-touati @teopir @MateuszGuzek @zou3519 @ezyang @albanD

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 27, 2024
@vmoens vmoens added the enhancement New feature or request label Mar 27, 2024
Copy link

github-actions bot commented Mar 27, 2024

$\color{#D29922}\textsf{\Large&amp;#x26A0;\kern{0.2cm}\normalsize Warning}$ Result of CPU Benchmark Tests

Total Benchmarks: 127. Improved: $\large\color{#35bf28}3$. Worsened: $\large\color{#d91a1a}25$.

Expand to view detailed results
Name Max Mean Ops Ops on Repo HEAD Change
test_plain_set_nested 31.8990μs 16.3345μs 61.2201 KOps/s 64.9443 KOps/s $\textbf{\color{#d91a1a}-5.73\%}$
test_plain_set_stack_nested 37.6510μs 16.7309μs 59.7698 KOps/s 63.8743 KOps/s $\textbf{\color{#d91a1a}-6.43\%}$
test_plain_set_nested_inplace 95.0570μs 18.6417μs 53.6431 KOps/s 56.2384 KOps/s $\color{#d91a1a}-4.61\%$
test_plain_set_stack_nested_inplace 69.5780μs 18.3914μs 54.3732 KOps/s 56.5358 KOps/s $\color{#d91a1a}-3.83\%$
test_items 27.0410μs 2.5584μs 390.8669 KOps/s 394.5466 KOps/s $\color{#d91a1a}-0.93\%$
test_items_nested 0.4198ms 0.2699ms 3.7048 KOps/s 3.6495 KOps/s $\color{#35bf28}+1.51\%$
test_items_nested_locked 1.1383ms 0.2703ms 3.6992 KOps/s 3.6968 KOps/s $\color{#35bf28}+0.07\%$
test_items_nested_leaf 0.6037ms 0.1652ms 6.0525 KOps/s 5.9241 KOps/s $\color{#35bf28}+2.17\%$
test_items_stack_nested 0.4865ms 0.2709ms 3.6921 KOps/s 3.6581 KOps/s $\color{#35bf28}+0.93\%$
test_items_stack_nested_leaf 0.7317ms 0.1662ms 6.0158 KOps/s 5.9779 KOps/s $\color{#35bf28}+0.63\%$
test_items_stack_nested_locked 0.6052ms 0.2735ms 3.6562 KOps/s 3.6377 KOps/s $\color{#35bf28}+0.51\%$
test_keys 21.1090μs 3.9081μs 255.8790 KOps/s 255.4344 KOps/s $\color{#35bf28}+0.17\%$
test_keys_nested 0.7543ms 0.1451ms 6.8900 KOps/s 6.9657 KOps/s $\color{#d91a1a}-1.09\%$
test_keys_nested_locked 0.2914ms 0.1481ms 6.7521 KOps/s 6.6836 KOps/s $\color{#35bf28}+1.03\%$
test_keys_nested_leaf 39.7538ms 0.1332ms 7.5096 KOps/s 8.0266 KOps/s $\textbf{\color{#d91a1a}-6.44\%}$
test_keys_stack_nested 0.2722ms 0.1452ms 6.8890 KOps/s 6.8145 KOps/s $\color{#35bf28}+1.09\%$
test_keys_stack_nested_leaf 0.2242ms 0.1279ms 7.8184 KOps/s 7.7602 KOps/s $\color{#35bf28}+0.75\%$
test_keys_stack_nested_locked 0.2905ms 0.1508ms 6.6293 KOps/s 6.6134 KOps/s $\color{#35bf28}+0.24\%$
test_values 9.3148μs 1.1720μs 853.2132 KOps/s 869.8257 KOps/s $\color{#d91a1a}-1.91\%$
test_values_nested 0.1098ms 50.2774μs 19.8897 KOps/s 19.9195 KOps/s $\color{#d91a1a}-0.15\%$
test_values_nested_locked 0.1120ms 52.6641μs 18.9883 KOps/s 19.8498 KOps/s $\color{#d91a1a}-4.34\%$
test_values_nested_leaf 91.2800μs 45.1893μs 22.1291 KOps/s 21.3161 KOps/s $\color{#35bf28}+3.81\%$
test_values_stack_nested 98.0640μs 50.7072μs 19.7211 KOps/s 19.2933 KOps/s $\color{#35bf28}+2.22\%$
test_values_stack_nested_leaf 94.0850μs 45.2638μs 22.0927 KOps/s 21.8018 KOps/s $\color{#35bf28}+1.33\%$
test_values_stack_nested_locked 93.2440μs 50.4211μs 19.8330 KOps/s 19.3380 KOps/s $\color{#35bf28}+2.56\%$
test_membership 33.4220μs 1.3602μs 735.2031 KOps/s 740.2316 KOps/s $\color{#d91a1a}-0.68\%$
test_membership_nested 36.1780μs 3.5015μs 285.5899 KOps/s 289.9956 KOps/s $\color{#d91a1a}-1.52\%$
test_membership_nested_leaf 29.3150μs 3.5049μs 285.3189 KOps/s 293.1181 KOps/s $\color{#d91a1a}-2.66\%$
test_membership_stacked_nested 29.6760μs 3.4622μs 288.8349 KOps/s 297.8264 KOps/s $\color{#d91a1a}-3.02\%$
test_membership_stacked_nested_leaf 25.2970μs 3.4663μs 288.4951 KOps/s 295.3611 KOps/s $\color{#d91a1a}-2.32\%$
test_membership_nested_last 19.6160μs 4.2486μs 235.3705 KOps/s 228.4426 KOps/s $\color{#35bf28}+3.03\%$
test_membership_nested_leaf_last 44.3330μs 4.2732μs 234.0189 KOps/s 220.6480 KOps/s $\textbf{\color{#35bf28}+6.06\%}$
test_membership_stacked_nested_last 42.7300μs 4.2483μs 235.3866 KOps/s 236.7873 KOps/s $\color{#d91a1a}-0.59\%$
test_membership_stacked_nested_leaf_last 37.8700μs 4.3088μs 232.0856 KOps/s 235.9441 KOps/s $\color{#d91a1a}-1.64\%$
test_nested_getleaf 48.6300μs 10.7457μs 93.0601 KOps/s 94.6065 KOps/s $\color{#d91a1a}-1.63\%$
test_nested_get 36.9890μs 10.1714μs 98.3147 KOps/s 100.2075 KOps/s $\color{#d91a1a}-1.89\%$
test_stacked_getleaf 41.0470μs 10.6788μs 93.6436 KOps/s 94.7532 KOps/s $\color{#d91a1a}-1.17\%$
test_stacked_get 54.0710μs 10.1867μs 98.1670 KOps/s 101.9039 KOps/s $\color{#d91a1a}-3.67\%$
test_nested_getitemleaf 65.9730μs 11.3664μs 87.9783 KOps/s 91.3916 KOps/s $\color{#d91a1a}-3.73\%$
test_nested_getitem 37.4800μs 10.3153μs 96.9433 KOps/s 99.1700 KOps/s $\color{#d91a1a}-2.25\%$
test_stacked_getitemleaf 45.6750μs 11.2523μs 88.8709 KOps/s 91.0168 KOps/s $\color{#d91a1a}-2.36\%$
test_stacked_getitem 39.3140μs 10.3874μs 96.2703 KOps/s 99.3457 KOps/s $\color{#d91a1a}-3.10\%$
test_lock_nested 0.7618ms 0.3442ms 2.9055 KOps/s 2.8336 KOps/s $\color{#35bf28}+2.54\%$
test_lock_stack_nested 0.4531ms 0.3076ms 3.2510 KOps/s 3.3083 KOps/s $\color{#d91a1a}-1.73\%$
test_unlock_nested 95.0257ms 0.4411ms 2.2673 KOps/s 2.3430 KOps/s $\color{#d91a1a}-3.23\%$
test_unlock_stack_nested 0.4680ms 0.3156ms 3.1688 KOps/s 3.1924 KOps/s $\color{#d91a1a}-0.74\%$
test_flatten_speed 0.5691ms 0.2661ms 3.7574 KOps/s 3.7950 KOps/s $\color{#d91a1a}-0.99\%$
test_unflatten_speed 0.7190ms 0.4113ms 2.4313 KOps/s 2.4792 KOps/s $\color{#d91a1a}-1.93\%$
test_common_ops 4.7291ms 0.6933ms 1.4423 KOps/s 1.5767 KOps/s $\textbf{\color{#d91a1a}-8.53\%}$
test_creation 27.3010μs 1.8468μs 541.4705 KOps/s 550.8993 KOps/s $\color{#d91a1a}-1.71\%$
test_creation_empty 28.6730μs 9.1592μs 109.1796 KOps/s 127.3858 KOps/s $\textbf{\color{#d91a1a}-14.29\%}$
test_creation_nested_1 33.9740μs 11.6299μs 85.9856 KOps/s 94.7013 KOps/s $\textbf{\color{#d91a1a}-9.20\%}$
test_creation_nested_2 52.6280μs 15.2045μs 65.7699 KOps/s 72.6513 KOps/s $\textbf{\color{#d91a1a}-9.47\%}$
test_clone 47.1680μs 13.3174μs 75.0899 KOps/s 75.3692 KOps/s $\color{#d91a1a}-0.37\%$
test_getitem[int] 33.2820μs 11.2423μs 88.9502 KOps/s 89.4844 KOps/s $\color{#d91a1a}-0.60\%$
test_getitem[slice_int] 59.5410μs 22.9573μs 43.5591 KOps/s 43.8278 KOps/s $\color{#d91a1a}-0.61\%$
test_getitem[range] 0.1116ms 42.4776μs 23.5418 KOps/s 24.4939 KOps/s $\color{#d91a1a}-3.89\%$
test_getitem[tuple] 48.8310μs 18.7556μs 53.3174 KOps/s 54.0223 KOps/s $\color{#d91a1a}-1.30\%$
test_getitem[list] 0.1929ms 37.4572μs 26.6971 KOps/s 26.1327 KOps/s $\color{#35bf28}+2.16\%$
test_setitem_dim[int] 55.4730μs 33.2861μs 30.0426 KOps/s 32.3086 KOps/s $\textbf{\color{#d91a1a}-7.01\%}$
test_setitem_dim[slice_int] 0.1100ms 60.0228μs 16.6603 KOps/s 17.7983 KOps/s $\textbf{\color{#d91a1a}-6.39\%}$
test_setitem_dim[range] 0.1632ms 80.9688μs 12.3504 KOps/s 13.4377 KOps/s $\textbf{\color{#d91a1a}-8.09\%}$
test_setitem_dim[tuple] 92.5730μs 48.8170μs 20.4847 KOps/s 21.1081 KOps/s $\color{#d91a1a}-2.95\%$
test_setitem 67.1560μs 19.7928μs 50.5235 KOps/s 53.9054 KOps/s $\textbf{\color{#d91a1a}-6.27\%}$
test_set 70.2920μs 19.2486μs 51.9519 KOps/s 56.4955 KOps/s $\textbf{\color{#d91a1a}-8.04\%}$
test_set_shared 1.5018ms 0.1382ms 7.2368 KOps/s 7.3115 KOps/s $\color{#d91a1a}-1.02\%$
test_update 0.1495ms 20.3964μs 49.0284 KOps/s 55.3096 KOps/s $\textbf{\color{#d91a1a}-11.36\%}$
test_update_nested 80.8100μs 28.6243μs 34.9353 KOps/s 38.7350 KOps/s $\textbf{\color{#d91a1a}-9.81\%}$
test_update__nested 70.8220μs 24.6559μs 40.5582 KOps/s 39.2831 KOps/s $\color{#35bf28}+3.25\%$
test_set_nested 0.1805ms 21.3932μs 46.7439 KOps/s 50.7811 KOps/s $\textbf{\color{#d91a1a}-7.95\%}$
test_set_nested_new 73.1460μs 25.6610μs 38.9697 KOps/s 42.4103 KOps/s $\textbf{\color{#d91a1a}-8.11\%}$
test_select 0.1024ms 39.3516μs 25.4119 KOps/s 26.2009 KOps/s $\color{#d91a1a}-3.01\%$
test_select_nested 0.2084ms 59.8214μs 16.7164 KOps/s 16.8501 KOps/s $\color{#d91a1a}-0.79\%$
test_exclude_nested 0.1954ms 0.1184ms 8.4443 KOps/s 8.3317 KOps/s $\color{#35bf28}+1.35\%$
test_empty[True] 0.8521ms 0.4233ms 2.3625 KOps/s 2.4223 KOps/s $\color{#d91a1a}-2.47\%$
test_empty[False] 14.9538μs 1.0478μs 954.3879 KOps/s 920.8033 KOps/s $\color{#35bf28}+3.65\%$
test_unbind_speed 0.4873ms 0.2697ms 3.7072 KOps/s 4.0064 KOps/s $\textbf{\color{#d91a1a}-7.47\%}$
test_unbind_speed_stack0 0.5137ms 0.2487ms 4.0203 KOps/s 4.0306 KOps/s $\color{#d91a1a}-0.25\%$
test_unbind_speed_stack1 0.1321s 0.6946ms 1.4398 KOps/s 1.4353 KOps/s $\color{#35bf28}+0.31\%$
test_split 0.1277s 1.6986ms 588.7108 Ops/s 612.6801 Ops/s $\color{#d91a1a}-3.91\%$
test_chunk 2.9962ms 1.5194ms 658.1549 Ops/s 685.7237 Ops/s $\color{#d91a1a}-4.02\%$
test_creation[device0] 0.3358ms 0.1025ms 9.7559 KOps/s 9.6927 KOps/s $\color{#35bf28}+0.65\%$
test_creation_from_tensor 4.6853ms 81.8297μs 12.2205 KOps/s 12.0638 KOps/s $\color{#35bf28}+1.30\%$
test_add_one[memmap_tensor0] 79.1070μs 5.8001μs 172.4116 KOps/s 193.1978 KOps/s $\textbf{\color{#d91a1a}-10.76\%}$
test_contiguous[memmap_tensor0] 16.8310μs 0.6426μs 1.5562 MOps/s 1.5102 MOps/s $\color{#35bf28}+3.05\%$
test_stack[memmap_tensor0] 40.3660μs 3.7527μs 266.4727 KOps/s 287.7724 KOps/s $\textbf{\color{#d91a1a}-7.40\%}$
test_memmaptd_index 1.1268ms 0.2425ms 4.1236 KOps/s 4.2333 KOps/s $\color{#d91a1a}-2.59\%$
test_memmaptd_index_astensor 0.5358ms 0.3062ms 3.2654 KOps/s 3.3177 KOps/s $\color{#d91a1a}-1.57\%$
test_memmaptd_index_op 1.1209ms 0.6014ms 1.6629 KOps/s 1.8152 KOps/s $\textbf{\color{#d91a1a}-8.39\%}$
test_serialize_model 0.2478s 0.1179s 8.4845 Ops/s 8.3833 Ops/s $\color{#35bf28}+1.21\%$
test_serialize_model_pickle 0.4544s 0.3797s 2.6336 Ops/s 2.6032 Ops/s $\color{#35bf28}+1.17\%$
test_serialize_weights 0.1047s 98.9182ms 10.1094 Ops/s 9.8665 Ops/s $\color{#35bf28}+2.46\%$
test_serialize_weights_returnearly 0.1314s 0.1251s 7.9937 Ops/s 8.1099 Ops/s $\color{#d91a1a}-1.43\%$
test_serialize_weights_pickle 0.7542s 0.5127s 1.9504 Ops/s 1.4419 Ops/s $\textbf{\color{#35bf28}+35.27\%}$
test_serialize_weights_filesystem 0.1084s 95.9016ms 10.4274 Ops/s 9.8016 Ops/s $\textbf{\color{#35bf28}+6.38\%}$
test_serialize_model_filesystem 0.1029s 94.4276ms 10.5901 Ops/s 10.6697 Ops/s $\color{#d91a1a}-0.75\%$
test_reshape_pytree 70.0910μs 20.8581μs 47.9430 KOps/s 48.7885 KOps/s $\color{#d91a1a}-1.73\%$
test_reshape_td 73.4270μs 32.8844μs 30.4096 KOps/s 30.5577 KOps/s $\color{#d91a1a}-0.48\%$
test_view_pytree 65.4720μs 21.4097μs 46.7079 KOps/s 48.2124 KOps/s $\color{#d91a1a}-3.12\%$
test_view_td 0.1273s 62.9896μs 15.8756 KOps/s 15.6140 KOps/s $\color{#35bf28}+1.68\%$
test_unbind_pytree 62.0260μs 24.7463μs 40.4100 KOps/s 40.3206 KOps/s $\color{#35bf28}+0.22\%$
test_unbind_td 0.1048ms 36.6855μs 27.2587 KOps/s 26.8189 KOps/s $\color{#35bf28}+1.64\%$
test_split_pytree 49.6330μs 24.1115μs 41.4741 KOps/s 41.3340 KOps/s $\color{#35bf28}+0.34\%$
test_split_td 0.1124ms 40.9350μs 24.4290 KOps/s 25.0965 KOps/s $\color{#d91a1a}-2.66\%$
test_add_pytree 71.2130μs 30.5968μs 32.6831 KOps/s 33.6483 KOps/s $\color{#d91a1a}-2.87\%$
test_add_td 0.1754ms 57.2260μs 17.4746 KOps/s 19.8377 KOps/s $\textbf{\color{#d91a1a}-11.91\%}$
test_distributed 0.1873ms 99.0378μs 10.0972 KOps/s 10.0026 KOps/s $\color{#35bf28}+0.95\%$
test_tdmodule 44.5430μs 16.8023μs 59.5156 KOps/s 62.8893 KOps/s $\textbf{\color{#d91a1a}-5.36\%}$
test_tdmodule_dispatch 63.7390μs 32.9821μs 30.3195 KOps/s 31.7710 KOps/s $\color{#d91a1a}-4.57\%$
test_tdseq 34.2040μs 19.5049μs 51.2691 KOps/s 54.1408 KOps/s $\textbf{\color{#d91a1a}-5.30\%}$
test_tdseq_dispatch 77.6330μs 38.4812μs 25.9867 KOps/s 27.9449 KOps/s $\textbf{\color{#d91a1a}-7.01\%}$
test_instantiation_functorch 2.1136ms 1.3255ms 754.4204 Ops/s 770.4929 Ops/s $\color{#d91a1a}-2.09\%$
test_instantiation_td 1.4341ms 1.0132ms 986.9964 Ops/s 988.7741 Ops/s $\color{#d91a1a}-0.18\%$
test_exec_functorch 0.2227ms 0.1610ms 6.2121 KOps/s 6.2793 KOps/s $\color{#d91a1a}-1.07\%$
test_exec_functional_call 0.3536ms 0.1528ms 6.5431 KOps/s 6.9109 KOps/s $\textbf{\color{#d91a1a}-5.32\%}$
test_exec_td 0.2726ms 0.1449ms 6.9007 KOps/s 7.1599 KOps/s $\color{#d91a1a}-3.62\%$
test_exec_td_decorator 0.7548ms 0.1988ms 5.0313 KOps/s 5.1040 KOps/s $\color{#d91a1a}-1.42\%$
test_vmap_mlp_speed[True-True] 0.7163ms 0.4708ms 2.1240 KOps/s 2.2129 KOps/s $\color{#d91a1a}-4.02\%$
test_vmap_mlp_speed[True-False] 0.7227ms 0.4663ms 2.1443 KOps/s 2.2067 KOps/s $\color{#d91a1a}-2.83\%$
test_vmap_mlp_speed[False-True] 0.5380ms 0.3830ms 2.6107 KOps/s 2.6909 KOps/s $\color{#d91a1a}-2.98\%$
test_vmap_mlp_speed[False-False] 0.6767ms 0.3840ms 2.6042 KOps/s 2.6839 KOps/s $\color{#d91a1a}-2.97\%$
test_vmap_mlp_speed_decorator[True-True] 1.0100ms 0.4995ms 2.0021 KOps/s 2.1021 KOps/s $\color{#d91a1a}-4.76\%$
test_vmap_mlp_speed_decorator[True-False] 0.8497ms 0.4952ms 2.0194 KOps/s 2.1057 KOps/s $\color{#d91a1a}-4.10\%$
test_vmap_mlp_speed_decorator[False-True] 0.6117ms 0.3991ms 2.5054 KOps/s 2.5435 KOps/s $\color{#d91a1a}-1.50\%$
test_vmap_mlp_speed_decorator[False-False] 0.7236ms 0.4003ms 2.4984 KOps/s 2.5641 KOps/s $\color{#d91a1a}-2.56\%$
test_to_module_speed[True] 2.0987ms 1.4177ms 705.3619 Ops/s 709.6159 Ops/s $\color{#d91a1a}-0.60\%$
test_to_module_speed[False] 1.6380ms 1.4177ms 705.3828 Ops/s 711.1263 Ops/s $\color{#d91a1a}-0.81\%$

@vmoens
Copy link
Contributor Author

vmoens commented Mar 27, 2024

@janeyx99 @fmassa

Here is a version of Adam in 23 lines of code (no weight decay and no edge case whatsoever) with tensordict.
With this you can easily recycle your adam for a sub-sub module, just index the weights accordingly, e.g. self.weight["decoder", "0"]

import torch

from tensordict import TensorDict
from tensordict.nn import TensorDictParams
from torch import nn

class Adam:
    def __init__(self, weights: TensorDictParams, alpha: float=1e-3, beta1: float=0.9, beta2: float=0.999, eps: float = 1e-6, weight_decay: float=0.0):
        # Lock for efficiency
        weights = weights.lock_()
        self.weights = weights
        self.t = 0

        self._mu = weights.data.clone()
        self._sigma = weights.data.mul(0.0)
        self.beta1 = beta1
        self.beta2 = beta2
        self.alpha = alpha
        self.eps = eps

    def step(self):
        self._mu.mul_(self.beta1).add_(self.weights.grad, 1 - self.beta1)
        self._sigma.mul_(self.beta2).add_(self.weights.grad.pow(2), 1 - self.beta2)
        self.t += 1
        mu = self._mu.div_(1-self.beta1**self.t)
        sigma = self._sigma.div_(1 - self.beta2 ** self.t)
        self.weights.data.add_(mu.div_(sigma.sqrt_().add_(self.eps)).mul_(-self.alpha))

    def zero_grad(self):
        self.weights.grad.zero_()

device = "cpu" if not torch.cuda.is_available() else "cuda"
with torch.device(device):
    net = nn.Transformer()
    weights = TensorDict.from_module(net, as_module=True)
    adam = Adam(weights=weights)
    net(torch.randn(5, 5, 512), torch.randn(5, 5, 512)).mean().backward()
    adam.step()
    adam.zero_grad()

    from torch.utils.benchmark import Timer
    net(torch.randn(5, 5, 512), torch.randn(5, 5, 512)).mean().backward()
    print("td", Timer("adam.step()", globals=globals()).adaptive_autorange())
    adam.zero_grad()

    adam_orig = torch.optim.Adam(params=net.parameters())

    net(torch.randn(5, 5, 512), torch.randn(5, 5, 512)).mean().backward()
    adam_orig.step()
    adam_orig.zero_grad()

    net(torch.randn(5, 5, 512), torch.randn(5, 5, 512)).mean().backward()
    print("native", Timer("adam.step()", globals=globals()).adaptive_autorange())
    adam.zero_grad()

On my machine and on A100 I get a similar runtime (again, this is much simpler than torch.optim - my point is that at least it isn't slower!)

@vmoens
Copy link
Contributor Author

vmoens commented Mar 28, 2024

@matteobettini

With this you can do

td += 1

with a lazy stack and it's efficient (much more efficient than stack -> do the op -> call setitem which was the default behaviour before). That will allow us to do ops on lazy stacks as if they were regular contiguous stacks of tensors

@vmoens vmoens merged commit eabefcc into main Apr 9, 2024
44 of 48 checks passed
@vmoens vmoens deleted the foreach branch April 9, 2024 10:48
@janeyx99
Copy link

Just to make sure I understand your point--this is showing that foreach is built into how TensorDicts can get one thing applied to multiple params at once?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants