Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix lazy stack / stack_onto + masking lazy stacks #497

Merged
merged 9 commits into from
Jul 25, 2023

Conversation

vmoens
Copy link
Contributor

@vmoens vmoens commented Jul 24, 2023

No description provided.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 24, 2023
@github-actions
Copy link

github-actions bot commented Jul 24, 2023

$\color{#D29922}\textsf{\Large⚠\kern{0.2cm}\normalsize Warning}$ Result of CPU Benchmark Tests

Total Benchmarks: 109. Improved: $\large\color{#35bf28}3$. Worsened: $\large\color{#d91a1a}5$.

Expand to view detailed results
Name Max Mean Ops Ops on Repo HEAD Change
test_plain_set_nested 32.4000μs 20.1408μs 49.6505 KOps/s 50.3156 KOps/s $\color{#d91a1a}-1.32\%$
test_plain_set_stack_nested 0.2077ms 0.1830ms 5.4657 KOps/s 5.4494 KOps/s $\color{#35bf28}+0.30\%$
test_plain_set_nested_inplace 85.2990μs 23.3856μs 42.7613 KOps/s 42.6522 KOps/s $\color{#35bf28}+0.26\%$
test_plain_set_stack_nested_inplace 0.2486ms 0.2167ms 4.6157 KOps/s 4.5944 KOps/s $\color{#35bf28}+0.46\%$
test_items 48.1990μs 3.0539μs 327.4492 KOps/s 315.0255 KOps/s $\color{#35bf28}+3.94\%$
test_items_nested 0.4513ms 0.3740ms 2.6740 KOps/s 2.7922 KOps/s $\color{#d91a1a}-4.23\%$
test_items_nested_locked 0.4256ms 0.3744ms 2.6706 KOps/s 2.7929 KOps/s $\color{#d91a1a}-4.38\%$
test_items_nested_leaf 1.5318ms 0.2251ms 4.4425 KOps/s 4.5620 KOps/s $\color{#d91a1a}-2.62\%$
test_items_stack_nested 1.9870ms 1.9026ms 525.6037 Ops/s 526.5199 Ops/s $\color{#d91a1a}-0.17\%$
test_items_stack_nested_leaf 1.8278ms 1.7217ms 580.8200 Ops/s 579.6780 Ops/s $\color{#35bf28}+0.20\%$
test_items_stack_nested_locked 1.0305ms 0.9532ms 1.0491 KOps/s 1.0475 KOps/s $\color{#35bf28}+0.15\%$
test_keys 27.9990μs 5.0537μs 197.8745 KOps/s 223.4656 KOps/s $\textbf{\color{#d91a1a}-11.45\%}$
test_keys_nested 1.6990ms 0.1721ms 5.8115 KOps/s 5.8152 KOps/s $\color{#d91a1a}-0.06\%$
test_keys_nested_locked 0.2173ms 0.1701ms 5.8799 KOps/s 5.8390 KOps/s $\color{#35bf28}+0.70\%$
test_keys_nested_leaf 0.2865ms 0.1679ms 5.9554 KOps/s 5.5701 KOps/s $\textbf{\color{#35bf28}+6.92\%}$
test_keys_stack_nested 1.8939ms 1.6772ms 596.2231 Ops/s 594.0117 Ops/s $\color{#35bf28}+0.37\%$
test_keys_stack_nested_leaf 1.7856ms 1.6836ms 593.9517 Ops/s 590.3709 Ops/s $\color{#35bf28}+0.61\%$
test_keys_stack_nested_locked 0.8132ms 0.7203ms 1.3883 KOps/s 1.3744 KOps/s $\color{#35bf28}+1.01\%$
test_values 25.1000μs 1.3475μs 742.1046 KOps/s 839.1745 KOps/s $\textbf{\color{#d91a1a}-11.57\%}$
test_values_nested 0.1157ms 65.0700μs 15.3681 KOps/s 15.2546 KOps/s $\color{#35bf28}+0.74\%$
test_values_nested_locked 0.1214ms 64.7235μs 15.4503 KOps/s 15.2536 KOps/s $\color{#35bf28}+1.29\%$
test_values_nested_leaf 0.1055ms 56.6414μs 17.6549 KOps/s 17.2796 KOps/s $\color{#35bf28}+2.17\%$
test_values_stack_nested 1.5679ms 1.5161ms 659.5697 Ops/s 651.0261 Ops/s $\color{#35bf28}+1.31\%$
test_values_stack_nested_leaf 1.6009ms 1.5115ms 661.6056 Ops/s 655.8951 Ops/s $\color{#35bf28}+0.87\%$
test_values_stack_nested_locked 0.7022ms 0.6273ms 1.5941 KOps/s 1.5854 KOps/s $\color{#35bf28}+0.55\%$
test_membership 15.1000μs 1.8199μs 549.4862 KOps/s 553.5809 KOps/s $\color{#d91a1a}-0.74\%$
test_membership_nested 18.5990μs 3.6526μs 273.7801 KOps/s 279.6026 KOps/s $\color{#d91a1a}-2.08\%$
test_membership_nested_leaf 53.2000μs 3.6392μs 274.7868 KOps/s 275.7908 KOps/s $\color{#d91a1a}-0.36\%$
test_membership_stacked_nested 70.2990μs 14.4368μs 69.2674 KOps/s 69.0185 KOps/s $\color{#35bf28}+0.36\%$
test_membership_stacked_nested_leaf 38.1990μs 14.5344μs 68.8022 KOps/s 68.6611 KOps/s $\color{#35bf28}+0.21\%$
test_membership_nested_last 55.6990μs 7.4891μs 133.5279 KOps/s 135.2427 KOps/s $\color{#d91a1a}-1.27\%$
test_membership_nested_leaf_last 29.7000μs 7.5289μs 132.8222 KOps/s 132.5553 KOps/s $\color{#35bf28}+0.20\%$
test_membership_stacked_nested_last 0.2738ms 0.2213ms 4.5187 KOps/s 4.5136 KOps/s $\color{#35bf28}+0.11\%$
test_membership_stacked_nested_leaf_last 62.6990μs 16.9410μs 59.0284 KOps/s 58.7053 KOps/s $\color{#35bf28}+0.55\%$
test_nested_getleaf 63.0990μs 15.3945μs 64.9584 KOps/s 65.3763 KOps/s $\color{#d91a1a}-0.64\%$
test_nested_get 65.7990μs 14.5028μs 68.9520 KOps/s 68.6920 KOps/s $\color{#35bf28}+0.38\%$
test_stacked_getleaf 0.9323ms 0.8207ms 1.2184 KOps/s 1.2085 KOps/s $\color{#35bf28}+0.82\%$
test_stacked_get 0.8572ms 0.7877ms 1.2694 KOps/s 1.2666 KOps/s $\color{#35bf28}+0.22\%$
test_nested_getitemleaf 59.8990μs 15.3654μs 65.0812 KOps/s 64.6364 KOps/s $\color{#35bf28}+0.69\%$
test_nested_getitem 37.0000μs 14.4886μs 69.0200 KOps/s 68.6508 KOps/s $\color{#35bf28}+0.54\%$
test_stacked_getitemleaf 0.9304ms 0.8181ms 1.2223 KOps/s 1.2099 KOps/s $\color{#35bf28}+1.03\%$
test_stacked_getitem 0.8624ms 0.7881ms 1.2689 KOps/s 1.2595 KOps/s $\color{#35bf28}+0.75\%$
test_lock_nested 51.6790ms 1.3901ms 719.3869 Ops/s 750.7129 Ops/s $\color{#d91a1a}-4.17\%$
test_lock_stack_nested 72.2050ms 16.5421ms 60.4519 Ops/s 60.7748 Ops/s $\color{#d91a1a}-0.53\%$
test_unlock_nested 53.1124ms 1.3983ms 715.1684 Ops/s 720.2584 Ops/s $\color{#d91a1a}-0.71\%$
test_unlock_stack_nested 74.1677ms 16.9112ms 59.1323 Ops/s 59.3689 Ops/s $\color{#d91a1a}-0.40\%$
test_flatten_speed 1.0609ms 0.9919ms 1.0082 KOps/s 1.0253 KOps/s $\color{#d91a1a}-1.67\%$
test_unflatten_speed 1.8294ms 1.7323ms 577.2562 Ops/s 580.9737 Ops/s $\color{#d91a1a}-0.64\%$
test_common_ops 1.1392ms 1.0114ms 988.7715 Ops/s 991.4985 Ops/s $\color{#d91a1a}-0.28\%$
test_creation 28.0000μs 6.0237μs 166.0118 KOps/s 165.0383 KOps/s $\color{#35bf28}+0.59\%$
test_creation_empty 37.1000μs 13.2044μs 75.7321 KOps/s 75.4620 KOps/s $\color{#35bf28}+0.36\%$
test_creation_nested_1 76.0990μs 22.7194μs 44.0152 KOps/s 44.0789 KOps/s $\color{#d91a1a}-0.14\%$
test_creation_nested_2 41.7990μs 25.5154μs 39.1920 KOps/s 39.1718 KOps/s $\color{#35bf28}+0.05\%$
test_clone 84.0990μs 24.7957μs 40.3295 KOps/s 41.0557 KOps/s $\color{#d91a1a}-1.77\%$
test_getitem[int] 83.2990μs 26.6535μs 37.5185 KOps/s 38.5269 KOps/s $\color{#d91a1a}-2.62\%$
test_getitem[slice_int] 96.6990μs 49.5494μs 20.1819 KOps/s 20.6035 KOps/s $\color{#d91a1a}-2.05\%$
test_getitem[range] 0.1159ms 76.9438μs 12.9965 KOps/s 13.3560 KOps/s $\color{#d91a1a}-2.69\%$
test_getitem[tuple] 0.1054ms 41.0445μs 24.3638 KOps/s 24.7117 KOps/s $\color{#d91a1a}-1.41\%$
test_getitem[list] 0.3828ms 71.9009μs 13.9080 KOps/s 14.1749 KOps/s $\color{#d91a1a}-1.88\%$
test_setitem_dim[int] 52.5990μs 31.6770μs 31.5687 KOps/s 31.8685 KOps/s $\color{#d91a1a}-0.94\%$
test_setitem_dim[slice_int] 87.9980μs 55.8484μs 17.9056 KOps/s 18.0450 KOps/s $\color{#d91a1a}-0.77\%$
test_setitem_dim[range] 0.1004ms 76.5881μs 13.0569 KOps/s 13.2589 KOps/s $\color{#d91a1a}-1.52\%$
test_setitem_dim[tuple] 68.7990μs 47.1672μs 21.2012 KOps/s 21.2904 KOps/s $\color{#d91a1a}-0.42\%$
test_setitem 0.1176ms 29.9903μs 33.3442 KOps/s 33.6026 KOps/s $\color{#d91a1a}-0.77\%$
test_set 0.1084ms 29.3052μs 34.1237 KOps/s 34.4824 KOps/s $\color{#d91a1a}-1.04\%$
test_set_shared 0.3020ms 0.1526ms 6.5549 KOps/s 6.5763 KOps/s $\color{#d91a1a}-0.33\%$
test_update 0.1446ms 32.3861μs 30.8774 KOps/s 31.3479 KOps/s $\color{#d91a1a}-1.50\%$
test_update_nested 0.1442ms 49.7536μs 20.0991 KOps/s 20.4879 KOps/s $\color{#d91a1a}-1.90\%$
test_set_nested 79.2000μs 31.1432μs 32.1097 KOps/s 32.2282 KOps/s $\color{#d91a1a}-0.37\%$
test_set_nested_new 0.1433ms 50.9031μs 19.6452 KOps/s 20.1844 KOps/s $\color{#d91a1a}-2.67\%$
test_select 0.2243ms 96.1792μs 10.3973 KOps/s 10.5643 KOps/s $\color{#d91a1a}-1.58\%$
test_unbind_speed 0.7074ms 0.6386ms 1.5659 KOps/s 1.5612 KOps/s $\color{#35bf28}+0.30\%$
test_unbind_speed_stack0 66.2258ms 8.1277ms 123.0353 Ops/s 339.5416 Ops/s $\textbf{\color{#d91a1a}-63.76\%}$
test_unbind_speed_stack1 8.2833μs 0.9402μs 1.0636 MOps/s 2.1610 MOps/s $\textbf{\color{#d91a1a}-50.78\%}$
test_creation[device0] 0.4710ms 0.3315ms 3.0163 KOps/s 3.0701 KOps/s $\color{#d91a1a}-1.75\%$
test_creation_from_tensor 0.4796ms 0.3731ms 2.6803 KOps/s 2.7153 KOps/s $\color{#d91a1a}-1.29\%$
test_add_one[memmap_tensor0] 1.2798ms 29.9717μs 33.3648 KOps/s 32.6919 KOps/s $\color{#35bf28}+2.06\%$
test_contiguous[memmap_tensor0] 55.4000μs 8.3159μs 120.2522 KOps/s 121.5111 KOps/s $\color{#d91a1a}-1.04\%$
test_stack[memmap_tensor0] 53.0000μs 24.4782μs 40.8526 KOps/s 40.2633 KOps/s $\color{#35bf28}+1.46\%$
test_memmaptd_index 0.3547ms 0.2978ms 3.3581 KOps/s 3.3859 KOps/s $\color{#d91a1a}-0.82\%$
test_memmaptd_index_astensor 1.2826ms 1.2199ms 819.7404 Ops/s 822.2220 Ops/s $\color{#d91a1a}-0.30\%$
test_memmaptd_index_op 2.6091ms 2.3395ms 427.4371 Ops/s 421.2483 Ops/s $\color{#35bf28}+1.47\%$
test_reshape_pytree 96.5990μs 36.7167μs 27.2356 KOps/s 27.5264 KOps/s $\color{#d91a1a}-1.06\%$
test_reshape_td 60.9990μs 43.6560μs 22.9064 KOps/s 23.1049 KOps/s $\color{#d91a1a}-0.86\%$
test_view_pytree 0.1133ms 34.0290μs 29.3867 KOps/s 29.5121 KOps/s $\color{#d91a1a}-0.42\%$
test_view_td 30.0000μs 8.6278μs 115.9040 KOps/s 116.9530 KOps/s $\color{#d91a1a}-0.90\%$
test_unbind_pytree 70.1000μs 38.9481μs 25.6752 KOps/s 26.3324 KOps/s $\color{#d91a1a}-2.50\%$
test_unbind_td 0.1658ms 94.1183μs 10.6249 KOps/s 10.6958 KOps/s $\color{#d91a1a}-0.66\%$
test_split_pytree 77.2990μs 42.7426μs 23.3959 KOps/s 23.3647 KOps/s $\color{#35bf28}+0.13\%$
test_split_td 0.7623ms 0.1113ms 8.9833 KOps/s 8.8423 KOps/s $\color{#35bf28}+1.59\%$
test_add_pytree 79.7000μs 45.3360μs 22.0575 KOps/s 21.8107 KOps/s $\color{#35bf28}+1.13\%$
test_add_td 0.1583ms 70.1342μs 14.2584 KOps/s 14.5557 KOps/s $\color{#d91a1a}-2.04\%$
test_distributed 29.6990μs 8.2062μs 121.8598 KOps/s 106.8608 KOps/s $\textbf{\color{#35bf28}+14.04\%}$
test_tdmodule 0.1860ms 25.9697μs 38.5063 KOps/s 38.4651 KOps/s $\color{#35bf28}+0.11\%$
test_tdmodule_dispatch 0.2378ms 50.5921μs 19.7659 KOps/s 9.7424 KOps/s $\textbf{\color{#35bf28}+102.89\%}$
test_tdseq 0.5065ms 27.4398μs 36.4434 KOps/s 36.3817 KOps/s $\color{#35bf28}+0.17\%$
test_tdseq_dispatch 0.4797ms 54.6829μs 18.2872 KOps/s 18.5082 KOps/s $\color{#d91a1a}-1.19\%$
test_instantiation_functorch 1.7477ms 1.5098ms 662.3235 Ops/s 652.5992 Ops/s $\color{#35bf28}+1.49\%$
test_instantiation_td 1.8760ms 1.2465ms 802.2517 Ops/s 791.1354 Ops/s $\color{#35bf28}+1.41\%$
test_exec_functorch 0.2467ms 0.1763ms 5.6729 KOps/s 5.5953 KOps/s $\color{#35bf28}+1.39\%$
test_exec_td 0.2087ms 0.1674ms 5.9754 KOps/s 5.9312 KOps/s $\color{#35bf28}+0.74\%$
test_vmap_mlp_speed[True-True] 1.6863ms 1.0229ms 977.6408 Ops/s 972.8381 Ops/s $\color{#35bf28}+0.49\%$
test_vmap_mlp_speed[True-False] 0.7302ms 0.5044ms 1.9826 KOps/s 1.9962 KOps/s $\color{#d91a1a}-0.68\%$
test_vmap_mlp_speed[False-True] 1.2040ms 0.8733ms 1.1450 KOps/s 1.1559 KOps/s $\color{#d91a1a}-0.94\%$
test_vmap_mlp_speed[False-False] 1.8131ms 0.3977ms 2.5142 KOps/s 2.5596 KOps/s $\color{#d91a1a}-1.78\%$
test_vmap_transformer_speed[True-True] 13.6153ms 12.1010ms 82.6379 Ops/s 81.7625 Ops/s $\color{#35bf28}+1.07\%$
test_vmap_transformer_speed[True-False] 67.7795ms 8.0917ms 123.5841 Ops/s 131.4784 Ops/s $\textbf{\color{#d91a1a}-6.00\%}$
test_vmap_transformer_speed[False-True] 12.0132ms 11.8384ms 84.4712 Ops/s 83.8651 Ops/s $\color{#35bf28}+0.72\%$
test_vmap_transformer_speed[False-False] 9.5686ms 7.4378ms 134.4477 Ops/s 134.6608 Ops/s $\color{#d91a1a}-0.16\%$

@matteobettini
Copy link
Contributor

Here is a test for it

class TestNestedLazyStacks:
    def get_agent_tensors(
        self,
        i,
    ):
        camera = torch.zeros(32, 32, 3)
        vector_3d = torch.zeros(3)
        vector_2d = torch.zeros(2)
        lidar = torch.zeros(20)

        agent_0_obs = torch.zeros(1)
        agent_1_obs = torch.zeros(1, 2)
        agent_2_obs = torch.zeros(1, 2, 3)

        # Agents all have the same camera
        # All have vector entry but different shapes
        # First 2 have lidar and last sonar
        # All have a different key agent_i_obs with different n_dims
        if i == 0:
            return TensorDict(
                {
                    "camera": camera,
                    "lidar": lidar,
                    "vector": vector_3d,
                    "agent_0_obs": agent_0_obs,
                },
                [],
            )
        elif i == 1:
            return TensorDict(
                {
                    "camera": camera,
                    "lidar": lidar,
                    "vector": vector_2d,
                    "agent_1_obs": agent_1_obs,
                },
                [],
            )
        elif i == 2:
            return TensorDict(
                {
                    "camera": camera,
                    "vector": vector_2d,
                    "agent_2_obs": agent_2_obs,
                },
                [],
            )
        else:
            raise ValueError(f"Index {i} undefined for 3 agents")

    def get_lazy_stack(self, batch_size):
        agent_obs = []
        for angent_id in range(3):
            agent_obs.append(self.get_agent_tensors(angent_id))
        agent_obs = torch.stack(agent_obs, dim=0)
        obs = TensorDict(
            {
                "agents": agent_obs,
                "state": torch.zeros(
                    64,
                    64,
                    3,
                ),
            },
            [],
        )
        obs = obs.expand(batch_size)
        return obs

 
    def dense_stack_tds_v2(
        self, td_list: List[TensorDictBase], stack_dim: int
    ) -> TensorDictBase:
        shape = list(td_list[0].shape)
        shape.insert(stack_dim, len(td_list))

        out = td_list[0].unsqueeze(stack_dim).expand(shape).clone()

        return torch.stack(td_list, dim=stack_dim, out=out)

    @pytest.mark.parametrize("batch_size", [(), (32,), (1, 2)])
    def test(self, batch_size):
        obs = self.get_lazy_stack(batch_size)
        for stack_dim in range(len(batch_size) + 1):
            res = self.dense_stack_tds_v2(
                [obs, obs], stack_dim=stack_dim
            )  

@matteobettini
Copy link
Contributor

If you replace dense_stack_tds_v2 with

def dense_stack_tds_v1(td_list: List[TensorDictBase], stack_dim: int) -> TensorDictBase:
    shape = list(td_list[0].shape)
    shape.insert(stack_dim, len(td_list))

    out = td_list[0].unsqueeze(stack_dim).expand(shape).clone()
    for i in range(1, len(td_list)):
        index = (slice(None),) * stack_dim + (i,)  # this is index_select
        out[index] = td_list[i]

    return out

you can also see that the heterogenous shape bug comes up also in __setitem__

@vmoens
Copy link
Contributor Author

vmoens commented Jul 24, 2023

On it I think i've found where the problem was
Nit for the tests:

  • we should avoid domain-specific terms
  • we should avoid looping over use cases in the test
  • we should avoid batch size of 32 or similar, they quickly pile up to large tensors which is unnecessary

@matteobettini
Copy link
Contributor

matteobettini commented Jul 24, 2023

Few things that I found when trying to understand this:

@vmoens
Copy link
Contributor Author

vmoens commented Jul 24, 2023

Few things that I found when trying to understand this:

this isn't related to stack(... out=smth) right?

@matteobettini
Copy link
Contributor

nope, general lazystack bugs

@matteobettini
Copy link
Contributor

but they might affect part of this downstream

@vmoens
Copy link
Contributor Author

vmoens commented Jul 24, 2023

I don't think so (though i would not rule it out completely)
usually after you get through an entry point like __setitem__ or similar you will only pass from private method to private method

@matteobettini
Copy link
Contributor

matteobettini commented Jul 24, 2023

I don't think so (though i would not rule it out completely)
usually after you get through an entry point like __setitem__ or similar you will only pass from private method to private method

I think setitem is used here https://github.com/pytorch-labs/tensordict/blob/main/tensordict/utils.py#L825 , called from here https://github.com/pytorch-labs/tensordict/blob/main/tensordict/tensordict.py#L4142 on nested Lazy stacks (that function is receiving LazxyStacks as input despite the type hint suggests it should not)

Here is a set of tests that I think will test some of interesting cases

class TestNestedLazyStacks:
    @staticmethod
    def nested_lazy_het_td(batch_size):
        shared = torch.zeros(4, 4, 2)
        hetero_3d = torch.zeros(3)
        hetero_2d = torch.zeros(2)

        individual_0_tensor = torch.zeros(1)
        individual_1_tensor = torch.zeros(1, 2)
        individual_2_tensor = torch.zeros(1, 2, 3)

       
        td_list = [
            TensorDict(
                {
                    "shared": shared,
                    "hetero": hetero_3d,
                    "individual_0_tensor": individual_0_tensor,
                },
                [],
            ),
            TensorDict(
                {
                    "shared": shared,
                    "hetero": hetero_3d,
                    "individual_1_tensor": individual_1_tensor,
                },
                [],
            ),
            TensorDict(
                {
                    "shared": shared,
                    "hetero": hetero_2d,
                    "individual_2_tensor": individual_2_tensor,
                },
                [],
            ),
        ]
        for i, td in enumerate(td_list):
            td[f"individual_{i}_td"] = td.clone()
            td[f"shared_td"] = td.clone()

        td_stack = torch.stack(td_list, dim=0)
        obs = TensorDict(
            {"lazy": td_stack, "dense": torch.zeros(3, 3, 2)},
            [],
        )
        obs = obs.expand(batch_size)
        return obs

    def recursively_check_key(self, td, value: int):
        if isinstance(td, LazyStackedTensorDict):
            for t in td.tensordicts:
                if not self.recursively_check_key(t, value):
                    return False
        elif isinstance(td, TensorDict):
            for i in td.values():
                if not self.recursively_check_key(i, value):
                    return False
        elif isinstance(td, torch.Tensor):
            return (td == value).all()
        else:
            return False

        return True

    def dense_stack_tds_v1(
        self, td_list: List[TensorDictBase], stack_dim: int
    ) -> TensorDictBase:
        shape = list(td_list[0].shape)
        shape.insert(stack_dim, len(td_list))

        out = td_list[0].unsqueeze(stack_dim).expand(shape).clone()
        for i in range(1, len(td_list)):
            index = (slice(None),) * stack_dim + (i,)  # this is index_select
            out[index] = td_list[i]

        return out

    def dense_stack_tds_v2(
        self, td_list: List[TensorDictBase], stack_dim: int
    ) -> TensorDictBase:
        shape = list(td_list[0].shape)
        shape.insert(stack_dim, len(td_list))

        out = td_list[0].unsqueeze(stack_dim).expand(shape).clone()

        return torch.stack(td_list, dim=stack_dim, out=out)

    @pytest.mark.parametrize("batch_size", [(), (2,), (1, 2)])
    @pytest.mark.parametrize("stack_dim", [0, 1, 2])
    def test(self, batch_size, stack_dim):
        obs = self.nested_lazy_het_td(batch_size)
        obs1 = obs.clone()
        obs1.apply_(lambda x: x + 1)

        if stack_dim > len(batch_size):
            return

        res1 = self.dense_stack_tds_v1([obs, obs1], stack_dim=stack_dim)
        res2 = self.dense_stack_tds_v2([obs, obs1], stack_dim=stack_dim)

        index = (slice(None),) * stack_dim + (0,)  # get the first in the stack
        assert self.recursively_check_key(res1[index], 0)  # check all 0
        assert self.recursively_check_key(res2[index], 0)  # check all 0
        index = (slice(None),) * stack_dim + (1,)  # get the second in the stack
        assert self.recursively_check_key(res1[index], 1)  # check all 1
        assert self.recursively_check_key(res2[index], 1)  # check all 1
  • to test just the feature in this PR you can drop the res1 lines
  • the res2 lines will test setitem

nit: the case of Lazy Lazy stacks is not tested, i just test Lazy stacks with keys that are Lazy stacks

@vmoens
Copy link
Contributor Author

vmoens commented Jul 24, 2023

Tests are passing now, and i'll fix this too

Copy link
Contributor

@matteobettini matteobettini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Can you add this to the tests?

     def all_eq(
        self,
        td: Union[TensorDictBase, torch.Tensor],
        other: Union[TensorDictBase, torch.Tensor],
    ):
        if td.__class__ != other.__class__:
            return False

        if td.shape != other.shape or td.device != other.device:
            return False

        if isinstance(td, LazyStackedTensorDict):
            if td.stack_dim != other.stack_dim:
                return False
            for t, o in zip(td.tensordicts, other.tensordicts):
                if not self.all_eq(t, o):
                    return False
        elif isinstance(td, TensorDictBase):
            td_keys = list(td.keys())
            other_keys = list(other.keys())
            if td_keys != other_keys:
                return False
            for k in td_keys:
                if not self.all_eq(td[k], other[k]):
                    return False
        elif isinstance(td, torch.Tensor):
            return torch.equal(td, other)
        else:
            raise AssertionError()

        return True

    @pytest.mark.parametrize("batch_size", [(), (2,), (1, 2)])
    @pytest.mark.parametrize("stack_dim", [0, 1, 2])
    def test_setitem_hetero(self, batch_size, stack_dim):
        obs = self.nested_lazy_het_td(batch_size)
        obs1 = obs.clone()
        obs1.apply_(lambda x: x + 1)

        if stack_dim > len(batch_size):
            return

        res1 = self.dense_stack_tds_v1([obs, obs1], stack_dim=stack_dim)
        res2 = self.dense_stack_tds_v2([obs, obs1], stack_dim=stack_dim)

        index = (slice(None),) * stack_dim + (0,)  # get the first in the stack
        assert self.recursively_check_key(res1[index], 0)  # check all 0
        assert self.recursively_check_key(res2[index], 0)  # check all 0
        index = (slice(None),) * stack_dim + (1,)  # get the second in the stack
        assert self.recursively_check_key(res1[index], 1)  # check all 1
        assert self.recursively_check_key(res2[index], 1)  # check all 1

        assert self.all_eq(res1, res2)

it is an extra consistency check

the second function is the one we alreadfy have w an extra line

@vmoens
Copy link
Contributor Author

vmoens commented Jul 24, 2023

I'd rather not add that all_eq method, what does it do that a simple eq does not do?

The reason is that, if for each new test we add a helper we'll end up with tests that are hard to modify or with lots of legacy code when we remove a test for eample

@matteobettini
Copy link
Contributor

The reason is that, if for each new test we add a helper we'll end up with tests that are hard to modify or with lots of legacy code when we remove a test for eample

oh yes ofc, a normal equal returns some lazy keys in canse there are though right? how do i check all lazy keys match?

# Conflicts:
#	tensordict/tensordict.py
#	test/test_tensordict.py
@vmoens vmoens marked this pull request as ready for review July 24, 2023 17:10
Copy link
Contributor

@matteobettini matteobettini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@vmoens vmoens merged commit 150b510 into main Jul 25, 2023
@vmoens vmoens changed the title [WIP] Fix lazy stack / stack_onto [BugFix] Fix lazy stack / stack_onto + masking lazy stacks Jul 25, 2023
@vmoens vmoens deleted the fix_lazy_stack_onto branch July 25, 2023 19:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants