-
Notifications
You must be signed in to change notification settings - Fork 76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Pointwise arithmetic operations using foreach #722
Conversation
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_plain_set_nested | 31.8990μs | 16.3345μs | 61.2201 KOps/s | 64.9443 KOps/s | |
test_plain_set_stack_nested | 37.6510μs | 16.7309μs | 59.7698 KOps/s | 63.8743 KOps/s | |
test_plain_set_nested_inplace | 95.0570μs | 18.6417μs | 53.6431 KOps/s | 56.2384 KOps/s | |
test_plain_set_stack_nested_inplace | 69.5780μs | 18.3914μs | 54.3732 KOps/s | 56.5358 KOps/s | |
test_items | 27.0410μs | 2.5584μs | 390.8669 KOps/s | 394.5466 KOps/s | |
test_items_nested | 0.4198ms | 0.2699ms | 3.7048 KOps/s | 3.6495 KOps/s | |
test_items_nested_locked | 1.1383ms | 0.2703ms | 3.6992 KOps/s | 3.6968 KOps/s | |
test_items_nested_leaf | 0.6037ms | 0.1652ms | 6.0525 KOps/s | 5.9241 KOps/s | |
test_items_stack_nested | 0.4865ms | 0.2709ms | 3.6921 KOps/s | 3.6581 KOps/s | |
test_items_stack_nested_leaf | 0.7317ms | 0.1662ms | 6.0158 KOps/s | 5.9779 KOps/s | |
test_items_stack_nested_locked | 0.6052ms | 0.2735ms | 3.6562 KOps/s | 3.6377 KOps/s | |
test_keys | 21.1090μs | 3.9081μs | 255.8790 KOps/s | 255.4344 KOps/s | |
test_keys_nested | 0.7543ms | 0.1451ms | 6.8900 KOps/s | 6.9657 KOps/s | |
test_keys_nested_locked | 0.2914ms | 0.1481ms | 6.7521 KOps/s | 6.6836 KOps/s | |
test_keys_nested_leaf | 39.7538ms | 0.1332ms | 7.5096 KOps/s | 8.0266 KOps/s | |
test_keys_stack_nested | 0.2722ms | 0.1452ms | 6.8890 KOps/s | 6.8145 KOps/s | |
test_keys_stack_nested_leaf | 0.2242ms | 0.1279ms | 7.8184 KOps/s | 7.7602 KOps/s | |
test_keys_stack_nested_locked | 0.2905ms | 0.1508ms | 6.6293 KOps/s | 6.6134 KOps/s | |
test_values | 9.3148μs | 1.1720μs | 853.2132 KOps/s | 869.8257 KOps/s | |
test_values_nested | 0.1098ms | 50.2774μs | 19.8897 KOps/s | 19.9195 KOps/s | |
test_values_nested_locked | 0.1120ms | 52.6641μs | 18.9883 KOps/s | 19.8498 KOps/s | |
test_values_nested_leaf | 91.2800μs | 45.1893μs | 22.1291 KOps/s | 21.3161 KOps/s | |
test_values_stack_nested | 98.0640μs | 50.7072μs | 19.7211 KOps/s | 19.2933 KOps/s | |
test_values_stack_nested_leaf | 94.0850μs | 45.2638μs | 22.0927 KOps/s | 21.8018 KOps/s | |
test_values_stack_nested_locked | 93.2440μs | 50.4211μs | 19.8330 KOps/s | 19.3380 KOps/s | |
test_membership | 33.4220μs | 1.3602μs | 735.2031 KOps/s | 740.2316 KOps/s | |
test_membership_nested | 36.1780μs | 3.5015μs | 285.5899 KOps/s | 289.9956 KOps/s | |
test_membership_nested_leaf | 29.3150μs | 3.5049μs | 285.3189 KOps/s | 293.1181 KOps/s | |
test_membership_stacked_nested | 29.6760μs | 3.4622μs | 288.8349 KOps/s | 297.8264 KOps/s | |
test_membership_stacked_nested_leaf | 25.2970μs | 3.4663μs | 288.4951 KOps/s | 295.3611 KOps/s | |
test_membership_nested_last | 19.6160μs | 4.2486μs | 235.3705 KOps/s | 228.4426 KOps/s | |
test_membership_nested_leaf_last | 44.3330μs | 4.2732μs | 234.0189 KOps/s | 220.6480 KOps/s | |
test_membership_stacked_nested_last | 42.7300μs | 4.2483μs | 235.3866 KOps/s | 236.7873 KOps/s | |
test_membership_stacked_nested_leaf_last | 37.8700μs | 4.3088μs | 232.0856 KOps/s | 235.9441 KOps/s | |
test_nested_getleaf | 48.6300μs | 10.7457μs | 93.0601 KOps/s | 94.6065 KOps/s | |
test_nested_get | 36.9890μs | 10.1714μs | 98.3147 KOps/s | 100.2075 KOps/s | |
test_stacked_getleaf | 41.0470μs | 10.6788μs | 93.6436 KOps/s | 94.7532 KOps/s | |
test_stacked_get | 54.0710μs | 10.1867μs | 98.1670 KOps/s | 101.9039 KOps/s | |
test_nested_getitemleaf | 65.9730μs | 11.3664μs | 87.9783 KOps/s | 91.3916 KOps/s | |
test_nested_getitem | 37.4800μs | 10.3153μs | 96.9433 KOps/s | 99.1700 KOps/s | |
test_stacked_getitemleaf | 45.6750μs | 11.2523μs | 88.8709 KOps/s | 91.0168 KOps/s | |
test_stacked_getitem | 39.3140μs | 10.3874μs | 96.2703 KOps/s | 99.3457 KOps/s | |
test_lock_nested | 0.7618ms | 0.3442ms | 2.9055 KOps/s | 2.8336 KOps/s | |
test_lock_stack_nested | 0.4531ms | 0.3076ms | 3.2510 KOps/s | 3.3083 KOps/s | |
test_unlock_nested | 95.0257ms | 0.4411ms | 2.2673 KOps/s | 2.3430 KOps/s | |
test_unlock_stack_nested | 0.4680ms | 0.3156ms | 3.1688 KOps/s | 3.1924 KOps/s | |
test_flatten_speed | 0.5691ms | 0.2661ms | 3.7574 KOps/s | 3.7950 KOps/s | |
test_unflatten_speed | 0.7190ms | 0.4113ms | 2.4313 KOps/s | 2.4792 KOps/s | |
test_common_ops | 4.7291ms | 0.6933ms | 1.4423 KOps/s | 1.5767 KOps/s | |
test_creation | 27.3010μs | 1.8468μs | 541.4705 KOps/s | 550.8993 KOps/s | |
test_creation_empty | 28.6730μs | 9.1592μs | 109.1796 KOps/s | 127.3858 KOps/s | |
test_creation_nested_1 | 33.9740μs | 11.6299μs | 85.9856 KOps/s | 94.7013 KOps/s | |
test_creation_nested_2 | 52.6280μs | 15.2045μs | 65.7699 KOps/s | 72.6513 KOps/s | |
test_clone | 47.1680μs | 13.3174μs | 75.0899 KOps/s | 75.3692 KOps/s | |
test_getitem[int] | 33.2820μs | 11.2423μs | 88.9502 KOps/s | 89.4844 KOps/s | |
test_getitem[slice_int] | 59.5410μs | 22.9573μs | 43.5591 KOps/s | 43.8278 KOps/s | |
test_getitem[range] | 0.1116ms | 42.4776μs | 23.5418 KOps/s | 24.4939 KOps/s | |
test_getitem[tuple] | 48.8310μs | 18.7556μs | 53.3174 KOps/s | 54.0223 KOps/s | |
test_getitem[list] | 0.1929ms | 37.4572μs | 26.6971 KOps/s | 26.1327 KOps/s | |
test_setitem_dim[int] | 55.4730μs | 33.2861μs | 30.0426 KOps/s | 32.3086 KOps/s | |
test_setitem_dim[slice_int] | 0.1100ms | 60.0228μs | 16.6603 KOps/s | 17.7983 KOps/s | |
test_setitem_dim[range] | 0.1632ms | 80.9688μs | 12.3504 KOps/s | 13.4377 KOps/s | |
test_setitem_dim[tuple] | 92.5730μs | 48.8170μs | 20.4847 KOps/s | 21.1081 KOps/s | |
test_setitem | 67.1560μs | 19.7928μs | 50.5235 KOps/s | 53.9054 KOps/s | |
test_set | 70.2920μs | 19.2486μs | 51.9519 KOps/s | 56.4955 KOps/s | |
test_set_shared | 1.5018ms | 0.1382ms | 7.2368 KOps/s | 7.3115 KOps/s | |
test_update | 0.1495ms | 20.3964μs | 49.0284 KOps/s | 55.3096 KOps/s | |
test_update_nested | 80.8100μs | 28.6243μs | 34.9353 KOps/s | 38.7350 KOps/s | |
test_update__nested | 70.8220μs | 24.6559μs | 40.5582 KOps/s | 39.2831 KOps/s | |
test_set_nested | 0.1805ms | 21.3932μs | 46.7439 KOps/s | 50.7811 KOps/s | |
test_set_nested_new | 73.1460μs | 25.6610μs | 38.9697 KOps/s | 42.4103 KOps/s | |
test_select | 0.1024ms | 39.3516μs | 25.4119 KOps/s | 26.2009 KOps/s | |
test_select_nested | 0.2084ms | 59.8214μs | 16.7164 KOps/s | 16.8501 KOps/s | |
test_exclude_nested | 0.1954ms | 0.1184ms | 8.4443 KOps/s | 8.3317 KOps/s | |
test_empty[True] | 0.8521ms | 0.4233ms | 2.3625 KOps/s | 2.4223 KOps/s | |
test_empty[False] | 14.9538μs | 1.0478μs | 954.3879 KOps/s | 920.8033 KOps/s | |
test_unbind_speed | 0.4873ms | 0.2697ms | 3.7072 KOps/s | 4.0064 KOps/s | |
test_unbind_speed_stack0 | 0.5137ms | 0.2487ms | 4.0203 KOps/s | 4.0306 KOps/s | |
test_unbind_speed_stack1 | 0.1321s | 0.6946ms | 1.4398 KOps/s | 1.4353 KOps/s | |
test_split | 0.1277s | 1.6986ms | 588.7108 Ops/s | 612.6801 Ops/s | |
test_chunk | 2.9962ms | 1.5194ms | 658.1549 Ops/s | 685.7237 Ops/s | |
test_creation[device0] | 0.3358ms | 0.1025ms | 9.7559 KOps/s | 9.6927 KOps/s | |
test_creation_from_tensor | 4.6853ms | 81.8297μs | 12.2205 KOps/s | 12.0638 KOps/s | |
test_add_one[memmap_tensor0] | 79.1070μs | 5.8001μs | 172.4116 KOps/s | 193.1978 KOps/s | |
test_contiguous[memmap_tensor0] | 16.8310μs | 0.6426μs | 1.5562 MOps/s | 1.5102 MOps/s | |
test_stack[memmap_tensor0] | 40.3660μs | 3.7527μs | 266.4727 KOps/s | 287.7724 KOps/s | |
test_memmaptd_index | 1.1268ms | 0.2425ms | 4.1236 KOps/s | 4.2333 KOps/s | |
test_memmaptd_index_astensor | 0.5358ms | 0.3062ms | 3.2654 KOps/s | 3.3177 KOps/s | |
test_memmaptd_index_op | 1.1209ms | 0.6014ms | 1.6629 KOps/s | 1.8152 KOps/s | |
test_serialize_model | 0.2478s | 0.1179s | 8.4845 Ops/s | 8.3833 Ops/s | |
test_serialize_model_pickle | 0.4544s | 0.3797s | 2.6336 Ops/s | 2.6032 Ops/s | |
test_serialize_weights | 0.1047s | 98.9182ms | 10.1094 Ops/s | 9.8665 Ops/s | |
test_serialize_weights_returnearly | 0.1314s | 0.1251s | 7.9937 Ops/s | 8.1099 Ops/s | |
test_serialize_weights_pickle | 0.7542s | 0.5127s | 1.9504 Ops/s | 1.4419 Ops/s | |
test_serialize_weights_filesystem | 0.1084s | 95.9016ms | 10.4274 Ops/s | 9.8016 Ops/s | |
test_serialize_model_filesystem | 0.1029s | 94.4276ms | 10.5901 Ops/s | 10.6697 Ops/s | |
test_reshape_pytree | 70.0910μs | 20.8581μs | 47.9430 KOps/s | 48.7885 KOps/s | |
test_reshape_td | 73.4270μs | 32.8844μs | 30.4096 KOps/s | 30.5577 KOps/s | |
test_view_pytree | 65.4720μs | 21.4097μs | 46.7079 KOps/s | 48.2124 KOps/s | |
test_view_td | 0.1273s | 62.9896μs | 15.8756 KOps/s | 15.6140 KOps/s | |
test_unbind_pytree | 62.0260μs | 24.7463μs | 40.4100 KOps/s | 40.3206 KOps/s | |
test_unbind_td | 0.1048ms | 36.6855μs | 27.2587 KOps/s | 26.8189 KOps/s | |
test_split_pytree | 49.6330μs | 24.1115μs | 41.4741 KOps/s | 41.3340 KOps/s | |
test_split_td | 0.1124ms | 40.9350μs | 24.4290 KOps/s | 25.0965 KOps/s | |
test_add_pytree | 71.2130μs | 30.5968μs | 32.6831 KOps/s | 33.6483 KOps/s | |
test_add_td | 0.1754ms | 57.2260μs | 17.4746 KOps/s | 19.8377 KOps/s | |
test_distributed | 0.1873ms | 99.0378μs | 10.0972 KOps/s | 10.0026 KOps/s | |
test_tdmodule | 44.5430μs | 16.8023μs | 59.5156 KOps/s | 62.8893 KOps/s | |
test_tdmodule_dispatch | 63.7390μs | 32.9821μs | 30.3195 KOps/s | 31.7710 KOps/s | |
test_tdseq | 34.2040μs | 19.5049μs | 51.2691 KOps/s | 54.1408 KOps/s | |
test_tdseq_dispatch | 77.6330μs | 38.4812μs | 25.9867 KOps/s | 27.9449 KOps/s | |
test_instantiation_functorch | 2.1136ms | 1.3255ms | 754.4204 Ops/s | 770.4929 Ops/s | |
test_instantiation_td | 1.4341ms | 1.0132ms | 986.9964 Ops/s | 988.7741 Ops/s | |
test_exec_functorch | 0.2227ms | 0.1610ms | 6.2121 KOps/s | 6.2793 KOps/s | |
test_exec_functional_call | 0.3536ms | 0.1528ms | 6.5431 KOps/s | 6.9109 KOps/s | |
test_exec_td | 0.2726ms | 0.1449ms | 6.9007 KOps/s | 7.1599 KOps/s | |
test_exec_td_decorator | 0.7548ms | 0.1988ms | 5.0313 KOps/s | 5.1040 KOps/s | |
test_vmap_mlp_speed[True-True] | 0.7163ms | 0.4708ms | 2.1240 KOps/s | 2.2129 KOps/s | |
test_vmap_mlp_speed[True-False] | 0.7227ms | 0.4663ms | 2.1443 KOps/s | 2.2067 KOps/s | |
test_vmap_mlp_speed[False-True] | 0.5380ms | 0.3830ms | 2.6107 KOps/s | 2.6909 KOps/s | |
test_vmap_mlp_speed[False-False] | 0.6767ms | 0.3840ms | 2.6042 KOps/s | 2.6839 KOps/s | |
test_vmap_mlp_speed_decorator[True-True] | 1.0100ms | 0.4995ms | 2.0021 KOps/s | 2.1021 KOps/s | |
test_vmap_mlp_speed_decorator[True-False] | 0.8497ms | 0.4952ms | 2.0194 KOps/s | 2.1057 KOps/s | |
test_vmap_mlp_speed_decorator[False-True] | 0.6117ms | 0.3991ms | 2.5054 KOps/s | 2.5435 KOps/s | |
test_vmap_mlp_speed_decorator[False-False] | 0.7236ms | 0.4003ms | 2.4984 KOps/s | 2.5641 KOps/s | |
test_to_module_speed[True] | 2.0987ms | 1.4177ms | 705.3619 Ops/s | 709.6159 Ops/s | |
test_to_module_speed[False] | 1.6380ms | 1.4177ms | 705.3828 Ops/s | 711.1263 Ops/s |
Here is a version of Adam in 23 lines of code (no weight decay and no edge case whatsoever) with tensordict. import torch
from tensordict import TensorDict
from tensordict.nn import TensorDictParams
from torch import nn
class Adam:
def __init__(self, weights: TensorDictParams, alpha: float=1e-3, beta1: float=0.9, beta2: float=0.999, eps: float = 1e-6, weight_decay: float=0.0):
# Lock for efficiency
weights = weights.lock_()
self.weights = weights
self.t = 0
self._mu = weights.data.clone()
self._sigma = weights.data.mul(0.0)
self.beta1 = beta1
self.beta2 = beta2
self.alpha = alpha
self.eps = eps
def step(self):
self._mu.mul_(self.beta1).add_(self.weights.grad, 1 - self.beta1)
self._sigma.mul_(self.beta2).add_(self.weights.grad.pow(2), 1 - self.beta2)
self.t += 1
mu = self._mu.div_(1-self.beta1**self.t)
sigma = self._sigma.div_(1 - self.beta2 ** self.t)
self.weights.data.add_(mu.div_(sigma.sqrt_().add_(self.eps)).mul_(-self.alpha))
def zero_grad(self):
self.weights.grad.zero_()
device = "cpu" if not torch.cuda.is_available() else "cuda"
with torch.device(device):
net = nn.Transformer()
weights = TensorDict.from_module(net, as_module=True)
adam = Adam(weights=weights)
net(torch.randn(5, 5, 512), torch.randn(5, 5, 512)).mean().backward()
adam.step()
adam.zero_grad()
from torch.utils.benchmark import Timer
net(torch.randn(5, 5, 512), torch.randn(5, 5, 512)).mean().backward()
print("td", Timer("adam.step()", globals=globals()).adaptive_autorange())
adam.zero_grad()
adam_orig = torch.optim.Adam(params=net.parameters())
net(torch.randn(5, 5, 512), torch.randn(5, 5, 512)).mean().backward()
adam_orig.step()
adam_orig.zero_grad()
net(torch.randn(5, 5, 512), torch.randn(5, 5, 512)).mean().backward()
print("native", Timer("adam.step()", globals=globals()).adaptive_autorange())
adam.zero_grad() On my machine and on A100 I get a similar runtime (again, this is much simpler than torch.optim - my point is that at least it isn't slower!) |
With this you can do
with a lazy stack and it's efficient (much more efficient than stack -> do the op -> call setitem which was the default behaviour before). That will allow us to do ops on lazy stacks as if they were regular contiguous stacks of tensors |
Just to make sure I understand your point--this is showing that foreach is built into how TensorDicts can get one thing applied to multiple params at once? |
Test the feature:
Gives these results:
So there is some overhead for add due to the fact that we need to rebuild the tensordict, but if we don't (add_) it's as fast as plain foreach and you can keep the nice structure of your params
cc @fmassa @ahmed-touati @teopir @MateuszGuzek @zou3519 @ezyang @albanD