-
Notifications
You must be signed in to change notification settings - Fork 76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BugFix] Fix lazy stack / stack_onto + masking lazy stacks #497
Conversation
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_plain_set_nested | 32.4000μs | 20.1408μs | 49.6505 KOps/s | 50.3156 KOps/s | |
test_plain_set_stack_nested | 0.2077ms | 0.1830ms | 5.4657 KOps/s | 5.4494 KOps/s | |
test_plain_set_nested_inplace | 85.2990μs | 23.3856μs | 42.7613 KOps/s | 42.6522 KOps/s | |
test_plain_set_stack_nested_inplace | 0.2486ms | 0.2167ms | 4.6157 KOps/s | 4.5944 KOps/s | |
test_items | 48.1990μs | 3.0539μs | 327.4492 KOps/s | 315.0255 KOps/s | |
test_items_nested | 0.4513ms | 0.3740ms | 2.6740 KOps/s | 2.7922 KOps/s | |
test_items_nested_locked | 0.4256ms | 0.3744ms | 2.6706 KOps/s | 2.7929 KOps/s | |
test_items_nested_leaf | 1.5318ms | 0.2251ms | 4.4425 KOps/s | 4.5620 KOps/s | |
test_items_stack_nested | 1.9870ms | 1.9026ms | 525.6037 Ops/s | 526.5199 Ops/s | |
test_items_stack_nested_leaf | 1.8278ms | 1.7217ms | 580.8200 Ops/s | 579.6780 Ops/s | |
test_items_stack_nested_locked | 1.0305ms | 0.9532ms | 1.0491 KOps/s | 1.0475 KOps/s | |
test_keys | 27.9990μs | 5.0537μs | 197.8745 KOps/s | 223.4656 KOps/s | |
test_keys_nested | 1.6990ms | 0.1721ms | 5.8115 KOps/s | 5.8152 KOps/s | |
test_keys_nested_locked | 0.2173ms | 0.1701ms | 5.8799 KOps/s | 5.8390 KOps/s | |
test_keys_nested_leaf | 0.2865ms | 0.1679ms | 5.9554 KOps/s | 5.5701 KOps/s | |
test_keys_stack_nested | 1.8939ms | 1.6772ms | 596.2231 Ops/s | 594.0117 Ops/s | |
test_keys_stack_nested_leaf | 1.7856ms | 1.6836ms | 593.9517 Ops/s | 590.3709 Ops/s | |
test_keys_stack_nested_locked | 0.8132ms | 0.7203ms | 1.3883 KOps/s | 1.3744 KOps/s | |
test_values | 25.1000μs | 1.3475μs | 742.1046 KOps/s | 839.1745 KOps/s | |
test_values_nested | 0.1157ms | 65.0700μs | 15.3681 KOps/s | 15.2546 KOps/s | |
test_values_nested_locked | 0.1214ms | 64.7235μs | 15.4503 KOps/s | 15.2536 KOps/s | |
test_values_nested_leaf | 0.1055ms | 56.6414μs | 17.6549 KOps/s | 17.2796 KOps/s | |
test_values_stack_nested | 1.5679ms | 1.5161ms | 659.5697 Ops/s | 651.0261 Ops/s | |
test_values_stack_nested_leaf | 1.6009ms | 1.5115ms | 661.6056 Ops/s | 655.8951 Ops/s | |
test_values_stack_nested_locked | 0.7022ms | 0.6273ms | 1.5941 KOps/s | 1.5854 KOps/s | |
test_membership | 15.1000μs | 1.8199μs | 549.4862 KOps/s | 553.5809 KOps/s | |
test_membership_nested | 18.5990μs | 3.6526μs | 273.7801 KOps/s | 279.6026 KOps/s | |
test_membership_nested_leaf | 53.2000μs | 3.6392μs | 274.7868 KOps/s | 275.7908 KOps/s | |
test_membership_stacked_nested | 70.2990μs | 14.4368μs | 69.2674 KOps/s | 69.0185 KOps/s | |
test_membership_stacked_nested_leaf | 38.1990μs | 14.5344μs | 68.8022 KOps/s | 68.6611 KOps/s | |
test_membership_nested_last | 55.6990μs | 7.4891μs | 133.5279 KOps/s | 135.2427 KOps/s | |
test_membership_nested_leaf_last | 29.7000μs | 7.5289μs | 132.8222 KOps/s | 132.5553 KOps/s | |
test_membership_stacked_nested_last | 0.2738ms | 0.2213ms | 4.5187 KOps/s | 4.5136 KOps/s | |
test_membership_stacked_nested_leaf_last | 62.6990μs | 16.9410μs | 59.0284 KOps/s | 58.7053 KOps/s | |
test_nested_getleaf | 63.0990μs | 15.3945μs | 64.9584 KOps/s | 65.3763 KOps/s | |
test_nested_get | 65.7990μs | 14.5028μs | 68.9520 KOps/s | 68.6920 KOps/s | |
test_stacked_getleaf | 0.9323ms | 0.8207ms | 1.2184 KOps/s | 1.2085 KOps/s | |
test_stacked_get | 0.8572ms | 0.7877ms | 1.2694 KOps/s | 1.2666 KOps/s | |
test_nested_getitemleaf | 59.8990μs | 15.3654μs | 65.0812 KOps/s | 64.6364 KOps/s | |
test_nested_getitem | 37.0000μs | 14.4886μs | 69.0200 KOps/s | 68.6508 KOps/s | |
test_stacked_getitemleaf | 0.9304ms | 0.8181ms | 1.2223 KOps/s | 1.2099 KOps/s | |
test_stacked_getitem | 0.8624ms | 0.7881ms | 1.2689 KOps/s | 1.2595 KOps/s | |
test_lock_nested | 51.6790ms | 1.3901ms | 719.3869 Ops/s | 750.7129 Ops/s | |
test_lock_stack_nested | 72.2050ms | 16.5421ms | 60.4519 Ops/s | 60.7748 Ops/s | |
test_unlock_nested | 53.1124ms | 1.3983ms | 715.1684 Ops/s | 720.2584 Ops/s | |
test_unlock_stack_nested | 74.1677ms | 16.9112ms | 59.1323 Ops/s | 59.3689 Ops/s | |
test_flatten_speed | 1.0609ms | 0.9919ms | 1.0082 KOps/s | 1.0253 KOps/s | |
test_unflatten_speed | 1.8294ms | 1.7323ms | 577.2562 Ops/s | 580.9737 Ops/s | |
test_common_ops | 1.1392ms | 1.0114ms | 988.7715 Ops/s | 991.4985 Ops/s | |
test_creation | 28.0000μs | 6.0237μs | 166.0118 KOps/s | 165.0383 KOps/s | |
test_creation_empty | 37.1000μs | 13.2044μs | 75.7321 KOps/s | 75.4620 KOps/s | |
test_creation_nested_1 | 76.0990μs | 22.7194μs | 44.0152 KOps/s | 44.0789 KOps/s | |
test_creation_nested_2 | 41.7990μs | 25.5154μs | 39.1920 KOps/s | 39.1718 KOps/s | |
test_clone | 84.0990μs | 24.7957μs | 40.3295 KOps/s | 41.0557 KOps/s | |
test_getitem[int] | 83.2990μs | 26.6535μs | 37.5185 KOps/s | 38.5269 KOps/s | |
test_getitem[slice_int] | 96.6990μs | 49.5494μs | 20.1819 KOps/s | 20.6035 KOps/s | |
test_getitem[range] | 0.1159ms | 76.9438μs | 12.9965 KOps/s | 13.3560 KOps/s | |
test_getitem[tuple] | 0.1054ms | 41.0445μs | 24.3638 KOps/s | 24.7117 KOps/s | |
test_getitem[list] | 0.3828ms | 71.9009μs | 13.9080 KOps/s | 14.1749 KOps/s | |
test_setitem_dim[int] | 52.5990μs | 31.6770μs | 31.5687 KOps/s | 31.8685 KOps/s | |
test_setitem_dim[slice_int] | 87.9980μs | 55.8484μs | 17.9056 KOps/s | 18.0450 KOps/s | |
test_setitem_dim[range] | 0.1004ms | 76.5881μs | 13.0569 KOps/s | 13.2589 KOps/s | |
test_setitem_dim[tuple] | 68.7990μs | 47.1672μs | 21.2012 KOps/s | 21.2904 KOps/s | |
test_setitem | 0.1176ms | 29.9903μs | 33.3442 KOps/s | 33.6026 KOps/s | |
test_set | 0.1084ms | 29.3052μs | 34.1237 KOps/s | 34.4824 KOps/s | |
test_set_shared | 0.3020ms | 0.1526ms | 6.5549 KOps/s | 6.5763 KOps/s | |
test_update | 0.1446ms | 32.3861μs | 30.8774 KOps/s | 31.3479 KOps/s | |
test_update_nested | 0.1442ms | 49.7536μs | 20.0991 KOps/s | 20.4879 KOps/s | |
test_set_nested | 79.2000μs | 31.1432μs | 32.1097 KOps/s | 32.2282 KOps/s | |
test_set_nested_new | 0.1433ms | 50.9031μs | 19.6452 KOps/s | 20.1844 KOps/s | |
test_select | 0.2243ms | 96.1792μs | 10.3973 KOps/s | 10.5643 KOps/s | |
test_unbind_speed | 0.7074ms | 0.6386ms | 1.5659 KOps/s | 1.5612 KOps/s | |
test_unbind_speed_stack0 | 66.2258ms | 8.1277ms | 123.0353 Ops/s | 339.5416 Ops/s | |
test_unbind_speed_stack1 | 8.2833μs | 0.9402μs | 1.0636 MOps/s | 2.1610 MOps/s | |
test_creation[device0] | 0.4710ms | 0.3315ms | 3.0163 KOps/s | 3.0701 KOps/s | |
test_creation_from_tensor | 0.4796ms | 0.3731ms | 2.6803 KOps/s | 2.7153 KOps/s | |
test_add_one[memmap_tensor0] | 1.2798ms | 29.9717μs | 33.3648 KOps/s | 32.6919 KOps/s | |
test_contiguous[memmap_tensor0] | 55.4000μs | 8.3159μs | 120.2522 KOps/s | 121.5111 KOps/s | |
test_stack[memmap_tensor0] | 53.0000μs | 24.4782μs | 40.8526 KOps/s | 40.2633 KOps/s | |
test_memmaptd_index | 0.3547ms | 0.2978ms | 3.3581 KOps/s | 3.3859 KOps/s | |
test_memmaptd_index_astensor | 1.2826ms | 1.2199ms | 819.7404 Ops/s | 822.2220 Ops/s | |
test_memmaptd_index_op | 2.6091ms | 2.3395ms | 427.4371 Ops/s | 421.2483 Ops/s | |
test_reshape_pytree | 96.5990μs | 36.7167μs | 27.2356 KOps/s | 27.5264 KOps/s | |
test_reshape_td | 60.9990μs | 43.6560μs | 22.9064 KOps/s | 23.1049 KOps/s | |
test_view_pytree | 0.1133ms | 34.0290μs | 29.3867 KOps/s | 29.5121 KOps/s | |
test_view_td | 30.0000μs | 8.6278μs | 115.9040 KOps/s | 116.9530 KOps/s | |
test_unbind_pytree | 70.1000μs | 38.9481μs | 25.6752 KOps/s | 26.3324 KOps/s | |
test_unbind_td | 0.1658ms | 94.1183μs | 10.6249 KOps/s | 10.6958 KOps/s | |
test_split_pytree | 77.2990μs | 42.7426μs | 23.3959 KOps/s | 23.3647 KOps/s | |
test_split_td | 0.7623ms | 0.1113ms | 8.9833 KOps/s | 8.8423 KOps/s | |
test_add_pytree | 79.7000μs | 45.3360μs | 22.0575 KOps/s | 21.8107 KOps/s | |
test_add_td | 0.1583ms | 70.1342μs | 14.2584 KOps/s | 14.5557 KOps/s | |
test_distributed | 29.6990μs | 8.2062μs | 121.8598 KOps/s | 106.8608 KOps/s | |
test_tdmodule | 0.1860ms | 25.9697μs | 38.5063 KOps/s | 38.4651 KOps/s | |
test_tdmodule_dispatch | 0.2378ms | 50.5921μs | 19.7659 KOps/s | 9.7424 KOps/s | |
test_tdseq | 0.5065ms | 27.4398μs | 36.4434 KOps/s | 36.3817 KOps/s | |
test_tdseq_dispatch | 0.4797ms | 54.6829μs | 18.2872 KOps/s | 18.5082 KOps/s | |
test_instantiation_functorch | 1.7477ms | 1.5098ms | 662.3235 Ops/s | 652.5992 Ops/s | |
test_instantiation_td | 1.8760ms | 1.2465ms | 802.2517 Ops/s | 791.1354 Ops/s | |
test_exec_functorch | 0.2467ms | 0.1763ms | 5.6729 KOps/s | 5.5953 KOps/s | |
test_exec_td | 0.2087ms | 0.1674ms | 5.9754 KOps/s | 5.9312 KOps/s | |
test_vmap_mlp_speed[True-True] | 1.6863ms | 1.0229ms | 977.6408 Ops/s | 972.8381 Ops/s | |
test_vmap_mlp_speed[True-False] | 0.7302ms | 0.5044ms | 1.9826 KOps/s | 1.9962 KOps/s | |
test_vmap_mlp_speed[False-True] | 1.2040ms | 0.8733ms | 1.1450 KOps/s | 1.1559 KOps/s | |
test_vmap_mlp_speed[False-False] | 1.8131ms | 0.3977ms | 2.5142 KOps/s | 2.5596 KOps/s | |
test_vmap_transformer_speed[True-True] | 13.6153ms | 12.1010ms | 82.6379 Ops/s | 81.7625 Ops/s | |
test_vmap_transformer_speed[True-False] | 67.7795ms | 8.0917ms | 123.5841 Ops/s | 131.4784 Ops/s | |
test_vmap_transformer_speed[False-True] | 12.0132ms | 11.8384ms | 84.4712 Ops/s | 83.8651 Ops/s | |
test_vmap_transformer_speed[False-False] | 9.5686ms | 7.4378ms | 134.4477 Ops/s | 134.6608 Ops/s |
Here is a test for it class TestNestedLazyStacks:
def get_agent_tensors(
self,
i,
):
camera = torch.zeros(32, 32, 3)
vector_3d = torch.zeros(3)
vector_2d = torch.zeros(2)
lidar = torch.zeros(20)
agent_0_obs = torch.zeros(1)
agent_1_obs = torch.zeros(1, 2)
agent_2_obs = torch.zeros(1, 2, 3)
# Agents all have the same camera
# All have vector entry but different shapes
# First 2 have lidar and last sonar
# All have a different key agent_i_obs with different n_dims
if i == 0:
return TensorDict(
{
"camera": camera,
"lidar": lidar,
"vector": vector_3d,
"agent_0_obs": agent_0_obs,
},
[],
)
elif i == 1:
return TensorDict(
{
"camera": camera,
"lidar": lidar,
"vector": vector_2d,
"agent_1_obs": agent_1_obs,
},
[],
)
elif i == 2:
return TensorDict(
{
"camera": camera,
"vector": vector_2d,
"agent_2_obs": agent_2_obs,
},
[],
)
else:
raise ValueError(f"Index {i} undefined for 3 agents")
def get_lazy_stack(self, batch_size):
agent_obs = []
for angent_id in range(3):
agent_obs.append(self.get_agent_tensors(angent_id))
agent_obs = torch.stack(agent_obs, dim=0)
obs = TensorDict(
{
"agents": agent_obs,
"state": torch.zeros(
64,
64,
3,
),
},
[],
)
obs = obs.expand(batch_size)
return obs
def dense_stack_tds_v2(
self, td_list: List[TensorDictBase], stack_dim: int
) -> TensorDictBase:
shape = list(td_list[0].shape)
shape.insert(stack_dim, len(td_list))
out = td_list[0].unsqueeze(stack_dim).expand(shape).clone()
return torch.stack(td_list, dim=stack_dim, out=out)
@pytest.mark.parametrize("batch_size", [(), (32,), (1, 2)])
def test(self, batch_size):
obs = self.get_lazy_stack(batch_size)
for stack_dim in range(len(batch_size) + 1):
res = self.dense_stack_tds_v2(
[obs, obs], stack_dim=stack_dim
) |
If you replace
you can also see that the heterogenous shape bug comes up also in |
On it I think i've found where the problem was
|
Few things that I found when trying to understand this:
|
this isn't related to stack(... out=smth) right? |
nope, general lazystack bugs |
but they might affect part of this downstream |
I don't think so (though i would not rule it out completely) |
I think setitem is used here https://github.com/pytorch-labs/tensordict/blob/main/tensordict/utils.py#L825 , called from here https://github.com/pytorch-labs/tensordict/blob/main/tensordict/tensordict.py#L4142 on nested Lazy stacks (that function is receiving LazxyStacks as input despite the type hint suggests it should not) Here is a set of tests that I think will test some of interesting cases class TestNestedLazyStacks:
@staticmethod
def nested_lazy_het_td(batch_size):
shared = torch.zeros(4, 4, 2)
hetero_3d = torch.zeros(3)
hetero_2d = torch.zeros(2)
individual_0_tensor = torch.zeros(1)
individual_1_tensor = torch.zeros(1, 2)
individual_2_tensor = torch.zeros(1, 2, 3)
td_list = [
TensorDict(
{
"shared": shared,
"hetero": hetero_3d,
"individual_0_tensor": individual_0_tensor,
},
[],
),
TensorDict(
{
"shared": shared,
"hetero": hetero_3d,
"individual_1_tensor": individual_1_tensor,
},
[],
),
TensorDict(
{
"shared": shared,
"hetero": hetero_2d,
"individual_2_tensor": individual_2_tensor,
},
[],
),
]
for i, td in enumerate(td_list):
td[f"individual_{i}_td"] = td.clone()
td[f"shared_td"] = td.clone()
td_stack = torch.stack(td_list, dim=0)
obs = TensorDict(
{"lazy": td_stack, "dense": torch.zeros(3, 3, 2)},
[],
)
obs = obs.expand(batch_size)
return obs
def recursively_check_key(self, td, value: int):
if isinstance(td, LazyStackedTensorDict):
for t in td.tensordicts:
if not self.recursively_check_key(t, value):
return False
elif isinstance(td, TensorDict):
for i in td.values():
if not self.recursively_check_key(i, value):
return False
elif isinstance(td, torch.Tensor):
return (td == value).all()
else:
return False
return True
def dense_stack_tds_v1(
self, td_list: List[TensorDictBase], stack_dim: int
) -> TensorDictBase:
shape = list(td_list[0].shape)
shape.insert(stack_dim, len(td_list))
out = td_list[0].unsqueeze(stack_dim).expand(shape).clone()
for i in range(1, len(td_list)):
index = (slice(None),) * stack_dim + (i,) # this is index_select
out[index] = td_list[i]
return out
def dense_stack_tds_v2(
self, td_list: List[TensorDictBase], stack_dim: int
) -> TensorDictBase:
shape = list(td_list[0].shape)
shape.insert(stack_dim, len(td_list))
out = td_list[0].unsqueeze(stack_dim).expand(shape).clone()
return torch.stack(td_list, dim=stack_dim, out=out)
@pytest.mark.parametrize("batch_size", [(), (2,), (1, 2)])
@pytest.mark.parametrize("stack_dim", [0, 1, 2])
def test(self, batch_size, stack_dim):
obs = self.nested_lazy_het_td(batch_size)
obs1 = obs.clone()
obs1.apply_(lambda x: x + 1)
if stack_dim > len(batch_size):
return
res1 = self.dense_stack_tds_v1([obs, obs1], stack_dim=stack_dim)
res2 = self.dense_stack_tds_v2([obs, obs1], stack_dim=stack_dim)
index = (slice(None),) * stack_dim + (0,) # get the first in the stack
assert self.recursively_check_key(res1[index], 0) # check all 0
assert self.recursively_check_key(res2[index], 0) # check all 0
index = (slice(None),) * stack_dim + (1,) # get the second in the stack
assert self.recursively_check_key(res1[index], 1) # check all 1
assert self.recursively_check_key(res2[index], 1) # check all 1
nit: the case of Lazy Lazy stacks is not tested, i just test Lazy stacks with keys that are Lazy stacks |
Tests are passing now, and i'll fix this too |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Can you add this to the tests?
def all_eq(
self,
td: Union[TensorDictBase, torch.Tensor],
other: Union[TensorDictBase, torch.Tensor],
):
if td.__class__ != other.__class__:
return False
if td.shape != other.shape or td.device != other.device:
return False
if isinstance(td, LazyStackedTensorDict):
if td.stack_dim != other.stack_dim:
return False
for t, o in zip(td.tensordicts, other.tensordicts):
if not self.all_eq(t, o):
return False
elif isinstance(td, TensorDictBase):
td_keys = list(td.keys())
other_keys = list(other.keys())
if td_keys != other_keys:
return False
for k in td_keys:
if not self.all_eq(td[k], other[k]):
return False
elif isinstance(td, torch.Tensor):
return torch.equal(td, other)
else:
raise AssertionError()
return True
@pytest.mark.parametrize("batch_size", [(), (2,), (1, 2)])
@pytest.mark.parametrize("stack_dim", [0, 1, 2])
def test_setitem_hetero(self, batch_size, stack_dim):
obs = self.nested_lazy_het_td(batch_size)
obs1 = obs.clone()
obs1.apply_(lambda x: x + 1)
if stack_dim > len(batch_size):
return
res1 = self.dense_stack_tds_v1([obs, obs1], stack_dim=stack_dim)
res2 = self.dense_stack_tds_v2([obs, obs1], stack_dim=stack_dim)
index = (slice(None),) * stack_dim + (0,) # get the first in the stack
assert self.recursively_check_key(res1[index], 0) # check all 0
assert self.recursively_check_key(res2[index], 0) # check all 0
index = (slice(None),) * stack_dim + (1,) # get the second in the stack
assert self.recursively_check_key(res1[index], 1) # check all 1
assert self.recursively_check_key(res2[index], 1) # check all 1
assert self.all_eq(res1, res2)
it is an extra consistency check
the second function is the one we alreadfy have w an extra line
I'd rather not add that all_eq method, what does it do that a simple eq does not do? The reason is that, if for each new test we add a helper we'll end up with tests that are hard to modify or with lots of legacy code when we remove a test for eample |
oh yes ofc, a normal equal returns some lazy keys in canse there are though right? how do i check all lazy keys match? |
# Conflicts: # tensordict/tensordict.py # test/test_tensordict.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
No description provided.