-
Notifications
You must be signed in to change notification settings - Fork 76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Support classmethods in tensorclass #448
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! thanks! left couple of comments
tensordict/tensorclass.py
Outdated
return res | ||
func = res | ||
|
||
def _wrap_func(self, attr, func): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here you just moved the closure out, right? why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I created a _wrap_method that mirrors a new _wrap_func in such a way that we have 2 distinct but semantically equivalent methods for clarity
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, one minor comment on naming which I found a bit confusing
tensordict/tensorclass.py
Outdated
return res | ||
func = res | ||
|
||
def _wrap_func(self, attr, func): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps we could call this _wrap_method
and rename _wrap_method
to _wrap_classmethod
?
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_items | 86.4010μs | 3.0524μs | 327.6082 KOps/s | 383.9771 KOps/s | |
test_items_nested | 2.8802ms | 0.7380ms | 1.3550 KOps/s | 1.2458 KOps/s | |
test_items_nested_locked | 0.5419ms | 0.2430ms | 4.1151 KOps/s | 3.7705 KOps/s | |
test_items_nested_leaf | 1.5446ms | 0.4139ms | 2.4159 KOps/s | 2.2330 KOps/s | |
test_items_stack_nested | 60.4390ms | 39.4375ms | 25.3566 Ops/s | 22.4709 Ops/s | |
test_items_stack_nested_leaf | 34.3456ms | 22.9753ms | 43.5250 Ops/s | 41.5916 Ops/s | |
test_items_stack_nested_locked | 8.9337ms | 2.6229ms | 381.2505 Ops/s | 353.3051 Ops/s | |
test_keys | 82.8000μs | 6.2537μs | 159.9065 KOps/s | 165.4656 KOps/s | |
test_keys_nested | 1.5435ms | 0.2211ms | 4.5226 KOps/s | 4.2741 KOps/s | |
test_keys_nested_locked | 1.9231ms | 0.2209ms | 4.5276 KOps/s | 3.9912 KOps/s | |
test_keys_nested_leaf | 2.3839ms | 0.2284ms | 4.3780 KOps/s | 4.0415 KOps/s | |
test_keys_stack_nested | 13.2699ms | 3.1388ms | 318.5900 Ops/s | 301.4084 Ops/s | |
test_keys_stack_nested_leaf | 9.5335ms | 2.9274ms | 341.6024 Ops/s | 317.1260 Ops/s | |
test_keys_stack_nested_locked | 10.6375ms | 1.1065ms | 903.7204 Ops/s | 840.1519 Ops/s | |
test_values | 18.9000μs | 1.3135μs | 761.3195 KOps/s | 559.9051 KOps/s | |
test_values_nested | 1.0915ms | 0.7543ms | 1.3257 KOps/s | 1.3204 KOps/s | |
test_values_nested_locked | 0.5576ms | 0.2381ms | 4.2005 KOps/s | 4.1939 KOps/s | |
test_values_nested_leaf | 1.1811ms | 0.4004ms | 2.4974 KOps/s | 2.3922 KOps/s | |
test_values_stack_nested | 53.1970ms | 38.9181ms | 25.6950 Ops/s | 23.2965 Ops/s | |
test_values_stack_nested_leaf | 36.5532ms | 21.6164ms | 46.2611 Ops/s | 44.0434 Ops/s | |
test_values_stack_nested_locked | 7.1219ms | 2.5772ms | 388.0119 Ops/s | 366.0238 Ops/s | |
test_membership | 63.0010μs | 2.2567μs | 443.1214 KOps/s | 469.0185 KOps/s | |
test_membership_nested | 80.0010μs | 4.1417μs | 241.4447 KOps/s | 245.8855 KOps/s | |
test_membership_nested_leaf | 95.2010μs | 3.9083μs | 255.8636 KOps/s | 264.1983 KOps/s | |
test_membership_stacked_nested | 18.9010μs | 1.9722μs | 507.0446 KOps/s | 529.4030 KOps/s | |
test_membership_stacked_nested_leaf | 0.1119ms | 2.3455μs | 426.3451 KOps/s | 367.7931 KOps/s | |
test_stacked_getleaf | 7.6902ms | 1.9270ms | 518.9492 Ops/s | 531.0652 Ops/s | |
test_stacked_get | 5.1825ms | 1.7751ms | 563.3412 Ops/s | 549.4165 Ops/s | |
test_lock_nested | 10.7928ms | 1.1722ms | 853.1194 Ops/s | 780.7064 Ops/s | |
test_lock_stack_nested | 0.1007s | 16.6499ms | 60.0603 Ops/s | 55.5661 Ops/s | |
test_unlock_nested | 7.3706ms | 1.2495ms | 800.3189 Ops/s | 720.7077 Ops/s | |
test_unlock_stack_nested | 0.1106s | 18.1067ms | 55.2281 Ops/s | 64.6418 Ops/s | |
test_flatten_speed | 4.4370ms | 1.2390ms | 807.0856 Ops/s | 798.9442 Ops/s | |
test_unflatten_speed | 9.7717ms | 2.2453ms | 445.3790 Ops/s | 432.4132 Ops/s | |
test_common_ops | 1.8263ms | 1.5881ms | 629.6971 Ops/s | 593.2949 Ops/s | |
test_creation | 16.4132μs | 6.4118μs | 155.9615 KOps/s | 135.8165 KOps/s | |
test_creation_empty | 32.5693μs | 16.2936μs | 61.3740 KOps/s | 59.4571 KOps/s | |
test_creation_nested_1 | 38.0763μs | 30.3319μs | 32.9686 KOps/s | 32.0038 KOps/s | |
test_creation_nested_2 | 78.7667μs | 31.9056μs | 31.3425 KOps/s | 30.0426 KOps/s | |
test_clone | 41.2414μs | 30.5439μs | 32.7397 KOps/s | 29.6958 KOps/s | |
test_getitem[int] | 53.1772μs | 38.2463μs | 26.1463 KOps/s | 22.9112 KOps/s | |
test_getitem[slice_int] | 0.1179ms | 89.0674μs | 11.2274 KOps/s | 10.5077 KOps/s | |
test_getitem[range] | 0.1467ms | 0.1027ms | 9.7357 KOps/s | 8.7394 KOps/s | |
test_getitem[tuple] | 98.6744μs | 78.7750μs | 12.6944 KOps/s | 11.7319 KOps/s | |
test_getitem[list] | 0.1329ms | 91.9909μs | 10.8706 KOps/s | 10.3778 KOps/s | |
test_setitem_dim[int] | 5.3230ms | 75.0739μs | 13.3202 KOps/s | 14.0774 KOps/s | |
test_setitem_dim[slice_int] | 6.7132ms | 0.1502ms | 6.6559 KOps/s | 7.2736 KOps/s | |
test_setitem_dim[range] | 5.5496ms | 0.1443ms | 6.9280 KOps/s | 7.2078 KOps/s | |
test_setitem_dim[tuple] | 5.9654ms | 0.1269ms | 7.8805 KOps/s | 8.5212 KOps/s | |
test_setitem | 61.4345μs | 46.0202μs | 21.7296 KOps/s | 23.2471 KOps/s | |
test_set | 72.7316μs | 44.8500μs | 22.2965 KOps/s | 23.4475 KOps/s | |
test_set_shared | 0.3807ms | 0.3046ms | 3.2827 KOps/s | 3.4387 KOps/s | |
test_update | 75.8417μs | 51.4382μs | 19.4408 KOps/s | 19.8490 KOps/s | |
test_update_nested | 91.9018μs | 71.1181μs | 14.0611 KOps/s | 14.4616 KOps/s | |
test_set_nested | 0.1208ms | 65.8283μs | 15.1910 KOps/s | 17.5465 KOps/s | |
test_set_nested_new | 0.1121ms | 86.5948μs | 11.5480 KOps/s | 12.0286 KOps/s | |
test_select | 0.1811ms | 0.1449ms | 6.9008 KOps/s | 6.9445 KOps/s | |
test_creation[device0] | 1.5660ms | 0.6507ms | 1.5367 KOps/s | 1.4850 KOps/s | |
test_creation_from_tensor | 0.7566ms | 0.5579ms | 1.7925 KOps/s | 1.6505 KOps/s | |
test_add_one[memmap_tensor0] | 0.1152ms | 62.4718μs | 16.0072 KOps/s | 15.7953 KOps/s | |
test_contiguous[memmap_tensor0] | 22.0682μs | 12.1228μs | 82.4895 KOps/s | 86.1344 KOps/s | |
test_stack[memmap_tensor0] | 0.1969ms | 58.0706μs | 17.2204 KOps/s | 16.7615 KOps/s | |
test_memmaptd_index | 1.3537ms | 0.3474ms | 2.8787 KOps/s | 2.6900 KOps/s | |
test_memmaptd_index_astensor | 5.3301ms | 1.7138ms | 583.4896 Ops/s | 560.4371 Ops/s | |
test_memmaptd_index_op | 12.6160ms | 4.4839ms | 223.0221 Ops/s | 224.8893 Ops/s | |
test_reshape_pytree | 72.6826μs | 47.7061μs | 20.9617 KOps/s | 20.7021 KOps/s | |
test_reshape_td | 0.1013ms | 61.1768μs | 16.3461 KOps/s | 16.9860 KOps/s | |
test_view_pytree | 56.4575μs | 44.5365μs | 22.4535 KOps/s | 20.8872 KOps/s | |
test_view_td | 15.8511μs | 11.8578μs | 84.3323 KOps/s | 76.9481 KOps/s | |
test_unbind_pytree | 95.7509μs | 53.4975μs | 18.6925 KOps/s | 19.9054 KOps/s | |
test_unbind_td | 0.2989ms | 0.1980ms | 5.0500 KOps/s | 5.2441 KOps/s | |
test_split_pytree | 93.8488μs | 56.7829μs | 17.6109 KOps/s | 16.8797 KOps/s | |
test_split_td | 0.1941ms | 0.1547ms | 6.4656 KOps/s | 6.0956 KOps/s | |
test_add_pytree | 0.1173ms | 66.7863μs | 14.9731 KOps/s | 14.4212 KOps/s | |
test_add_td | 0.1098ms | 86.8849μs | 11.5095 KOps/s | 10.6198 KOps/s | |
test_distributed | 88.0010μs | 88.0010μs | 11.3635 KOps/s | 11.6958 KOps/s | |
test_tdmodule | 5.3392ms | 39.6435μs | 25.2248 KOps/s | 23.1159 KOps/s | |
test_tdmodule_dispatch | 0.9818ms | 84.5577μs | 11.8262 KOps/s | 10.6447 KOps/s | |
test_tdseq | 0.7131ms | 46.1719μs | 21.6582 KOps/s | 20.4574 KOps/s | |
test_tdseq_dispatch | 4.1626ms | 0.1019ms | 9.8170 KOps/s | 8.9017 KOps/s | |
test_instantiation_functorch | 2.9208ms | 2.2446ms | 445.5141 Ops/s | 439.4855 Ops/s | |
test_instantiation_td | 2.4976ms | 1.7649ms | 566.6100 Ops/s | 568.0258 Ops/s | |
test_exec_functorch | 0.3784ms | 0.2777ms | 3.6016 KOps/s | 3.2494 KOps/s | |
test_exec_td | 0.3806ms | 0.2795ms | 3.5781 KOps/s | 3.3008 KOps/s | |
test_vmap_mlp_speed[True-True] | 2.9535ms | 2.5926ms | 385.7160 Ops/s | 366.5945 Ops/s | |
test_vmap_mlp_speed[True-False] | 1.4265ms | 1.1312ms | 884.0171 Ops/s | 846.5646 Ops/s | |
test_vmap_mlp_speed[False-True] | 2.6628ms | 2.2735ms | 439.8596 Ops/s | 410.8131 Ops/s | |
test_vmap_mlp_speed[False-False] | 1.1417ms | 0.9666ms | 1.0346 KOps/s | 986.2635 Ops/s | |
test_vmap_transformer_speed[True-True] | 37.5613ms | 30.1376ms | 33.1812 Ops/s | 31.7219 Ops/s | |
test_vmap_transformer_speed[True-False] | 17.0442ms | 15.6932ms | 63.7218 Ops/s | 62.2246 Ops/s | |
test_vmap_transformer_speed[False-True] | 32.9490ms | 29.2220ms | 34.2208 Ops/s | 33.3026 Ops/s | |
test_vmap_transformer_speed[False-False] | 16.9143ms | 15.2169ms | 65.7164 Ops/s | 63.6289 Ops/s |
No description provided.