Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix, Feature] tensorclass.to_dict and from_dict #707

Merged
merged 12 commits into from
Mar 13, 2024
Merged

Conversation

vmoens
Copy link
Contributor

@vmoens vmoens commented Mar 8, 2024

Nested tensorclasses will need some form of autocasting to make sure that a dict within a dict is interpreted as the right class.

The PR also proposes such auto-casting for tensor containers only (leaves are not auto-cast, as it is the case with tensordict).

The doc has been updated to account for this feature, and tests have been written for it too.

cc @maximilianigl

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 8, 2024
@vmoens vmoens added the bug Something isn't working label Mar 8, 2024
@vmoens vmoens linked an issue Mar 8, 2024 that may be closed by this pull request
3 tasks
@vmoens vmoens marked this pull request as ready for review March 8, 2024 16:46
Copy link

github-actions bot commented Mar 8, 2024

$\color{#D29922}\textsf{\Large⚠\kern{0.2cm}\normalsize Warning}$ Result of CPU Benchmark Tests

Total Benchmarks: 127. Improved: $\large\color{#35bf28}11$. Worsened: $\large\color{#d91a1a}8$.

Expand to view detailed results
Name Max Mean Ops Ops on Repo HEAD Change
test_plain_set_nested 36.4680μs 16.1555μs 61.8984 KOps/s 61.9716 KOps/s $\color{#d91a1a}-0.12\%$
test_plain_set_stack_nested 36.3380μs 16.3335μs 61.2240 KOps/s 61.0128 KOps/s $\color{#35bf28}+0.35\%$
test_plain_set_nested_inplace 78.2090μs 18.5077μs 54.0315 KOps/s 54.0809 KOps/s $\color{#d91a1a}-0.09\%$
test_plain_set_stack_nested_inplace 1.7262ms 18.9092μs 52.8843 KOps/s 53.8382 KOps/s $\color{#d91a1a}-1.77\%$
test_items 36.4980μs 2.4085μs 415.1939 KOps/s 371.8137 KOps/s $\textbf{\color{#35bf28}+11.67\%}$
test_items_nested 0.9483ms 0.2698ms 3.7066 KOps/s 3.7021 KOps/s $\color{#35bf28}+0.12\%$
test_items_nested_locked 0.4884ms 0.2700ms 3.7034 KOps/s 3.7073 KOps/s $\color{#d91a1a}-0.11\%$
test_items_nested_leaf 0.4709ms 0.1682ms 5.9448 KOps/s 5.9950 KOps/s $\color{#d91a1a}-0.84\%$
test_items_stack_nested 0.9418ms 0.2724ms 3.6706 KOps/s 3.6749 KOps/s $\color{#d91a1a}-0.11\%$
test_items_stack_nested_leaf 0.2820ms 0.1670ms 5.9867 KOps/s 6.0094 KOps/s $\color{#d91a1a}-0.38\%$
test_items_stack_nested_locked 0.4935ms 0.2739ms 3.6512 KOps/s 3.6608 KOps/s $\color{#d91a1a}-0.26\%$
test_keys 30.7670μs 3.8801μs 257.7243 KOps/s 259.5754 KOps/s $\color{#d91a1a}-0.71\%$
test_keys_nested 2.3066ms 0.1457ms 6.8633 KOps/s 6.7361 KOps/s $\color{#35bf28}+1.89\%$
test_keys_nested_locked 0.3034ms 0.1509ms 6.6275 KOps/s 6.5143 KOps/s $\color{#35bf28}+1.74\%$
test_keys_nested_leaf 41.9815ms 0.1334ms 7.4949 KOps/s 7.6677 KOps/s $\color{#d91a1a}-2.25\%$
test_keys_stack_nested 0.2520ms 0.1445ms 6.9192 KOps/s 6.5447 KOps/s $\textbf{\color{#35bf28}+5.72\%}$
test_keys_stack_nested_leaf 0.2532ms 0.1259ms 7.9426 KOps/s 7.4672 KOps/s $\textbf{\color{#35bf28}+6.37\%}$
test_keys_stack_nested_locked 0.1980ms 0.1505ms 6.6444 KOps/s 6.3500 KOps/s $\color{#35bf28}+4.64\%$
test_values 7.9388μs 1.1515μs 868.4319 KOps/s 863.2005 KOps/s $\color{#35bf28}+0.61\%$
test_values_nested 94.5660μs 51.3212μs 19.4851 KOps/s 19.4247 KOps/s $\color{#35bf28}+0.31\%$
test_values_nested_locked 0.1003ms 53.5086μs 18.6886 KOps/s 19.3172 KOps/s $\color{#d91a1a}-3.25\%$
test_values_nested_leaf 0.1004ms 46.2299μs 21.6310 KOps/s 21.2420 KOps/s $\color{#35bf28}+1.83\%$
test_values_stack_nested 92.8930μs 50.8822μs 19.6532 KOps/s 18.4214 KOps/s $\textbf{\color{#35bf28}+6.69\%}$
test_values_stack_nested_leaf 92.1810μs 45.9679μs 21.7543 KOps/s 21.6229 KOps/s $\color{#35bf28}+0.61\%$
test_values_stack_nested_locked 0.1005ms 51.5918μs 19.3829 KOps/s 18.6987 KOps/s $\color{#35bf28}+3.66\%$
test_membership 35.4760μs 1.3674μs 731.3029 KOps/s 730.0502 KOps/s $\color{#35bf28}+0.17\%$
test_membership_nested 45.1400μs 3.4350μs 291.1224 KOps/s 295.4661 KOps/s $\color{#d91a1a}-1.47\%$
test_membership_nested_leaf 43.3250μs 3.5256μs 283.6406 KOps/s 294.4298 KOps/s $\color{#d91a1a}-3.66\%$
test_membership_stacked_nested 19.0350μs 3.4525μs 289.6417 KOps/s 292.6901 KOps/s $\color{#d91a1a}-1.04\%$
test_membership_stacked_nested_leaf 27.9220μs 3.5117μs 284.7604 KOps/s 294.3160 KOps/s $\color{#d91a1a}-3.25\%$
test_membership_nested_last 40.2150μs 4.2800μs 233.6439 KOps/s 232.3227 KOps/s $\color{#35bf28}+0.57\%$
test_membership_nested_leaf_last 26.6890μs 4.3267μs 231.1207 KOps/s 233.5536 KOps/s $\color{#d91a1a}-1.04\%$
test_membership_stacked_nested_last 27.7520μs 4.2913μs 233.0277 KOps/s 71.3247 KOps/s $\textbf{\color{#35bf28}+226.71\%}$
test_membership_stacked_nested_leaf_last 27.9420μs 4.3687μs 228.9035 KOps/s 71.0316 KOps/s $\textbf{\color{#35bf28}+222.26\%}$
test_nested_getleaf 32.3600μs 10.7483μs 93.0379 KOps/s 89.6214 KOps/s $\color{#35bf28}+3.81\%$
test_nested_get 31.0280μs 10.1641μs 98.3853 KOps/s 91.7988 KOps/s $\textbf{\color{#35bf28}+7.17\%}$
test_stacked_getleaf 63.0670μs 10.8354μs 92.2902 KOps/s 88.3438 KOps/s $\color{#35bf28}+4.47\%$
test_stacked_get 43.2910μs 9.9892μs 100.1084 KOps/s 93.0177 KOps/s $\textbf{\color{#35bf28}+7.62\%}$
test_nested_getitemleaf 40.7060μs 11.3375μs 88.2029 KOps/s 84.2834 KOps/s $\color{#35bf28}+4.65\%$
test_nested_getitem 35.2950μs 10.5706μs 94.6020 KOps/s 89.1463 KOps/s $\textbf{\color{#35bf28}+6.12\%}$
test_stacked_getitemleaf 55.9150μs 11.3299μs 88.2617 KOps/s 84.1436 KOps/s $\color{#35bf28}+4.89\%$
test_stacked_getitem 0.1501ms 10.6584μs 93.8225 KOps/s 90.7771 KOps/s $\color{#35bf28}+3.35\%$
test_lock_nested 0.7902ms 0.3413ms 2.9303 KOps/s 2.8867 KOps/s $\color{#35bf28}+1.51\%$
test_lock_stack_nested 0.5082ms 0.3051ms 3.2781 KOps/s 3.3688 KOps/s $\color{#d91a1a}-2.69\%$
test_unlock_nested 93.0610ms 0.4326ms 2.3119 KOps/s 2.3582 KOps/s $\color{#d91a1a}-1.96\%$
test_unlock_stack_nested 0.6769ms 0.3147ms 3.1774 KOps/s 3.2738 KOps/s $\color{#d91a1a}-2.94\%$
test_flatten_speed 0.6080ms 0.2648ms 3.7766 KOps/s 3.6138 KOps/s $\color{#35bf28}+4.51\%$
test_unflatten_speed 0.5490ms 0.4068ms 2.4583 KOps/s 2.4412 KOps/s $\color{#35bf28}+0.70\%$
test_common_ops 5.2433ms 0.6871ms 1.4553 KOps/s 1.5006 KOps/s $\color{#d91a1a}-3.01\%$
test_creation 92.6620μs 1.8588μs 537.9705 KOps/s 536.0752 KOps/s $\color{#35bf28}+0.35\%$
test_creation_empty 35.4360μs 9.0127μs 110.9551 KOps/s 109.6528 KOps/s $\color{#35bf28}+1.19\%$
test_creation_nested_1 42.5790μs 11.8411μs 84.4515 KOps/s 84.4815 KOps/s $\color{#d91a1a}-0.04\%$
test_creation_nested_2 43.3610μs 15.1114μs 66.1752 KOps/s 67.1883 KOps/s $\color{#d91a1a}-1.51\%$
test_clone 0.1036ms 13.0741μs 76.4870 KOps/s 75.9697 KOps/s $\color{#35bf28}+0.68\%$
test_getitem[int] 31.0980μs 11.3092μs 88.4235 KOps/s 88.6494 KOps/s $\color{#d91a1a}-0.25\%$
test_getitem[slice_int] 67.4260μs 22.8787μs 43.7089 KOps/s 45.1350 KOps/s $\color{#d91a1a}-3.16\%$
test_getitem[range] 98.1820μs 42.6714μs 23.4349 KOps/s 24.4950 KOps/s $\color{#d91a1a}-4.33\%$
test_getitem[tuple] 60.3520μs 18.7079μs 53.4533 KOps/s 54.4452 KOps/s $\color{#d91a1a}-1.82\%$
test_getitem[list] 0.2591ms 36.9806μs 27.0412 KOps/s 27.3917 KOps/s $\color{#d91a1a}-1.28\%$
test_setitem_dim[int] 72.8560μs 34.6016μs 28.9004 KOps/s 31.7802 KOps/s $\textbf{\color{#d91a1a}-9.06\%}$
test_setitem_dim[slice_int] 0.1654ms 63.0394μs 15.8631 KOps/s 17.2053 KOps/s $\textbf{\color{#d91a1a}-7.80\%}$
test_setitem_dim[range] 0.1658ms 81.4293μs 12.2806 KOps/s 13.4073 KOps/s $\textbf{\color{#d91a1a}-8.40\%}$
test_setitem_dim[tuple] 0.1084ms 49.6128μs 20.1561 KOps/s 21.2876 KOps/s $\textbf{\color{#d91a1a}-5.32\%}$
test_setitem 0.1254ms 19.1792μs 52.1399 KOps/s 52.7244 KOps/s $\color{#d91a1a}-1.11\%$
test_set 0.1236ms 18.6094μs 53.7364 KOps/s 54.3597 KOps/s $\color{#d91a1a}-1.15\%$
test_set_shared 4.1973ms 0.1461ms 6.8438 KOps/s 7.2692 KOps/s $\textbf{\color{#d91a1a}-5.85\%}$
test_update 0.1538ms 21.0055μs 47.6066 KOps/s 48.2933 KOps/s $\color{#d91a1a}-1.42\%$
test_update_nested 0.1465ms 28.6306μs 34.9277 KOps/s 34.9273 KOps/s $+0.00\%$
test_update__nested 0.1126ms 24.8051μs 40.3143 KOps/s 41.5608 KOps/s $\color{#d91a1a}-3.00\%$
test_set_nested 0.1194ms 20.2597μs 49.3590 KOps/s 49.0647 KOps/s $\color{#35bf28}+0.60\%$
test_set_nested_new 0.1669ms 24.9667μs 40.0534 KOps/s 41.0887 KOps/s $\color{#d91a1a}-2.52\%$
test_select 0.9012ms 38.1092μs 26.2404 KOps/s 25.6017 KOps/s $\color{#35bf28}+2.49\%$
test_select_nested 0.1132ms 59.6346μs 16.7688 KOps/s 16.7163 KOps/s $\color{#35bf28}+0.31\%$
test_exclude_nested 0.2219ms 0.1190ms 8.3999 KOps/s 8.4436 KOps/s $\color{#d91a1a}-0.52\%$
test_empty[True] 0.7473ms 0.4088ms 2.4459 KOps/s 2.4182 KOps/s $\color{#35bf28}+1.14\%$
test_empty[False] 7.3978μs 1.0324μs 968.6329 KOps/s 936.1398 KOps/s $\color{#35bf28}+3.47\%$
test_unbind_speed 0.4315ms 0.2461ms 4.0639 KOps/s 3.8512 KOps/s $\textbf{\color{#35bf28}+5.52\%}$
test_unbind_speed_stack0 0.4311ms 0.2426ms 4.1213 KOps/s 4.1697 KOps/s $\color{#d91a1a}-1.16\%$
test_unbind_speed_stack1 0.8302ms 0.6048ms 1.6534 KOps/s 1.4696 KOps/s $\textbf{\color{#35bf28}+12.50\%}$
test_split 0.1342s 1.6708ms 598.5194 Ops/s 601.6987 Ops/s $\color{#d91a1a}-0.53\%$
test_chunk 2.3889ms 1.4780ms 676.6088 Ops/s 681.9105 Ops/s $\color{#d91a1a}-0.78\%$
test_creation[device0] 4.9138ms 0.1059ms 9.4386 KOps/s 9.8778 KOps/s $\color{#d91a1a}-4.45\%$
test_creation_from_tensor 0.2296ms 82.9398μs 12.0569 KOps/s 11.9876 KOps/s $\color{#35bf28}+0.58\%$
test_add_one[memmap_tensor0] 77.1140μs 5.3981μs 185.2511 KOps/s 187.9122 KOps/s $\color{#d91a1a}-1.42\%$
test_contiguous[memmap_tensor0] 26.1890μs 0.6258μs 1.5980 MOps/s 1.5743 MOps/s $\color{#35bf28}+1.51\%$
test_stack[memmap_tensor0] 40.1550μs 3.6205μs 276.2075 KOps/s 274.3330 KOps/s $\color{#35bf28}+0.68\%$
test_memmaptd_index 1.0152ms 0.2430ms 4.1160 KOps/s 4.1139 KOps/s $\color{#35bf28}+0.05\%$
test_memmaptd_index_astensor 0.5478ms 0.3055ms 3.2738 KOps/s 3.2810 KOps/s $\color{#d91a1a}-0.22\%$
test_memmaptd_index_op 0.8407ms 0.5812ms 1.7205 KOps/s 1.7535 KOps/s $\color{#d91a1a}-1.88\%$
test_serialize_model 0.2316s 0.1163s 8.5963 Ops/s 8.6251 Ops/s $\color{#d91a1a}-0.33\%$
test_serialize_model_pickle 0.4489s 0.3798s 2.6328 Ops/s 2.6173 Ops/s $\color{#35bf28}+0.59\%$
test_serialize_weights 0.1028s 98.6704ms 10.1347 Ops/s 10.0995 Ops/s $\color{#35bf28}+0.35\%$
test_serialize_weights_returnearly 0.3064s 0.1552s 6.4434 Ops/s 8.2139 Ops/s $\textbf{\color{#d91a1a}-21.56\%}$
test_serialize_weights_pickle 0.7428s 0.5080s 1.9683 Ops/s 2.4347 Ops/s $\textbf{\color{#d91a1a}-19.15\%}$
test_serialize_weights_filesystem 0.1004s 93.4542ms 10.7004 Ops/s 10.3714 Ops/s $\color{#35bf28}+3.17\%$
test_serialize_model_filesystem 0.1043s 93.8602ms 10.6541 Ops/s 10.6235 Ops/s $\color{#35bf28}+0.29\%$
test_reshape_pytree 84.8040μs 21.3457μs 46.8478 KOps/s 47.8056 KOps/s $\color{#d91a1a}-2.00\%$
test_reshape_td 68.2670μs 32.3192μs 30.9414 KOps/s 31.5558 KOps/s $\color{#d91a1a}-1.95\%$
test_view_pytree 55.0930μs 21.3041μs 46.9392 KOps/s 47.5234 KOps/s $\color{#d91a1a}-1.23\%$
test_view_td 0.1266s 62.9820μs 15.8776 KOps/s 16.3101 KOps/s $\color{#d91a1a}-2.65\%$
test_unbind_pytree 80.5500μs 25.3723μs 39.4130 KOps/s 40.9889 KOps/s $\color{#d91a1a}-3.84\%$
test_unbind_td 77.8750μs 36.9111μs 27.0921 KOps/s 26.7512 KOps/s $\color{#35bf28}+1.27\%$
test_split_pytree 72.6650μs 24.9584μs 40.0666 KOps/s 41.1333 KOps/s $\color{#d91a1a}-2.59\%$
test_split_td 0.1247ms 40.1187μs 24.9261 KOps/s 25.0084 KOps/s $\color{#d91a1a}-0.33\%$
test_add_pytree 0.1138ms 30.7429μs 32.5279 KOps/s 33.6854 KOps/s $\color{#d91a1a}-3.44\%$
test_add_td 0.1333ms 53.4099μs 18.7231 KOps/s 19.8077 KOps/s $\textbf{\color{#d91a1a}-5.48\%}$
test_distributed 0.5127ms 0.1044ms 9.5787 KOps/s 9.7429 KOps/s $\color{#d91a1a}-1.69\%$
test_tdmodule 31.8390μs 16.6232μs 60.1570 KOps/s 58.2447 KOps/s $\color{#35bf28}+3.28\%$
test_tdmodule_dispatch 59.0100μs 32.7666μs 30.5189 KOps/s 30.8101 KOps/s $\color{#d91a1a}-0.95\%$
test_tdseq 36.0370μs 19.3246μs 51.7476 KOps/s 50.0877 KOps/s $\color{#35bf28}+3.31\%$
test_tdseq_dispatch 62.0350μs 37.8308μs 26.4335 KOps/s 26.6560 KOps/s $\color{#d91a1a}-0.83\%$
test_instantiation_functorch 1.9549ms 1.3338ms 749.7632 Ops/s 769.9250 Ops/s $\color{#d91a1a}-2.62\%$
test_instantiation_td 1.6629ms 1.0121ms 988.0488 Ops/s 1.0029 KOps/s $\color{#d91a1a}-1.48\%$
test_exec_functorch 0.3136ms 0.1601ms 6.2473 KOps/s 6.2821 KOps/s $\color{#d91a1a}-0.55\%$
test_exec_functional_call 0.3476ms 0.1504ms 6.6491 KOps/s 6.7997 KOps/s $\color{#d91a1a}-2.21\%$
test_exec_td 0.2151ms 0.1422ms 7.0334 KOps/s 7.0100 KOps/s $\color{#35bf28}+0.33\%$
test_exec_td_decorator 0.6334ms 0.1982ms 5.0458 KOps/s 5.1505 KOps/s $\color{#d91a1a}-2.03\%$
test_vmap_mlp_speed[True-True] 1.0525ms 0.4731ms 2.1136 KOps/s 2.0560 KOps/s $\color{#35bf28}+2.80\%$
test_vmap_mlp_speed[True-False] 0.7592ms 0.4718ms 2.1196 KOps/s 2.0933 KOps/s $\color{#35bf28}+1.26\%$
test_vmap_mlp_speed[False-True] 0.5743ms 0.3862ms 2.5891 KOps/s 2.5539 KOps/s $\color{#35bf28}+1.38\%$
test_vmap_mlp_speed[False-False] 0.5102ms 0.3848ms 2.5990 KOps/s 2.5211 KOps/s $\color{#35bf28}+3.09\%$
test_vmap_mlp_speed_decorator[True-True] 0.9552ms 0.4974ms 2.0105 KOps/s 1.9947 KOps/s $\color{#35bf28}+0.79\%$
test_vmap_mlp_speed_decorator[True-False] 0.7116ms 0.4938ms 2.0250 KOps/s 1.9984 KOps/s $\color{#35bf28}+1.33\%$
test_vmap_mlp_speed_decorator[False-True] 0.7576ms 0.4036ms 2.4775 KOps/s 2.4482 KOps/s $\color{#35bf28}+1.20\%$
test_vmap_mlp_speed_decorator[False-False] 0.7875ms 0.4062ms 2.4620 KOps/s 2.4337 KOps/s $\color{#35bf28}+1.16\%$
test_to_module_speed[True] 1.4920ms 1.3807ms 724.2453 Ops/s 721.2933 Ops/s $\color{#35bf28}+0.41\%$
test_to_module_speed[False] 1.4800ms 1.3697ms 730.0856 Ops/s 737.6185 Ops/s $\color{#d91a1a}-1.02\%$

@vmoens
Copy link
Contributor Author

vmoens commented Mar 8, 2024

As planned this isn't as easy as anticipated lol
Will try to fix it during the weekend!

@maximilianigl
Copy link

Absolutely no rush from my side, it's easy to work around. Just wanted to flag it.

@vmoens vmoens changed the title [BugFix] tensorclass.to_dict and from_dict [BugFix, Feature] tensorclass.to_dict and from_dict Mar 11, 2024
@vmoens vmoens added the enhancement New feature or request label Mar 11, 2024
@vmoens vmoens merged commit d6b6a4b into main Mar 13, 2024
44 of 48 checks passed
@vmoens vmoens deleted the tensorclass-todict branch March 13, 2024 20:38
vmoens added a commit that referenced this pull request Mar 24, 2024
vmoens added a commit that referenced this pull request Mar 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] @tensorclass.to_dict() removes non-tensor entries.
3 participants