From d00bc79cde73535b2c228e5755a2de9adc07cf10 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 6 Aug 2024 16:14:58 -0400 Subject: [PATCH 1/4] init --- tensordict/tensorclass.py | 28 ++-- test/test_nn.py | 274 ++++++++++++++++++++++++-------------- test/test_tensorclass.py | 3 + test/test_tensordict.py | 170 +++++++++++++++-------- 4 files changed, 306 insertions(+), 169 deletions(-) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 2fbc21ceb..698a0b412 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -128,10 +128,11 @@ def __subclasscheck__(self, subclass): "is_shared", "items", "keys", - # "ndim", "ndimension", "numel", + "size", "values", + # "ndim", ] # Methods to be executed from tensordict, any ref to self means 'self._tensordict' @@ -252,6 +253,7 @@ def __subclasscheck__(self, subclass): "new_zeros", "norm", "permute", + "pin_memory", "pow", "pow_", "prod", @@ -1104,12 +1106,13 @@ def check_out(kwargs, result): return wrapped_func -def _wrap_method(self, attr, func): - warnings.warn( - f"The method {func} wasn't explicitly implemented for tensorclass. " - f"This fallback will be deprecated in future releases because it is inefficient " - f"and non-compilable. Please raise an issue in tensordict repo to support this method!" - ) +def _wrap_method(self, attr, func, nowarn=False): + if not nowarn: + warnings.warn( + f"The method {func} wasn't explicitly implemented for tensorclass. " + f"This fallback will be deprecated in future releases because it is inefficient " + f"and non-compilable. Please raise an issue in tensordict repo to support this method!" + ) @functools.wraps(func) def wrapped_func(*args, **kwargs): @@ -2666,14 +2669,17 @@ def __torch_function__( def _fast_apply(self, *args, **kwargs): kwargs["filter_empty"] = False - return _wrap_method(self, "_fast_apply", self._tensordict._fast_apply)( - *args, **kwargs - ) + return _wrap_method( + self, "_fast_apply", self._tensordict._fast_apply, nowarn=True + )(*args, **kwargs) def _multithread_rebuild(self, *args, **kwargs): kwargs["filter_empty"] = False return _wrap_method( - self, "_multithread_rebuild", self._tensordict._multithread_rebuild + self, + "_multithread_rebuild", + self._tensordict._multithread_rebuild, + nowarn=True, )(*args, **kwargs) def tolist(self): diff --git a/test/test_nn.py b/test/test_nn.py index c6c50c21c..80d0c802f 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import argparse +import contextlib import copy import pickle import unittest @@ -35,7 +36,6 @@ AddStateIndependentNormalScale, Delta, NormalParamExtractor, - NormalParamWrapper, ) from tensordict.nn.distributions.composite import CompositeDistribution from tensordict.nn.ensemble import EnsembleModule @@ -68,6 +68,13 @@ FUNCTORCH_ERR = str(err) +# Capture all warnings +pytestmark = [ + pytest.mark.filterwarnings("error"), + pytest.mark.filterwarnings("ignore:enable_nested_tensor is True"), +] + + class TestInteractionType: @pytest.mark.parametrize( "str_and_expected_type", @@ -310,7 +317,9 @@ def test_stateful_probabilistic_deprec(self, lazy, interaction_type, out_keys): in_keys = ["in"] net = TensorDictModule( - module=NormalParamWrapper(net), in_keys=in_keys, out_keys=out_keys + module=nn.Sequential(net, NormalParamExtractor()), + in_keys=in_keys, + out_keys=out_keys, ) kwargs = {"distribution_class": Normal} @@ -329,7 +338,12 @@ def test_stateful_probabilistic_deprec(self, lazy, interaction_type, out_keys): td = TensorDict({"in": torch.randn(3, 3)}, [3]) with set_interaction_type(interaction_type): - tensordict_module(td) + with ( + pytest.warns(UserWarning, match="deterministic_sample") + if interaction_type in (InteractionType.DETERMINISTIC, None) + else contextlib.nullcontext() + ): + tensordict_module(td) assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 4]) @@ -368,7 +382,12 @@ def test_stateful_probabilistic_kwargs( td = TensorDict({"in": torch.randn(3, 3)}, [3]) with set_interaction_type(interaction_type): - tensordict_module(td) + with ( + pytest.warns(UserWarning, match="deterministic_sample") + if interaction_type in (None, InteractionType.DETERMINISTIC) + else contextlib.nullcontext() + ): + tensordict_module(td) assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 4]) @@ -428,7 +447,12 @@ def test_stateful_probabilistic(self, lazy, interaction_type, out_keys): td = TensorDict({"in": torch.randn(3, 3)}, [3]) with set_interaction_type(interaction_type): - tensordict_module(td) + with ( + pytest.warns(UserWarning, match="deterministic_sample") + if interaction_type in (None, InteractionType.DETERMINISTIC) + else contextlib.nullcontext() + ): + tensordict_module(td) assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 4]) @@ -533,9 +557,10 @@ def test_functional_functorch(self): module=net, in_keys=["in"], out_keys=["out"] ) - tensordict_module, params, buffers = make_functional_functorch( - tensordict_module - ) + with pytest.warns(FutureWarning, match="integrated functorch"): + tensordict_module, params, buffers = make_functional_functorch( + tensordict_module + ) td = TensorDict({"in": torch.randn(3, 3)}, [3]) tensordict_module(params, buffers, td) @@ -550,7 +575,9 @@ def test_functional_probabilistic_deprec(self): param_multiplier = 2 tdnet = TensorDictModule( - module=NormalParamWrapper(nn.Linear(3, 4 * param_multiplier)), + module=nn.Sequential( + nn.Linear(3, 4 * param_multiplier), NormalParamExtractor() + ), in_keys=["in"], out_keys=["loc", "scale"], ) @@ -564,7 +591,8 @@ def test_functional_probabilistic_deprec(self): params = make_functional(tensordict_module) td = TensorDict({"in": torch.randn(3, 3)}, [3]) - tensordict_module(td, params=params) + with pytest.warns(UserWarning, match="deterministic_sample"): + tensordict_module(td, params=params) assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 4]) @@ -595,7 +623,8 @@ def test_functional_probabilistic(self): params = make_functional(tensordict_module) td = TensorDict({"in": torch.randn(3, 3)}, [3]) - tensordict_module(td, params=params) + with pytest.warns(UserWarning, match="deterministic_sample"): + tensordict_module(td, params=params) assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 4]) @@ -625,7 +654,9 @@ def test_functional_with_buffer_probabilistic_deprec(self): param_multiplier = 2 tdnet = TensorDictModule( - module=NormalParamWrapper(nn.BatchNorm1d(32 * param_multiplier)), + module=nn.Sequential( + nn.BatchNorm1d(32 * param_multiplier), NormalParamExtractor() + ), in_keys=["in"], out_keys=["loc", "scale"], ) @@ -639,7 +670,8 @@ def test_functional_with_buffer_probabilistic_deprec(self): params = make_functional(tdmodule) td = TensorDict({"in": torch.randn(3, 32 * param_multiplier)}, [3]) - tdmodule(td, params=params) + with pytest.warns(UserWarning, match="deterministic_sample"): + tdmodule(td, params=params) assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 32]) @@ -668,7 +700,8 @@ def test_functional_with_buffer_probabilistic(self): params = make_functional(tdmodule) td = TensorDict({"in": torch.randn(3, 32 * param_multiplier)}, [3]) - tdmodule(td, params=params) + with pytest.warns(UserWarning, match="deterministic_sample"): + tdmodule(td, params=params) assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 32]) @@ -707,7 +740,7 @@ def test_vmap_probabilistic_deprec(self): torch.manual_seed(0) param_multiplier = 2 - net = NormalParamWrapper(nn.Linear(3, 4 * param_multiplier)) + net = nn.Sequential(nn.Linear(3, 4 * param_multiplier), NormalParamExtractor()) tdnet = TensorDictModule(module=net, in_keys=["in"], out_keys=["loc", "scale"]) @@ -722,7 +755,8 @@ def test_vmap_probabilistic_deprec(self): # vmap = True params = params.expand(10).lock_() td = TensorDict({"in": torch.randn(3, 3)}, [3]) - td_out = vmap(tdmodule, (None, 0))(td, params) + with pytest.warns(UserWarning, match="deterministic_sample"): + td_out = vmap(tdmodule, (None, 0))(td, params) assert td_out is not td assert td_out.shape == torch.Size([10, 3]) assert td_out.get("out").shape == torch.Size([10, 3, 4]) @@ -730,7 +764,8 @@ def test_vmap_probabilistic_deprec(self): # vmap = (0, 0) td = TensorDict({"in": torch.randn(3, 3)}, [3]) td_repeat = td.expand(10, *td.batch_size) - td_out = vmap(tdmodule, (0, 0))(td_repeat, params) + with pytest.warns(UserWarning, match="deterministic_sample"): + td_out = vmap(tdmodule, (0, 0))(td_repeat, params) assert td_out is not td_repeat assert td_out.shape == torch.Size([10, 3]) assert td_out.get("out").shape == torch.Size([10, 3, 4]) @@ -771,7 +806,8 @@ def test_vmap_probabilistic(self): # vmap = True params = params.expand(10).lock_() td = TensorDict({"in": torch.randn(3, 3)}, [3]) - td_out = vmap(tdmodule, (None, 0))(td, params) + with pytest.warns(UserWarning, match="deterministic_sample"): + td_out = vmap(tdmodule, (None, 0))(td, params) assert td_out is not td assert td_out.shape == torch.Size([10, 3]) assert td_out.get("out").shape == torch.Size([10, 3, 4]) @@ -779,7 +815,8 @@ def test_vmap_probabilistic(self): # vmap = (0, 0) td = TensorDict({"in": torch.randn(3, 3)}, [3]) td_repeat = td.expand(10, *td.batch_size) - td_out = vmap(tdmodule, (0, 0))(td_repeat, params) + with pytest.warns(UserWarning, match="deterministic_sample"): + td_out = vmap(tdmodule, (0, 0))(td_repeat, params) assert td_out is not td_repeat assert td_out.shape == torch.Size([10, 3]) assert td_out.get("out").shape == torch.Size([10, 3, 4]) @@ -1056,7 +1093,7 @@ def test_stateful_probabilistic_deprec(self, lazy): net1 = nn.Linear(3, 4) dummy_net = nn.Linear(4, 4) net2 = nn.Linear(4, 4 * param_multiplier) - net2 = NormalParamWrapper(net2) + net2 = nn.Sequential(net2, NormalParamExtractor()) kwargs = {"distribution_class": Normal} tdmodule1 = TensorDictModule(net1, in_keys=["in"], out_keys=["hidden"]) @@ -1091,7 +1128,8 @@ def test_stateful_probabilistic_deprec(self, lazy): assert tdmodule[2] is prob_module td = TensorDict({"in": torch.randn(3, 3)}, [3]) - tdmodule(td) + with pytest.warns(UserWarning, match="deterministic_sample"): + tdmodule(td) assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 4]) @@ -1147,7 +1185,8 @@ def test_stateful_probabilistic(self, lazy): assert tdmodule[3] is prob_module td = TensorDict({"in": torch.randn(3, 3)}, [3]) - tdmodule(td) + with pytest.warns(UserWarning, match="deterministic_sample"): + tdmodule(td) assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 4]) @@ -1215,7 +1254,8 @@ def test_functional_functorch(self): tdmodule2 = TensorDictModule(net2, in_keys=["hidden"], out_keys=["out"]) tdmodule = TensorDictSequential(tdmodule1, dummy_tdmodule, tdmodule2) - ftdmodule, params, buffers = make_functional_functorch(tdmodule) + with pytest.warns(FutureWarning): + ftdmodule, params, buffers = make_functional_functorch(tdmodule) td = TensorDict({"in": torch.randn(3, 3)}, [3]) ftdmodule(params, buffers, td) @@ -1232,7 +1272,7 @@ def test_functional_probabilistic_deprec(self): net1 = nn.Linear(3, 4) dummy_net = nn.Linear(4, 4) net2 = nn.Linear(4, 4 * param_multiplier) - net2 = NormalParamWrapper(net2) + net2 = nn.Sequential(net2, NormalParamExtractor()) tdmodule1 = TensorDictModule(net1, in_keys=["in"], out_keys=["hidden"]) dummy_tdmodule = TensorDictModule( @@ -1276,7 +1316,8 @@ def test_functional_probabilistic_deprec(self): assert tdmodule[2] is prob_module td = TensorDict({"in": torch.randn(3, 3)}, [3]) - tdmodule(td, params=params) + with pytest.warns(UserWarning, match="deterministic_sample"): + tdmodule(td, params=params) assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 4]) @@ -1340,7 +1381,8 @@ def test_functional_probabilistic(self): assert tdmodule[3] is prob_module td = TensorDict({"in": torch.randn(3, 3)}, [3]) - tdmodule(td, params=params) + with pytest.warns(UserWarning, match="deterministic_sample"): + tdmodule(td, params=params) assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 4]) @@ -1405,7 +1447,7 @@ def test_functional_with_buffer_probabilistic_deprec(self): net2 = nn.Sequential( nn.Linear(7, 7 * param_multiplier), nn.BatchNorm1d(7 * param_multiplier) ) - net2 = NormalParamWrapper(net2) + net2 = nn.Sequential(net2, NormalParamExtractor()) tdmodule1 = TensorDictModule(net1, in_keys=["in"], out_keys=["hidden"]) dummy_tdmodule = TensorDictModule( @@ -1450,7 +1492,8 @@ def test_functional_with_buffer_probabilistic_deprec(self): assert tdmodule[2] is prob_module td = TensorDict({"in": torch.randn(3, 7)}, [3]) - tdmodule(td, params=params) + with pytest.warns(UserWarning, match="deterministic_sample"): + tdmodule(td, params=params) dist = tdmodule.get_dist(td, params=params) assert dist.rsample().shape[: td.ndimension()] == td.shape @@ -1518,7 +1561,8 @@ def test_functional_with_buffer_probabilistic(self): assert tdmodule[3] is prob_module td = TensorDict({"in": torch.randn(3, 7)}, [3]) - tdmodule(td, params=params) + with pytest.warns(UserWarning, match="deterministic_sample"): + tdmodule(td, params=params) dist = tdmodule.get_dist(td, params=params) assert dist.rsample().shape[: td.ndimension()] == td.shape @@ -1590,7 +1634,7 @@ def test_vmap_probabilistic_deprec(self): net1 = nn.Linear(3, 4) net2 = nn.Linear(4, 4 * param_multiplier) - net2 = NormalParamWrapper(net2) + net2 = nn.Sequential(net2, NormalParamExtractor()) kwargs = {"distribution_class": Normal} tdmodule1 = TensorDictModule(net1, in_keys=["in"], out_keys=["hidden"]) @@ -1612,7 +1656,8 @@ def test_vmap_probabilistic_deprec(self): # vmap = True params = params.expand(10).lock_() td = TensorDict({"in": torch.randn(3, 3)}, [3]) - td_out = vmap(tdmodule, (None, 0))(td, params) + with pytest.warns(UserWarning, match="deterministic_sample"): + td_out = vmap(tdmodule, (None, 0))(td, params) assert td_out is not td assert td_out.shape == torch.Size([10, 3]) assert td_out.get("out").shape == torch.Size([10, 3, 4]) @@ -1620,7 +1665,8 @@ def test_vmap_probabilistic_deprec(self): # vmap = (0, 0) td = TensorDict({"in": torch.randn(3, 3)}, [3]) td_repeat = td.expand(10, *td.batch_size) - td_out = vmap(tdmodule, (0, 0))(td_repeat, params) + with pytest.warns(UserWarning, match="deterministic_sample"): + td_out = vmap(tdmodule, (0, 0))(td_repeat, params) assert td_out is not td_repeat assert td_out.shape == torch.Size([10, 3]) assert td_out.get("out").shape == torch.Size([10, 3, 4]) @@ -1655,7 +1701,8 @@ def test_vmap_probabilistic(self): # vmap = True params = params.expand(10).lock_() td = TensorDict({"in": torch.randn(3, 3)}, [3]) - td_out = vmap(tdmodule, (None, 0))(td, params) + with pytest.warns(UserWarning, match="deterministic_sample"): + td_out = vmap(tdmodule, (None, 0))(td, params) assert td_out is not td assert td_out.shape == torch.Size([10, 3]) assert td_out.get("out").shape == torch.Size([10, 3, 4]) @@ -1663,7 +1710,8 @@ def test_vmap_probabilistic(self): # vmap = (0, 0) td = TensorDict({"in": torch.randn(3, 3)}, [3]) td_repeat = td.expand(10, *td.batch_size) - td_out = vmap(tdmodule, (0, 0))(td_repeat, params) + with pytest.warns(UserWarning, match="deterministic_sample"): + td_out = vmap(tdmodule, (0, 0))(td_repeat, params) assert td_out is not td_repeat assert td_out.shape == torch.Size([10, 3]) assert td_out.get("out").shape == torch.Size([10, 3, 4]) @@ -1765,11 +1813,11 @@ def test_sequential_partial_deprec(self, stack, functional): net1 = nn.Linear(3, 4) net2 = nn.Linear(4, 4 * param_multiplier) - net2 = NormalParamWrapper(net2) + net2 = nn.Sequential(net2, NormalParamExtractor()) net2 = TensorDictModule(net2, in_keys=["b"], out_keys=["loc", "scale"]) net3 = nn.Linear(4, 4 * param_multiplier) - net3 = NormalParamWrapper(net3) + net3 = nn.Sequential(net3, NormalParamExtractor()) net3 = TensorDictModule(net3, in_keys=["c"], out_keys=["loc", "scale"]) kwargs = {"distribution_class": Normal} @@ -1804,10 +1852,11 @@ def test_sequential_partial_deprec(self, stack, functional): ], 0, ) - if functional: - tdmodule(td, params=params) - else: - tdmodule(td) + with pytest.warns(UserWarning, match="deterministic_sample"): + if functional: + tdmodule(td, params=params) + else: + tdmodule(td) assert "loc" in td.keys() assert "scale" in td.keys() assert "out" in td.keys() @@ -1818,10 +1867,11 @@ def test_sequential_partial_deprec(self, stack, functional): assert "b" in td[0].keys() else: td = TensorDict({"a": torch.randn(3), "b": torch.randn(4)}, []) - if functional: - tdmodule(td, params=params) - else: - tdmodule(td) + with pytest.warns(UserWarning, match="deterministic_sample"): + if functional: + tdmodule(td, params=params) + else: + tdmodule(td) assert "loc" in td.keys() assert "scale" in td.keys() assert "out" in td.keys() @@ -1882,10 +1932,11 @@ def test_sequential_partial(self, stack, functional): ], 0, ) - if functional: - tdmodule(td, params=params) - else: - tdmodule(td) + with pytest.warns(UserWarning, match="deterministic_sample"): + if functional: + tdmodule(td, params=params) + else: + tdmodule(td) assert "loc" in td.keys() assert "scale" in td.keys() assert "out" in td.keys() @@ -1896,10 +1947,11 @@ def test_sequential_partial(self, stack, functional): assert "b" in td[0].keys() else: td = TensorDict({"a": torch.randn(3), "b": torch.randn(4)}, []) - if functional: - tdmodule(td, params=params) - else: - tdmodule(td) + with pytest.warns(UserWarning, match="deterministic_sample"): + if functional: + tdmodule(td, params=params) + else: + tdmodule(td) assert "loc" in td.keys() assert "scale" in td.keys() assert "out" in td.keys() @@ -1996,7 +2048,8 @@ def mycallable(): TensorDictModule(module, in_keys=[("_", "")], out_keys=[("_", "")]) TensorDictModule(module, in_keys=[("_", "")], out_keys=[("a", "a")]) TensorDictModule(module, in_keys=[""], out_keys=["_"]) - TensorDictModule(module, in_keys=["_"], out_keys=[""]) + with pytest.warns(UserWarning, match='key "_"'): + TensorDictModule(module, in_keys=["_"], out_keys=[""]) # this should raise for wrong_model in (MyModule, int, [123], 1, torch.randn(2)): @@ -2213,19 +2266,29 @@ def test_make_func(self, module_type, stateless, keyword, extra_kwargs): params = params.clone() params.zero_() td = self.td - if not keyword: - if not stateless and module_type == "nnModule": - with pytest.raises(TypeError, match="It seems you tried to provide"): - if extra_kwargs: - _ = module(td, params, extra=None) - else: - _ = module(td, params) - return - tdout = module(td, params) - assert (tdout == self.td_zero).all() - else: - tdout = module(td, params=params) - assert (tdout == self.td_zero).all(), tdout + with ( + pytest.warns( + UserWarning, + match="You are passing a tensordict/tensorclass instance to a module", + ) + if module_type in ("nnModule",) + else contextlib.nullcontext() + ): + if not keyword: + if not stateless and module_type == "nnModule": + with pytest.raises( + TypeError, match="It seems you tried to provide" + ): + if extra_kwargs: + _ = module(td, params, extra=None) + else: + _ = module(td, params) + return + tdout = module(td, params) + assert (tdout == self.td_zero).all() + else: + tdout = module(td, params=params) + assert (tdout == self.td_zero).all(), tdout @pytest.mark.parametrize("module_type", ["TDMBase", "nnModule", "TDM"]) @pytest.mark.parametrize("stateless", [True, False]) @@ -2237,21 +2300,31 @@ def test_make_func_vmap(self, module_type, stateless, keyword, extra_kwargs): params = params.expand(5).to_tensordict().lock_() params.zero_() td = self.td.expand(5, 3).to_tensordict() - if not keyword: - if not stateless and module_type == "nnModule": - with pytest.raises(TypeError, match="It seems you tried to provide"): - if extra_kwargs: - _ = vmap(module)(td, params, extra=None) - else: - _ = vmap(module)(td, params) - return - tdout = vmap(module)(td, params) - assert (tdout == self.td_zero).all() - else: - # this isn't supposed to work: keyword arguments are not expanded with vmap - with pytest.raises(Exception): - tdout = vmap(module)(td, params=params) - assert (tdout == self.td_zero).all(), tdout + with ( + pytest.warns( + UserWarning, + match="You are passing a tensordict/tensorclass instance to a module", + ) + if module_type in ("nnModule",) + else contextlib.nullcontext() + ): + if not keyword: + if not stateless and module_type == "nnModule": + with pytest.raises( + TypeError, match="It seems you tried to provide" + ): + if extra_kwargs: + _ = vmap(module)(td, params, extra=None) + else: + _ = vmap(module)(td, params) + return + tdout = vmap(module)(td, params) + assert (tdout == self.td_zero).all() + else: + # this isn't supposed to work: keyword arguments are not expanded with vmap + with pytest.raises(Exception): + tdout = vmap(module)(td, params=params) + assert (tdout == self.td_zero).all(), tdout class TestSkipExisting: @@ -2762,26 +2835,27 @@ def test_nested_keys_probabilistic_normal(log_prob_key): return_log_prob=True, log_prob_key=log_prob_key, ) - td_out = module(loc_module(scale_module(td))) - assert td_out["data", "action"].shape == (3, 4, 1) - if log_prob_key: - assert td_out[log_prob_key].shape == (3, 4, 1) - else: - assert td_out["sample_log_prob"].shape == (3, 4, 1) + with pytest.warns(UserWarning, match="deterministic_sample"): + td_out = module(loc_module(scale_module(td))) + assert td_out["data", "action"].shape == (3, 4, 1) + if log_prob_key: + assert td_out[log_prob_key].shape == (3, 4, 1) + else: + assert td_out["sample_log_prob"].shape == (3, 4, 1) - module = ProbabilisticTensorDictModule( - in_keys={"loc": ("data", "loc"), "scale": ("data", "scale")}, - out_keys=[("data", "action")], - distribution_class=Normal, - return_log_prob=True, - log_prob_key=log_prob_key, - ) - td_out = module(loc_module(scale_module(td))) - assert td_out["data", "action"].shape == (3, 4, 1) - if log_prob_key: - assert td_out[log_prob_key].shape == (3, 4, 1) - else: - assert td_out["sample_log_prob"].shape == (3, 4, 1) + module = ProbabilisticTensorDictModule( + in_keys={"loc": ("data", "loc"), "scale": ("data", "scale")}, + out_keys=[("data", "action")], + distribution_class=Normal, + return_log_prob=True, + log_prob_key=log_prob_key, + ) + td_out = module(loc_module(scale_module(td))) + assert td_out["data", "action"].shape == (3, 4, 1) + if log_prob_key: + assert td_out[log_prob_key].shape == (3, 4, 1) + else: + assert td_out["sample_log_prob"].shape == (3, 4, 1) class TestEnsembleModule: diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index 17eca8292..87896b954 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -47,6 +47,9 @@ from tensordict._lazy import _PermutedTensorDict, _ViewedTensorDict from torch import Tensor +# Capture all warnings +pytestmark = pytest.mark.filterwarnings("error") + def _make_data(shape): return MyData( diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 98a1a374e..b6b51822f 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -98,6 +98,23 @@ getattr(torch, "_nested_compute_contiguous_strides_offsets", None) is not None ) +# Capture all warnings +pytestmark = [ + pytest.mark.filterwarnings("error"), + pytest.mark.filterwarnings( + "ignore:There is a performance drop because we have not yet implemented the batching rule" + ), + pytest.mark.filterwarnings( + "ignore:A destination should be provided when cloning a PersistentTensorDict" + ), + pytest.mark.filterwarnings( + "ignore:Replacing an array with another one is inefficient" + ), + pytest.mark.filterwarnings( + "ignore:Indexing an h5py.Dataset object with a boolean mask that needs broadcasting does not work directly" + ), +] + mp_ctx = "fork" if (not torch.cuda.is_available() and not _IS_WINDOWS) else "spawn" @@ -279,33 +296,38 @@ def test_consolidate(self, device, use_file, tmpdir, num_threads, nested, hetdty (2, 3), 2, dtype=torch.float64 if hetdtype else torch.float32 ) else: - layout = torch.jagged - a = torch.nested.nested_tensor( - [torch.zeros((1,), device=device), torch.zeros((2,), device=device)], - layout=layout, - ) - c = torch.nested.nested_tensor( - [ - torch.ones( - (1,), - device=device, - dtype=torch.float16 if hetdtype else torch.float32, - ), - torch.ones( - (2,), - device=device, - dtype=torch.float16 if hetdtype else torch.float32, - ), - ], - layout=layout, - ) - g0 = torch.full( - (1, 3), 2, dtype=torch.float64 if hetdtype else torch.float32 - ) - g1 = torch.full( - (2, 3), 2, dtype=torch.float64 if hetdtype else torch.float32 - ) - g = torch.nested.nested_tensor([g0, g1], layout=layout) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + layout = torch.jagged + a = torch.nested.nested_tensor( + [ + torch.zeros((1,), device=device), + torch.zeros((2,), device=device), + ], + layout=layout, + ) + c = torch.nested.nested_tensor( + [ + torch.ones( + (1,), + device=device, + dtype=torch.float16 if hetdtype else torch.float32, + ), + torch.ones( + (2,), + device=device, + dtype=torch.float16 if hetdtype else torch.float32, + ), + ], + layout=layout, + ) + g0 = torch.full( + (1, 3), 2, dtype=torch.float64 if hetdtype else torch.float32 + ) + g1 = torch.full( + (2, 3), 2, dtype=torch.float64 if hetdtype else torch.float32 + ) + g = torch.nested.nested_tensor([g0, g1], layout=layout) td = TensorDict( { @@ -533,7 +555,7 @@ def test_data_grad(self): [3, 4], ) td1 = td + 1 - td1.apply(lambda x: x.sum().backward(retain_graph=True)) + td1.apply(lambda x: x.sum().backward(retain_graph=True), filter_empty=True) assert not td.grad.is_locked assert td.grad is not td.grad assert not td.data.is_locked @@ -935,6 +957,7 @@ def test_from_module(self, memmap, params): nhead=2, num_encoder_layers=3, dim_feedforward=12, + batch_first=True, ) td = TensorDict.from_module(net, as_module=params, filter_empty=False) # check that we have empty tensordicts, reflecting modules without params @@ -956,6 +979,7 @@ def test_from_module_state_dict(self): nhead=2, num_encoder_layers=3, dim_feedforward=12, + batch_first=True, ) def adder(module, *args, **kwargs): @@ -1008,9 +1032,17 @@ def exec_module(params, x): def get_leaf(leaf): leaves.append(leaf) - params.apply(get_leaf) + params.apply(get_leaf, filter_empty=True) assert all(param.grad is not None for param in leaves) - assert all(param.grad is None for param in params.values(True, True)) + with ( + pytest.warns( + UserWarning, + match="The .grad attribute of a Tensor that is not a leaf Tensor is being accessed.", + ) + if lazy_stack + else contextlib.nullcontext() + ): + assert all(param.grad is None for param in params.values(True, True)) else: for p in modules[0].parameters(): assert p.grad is None @@ -1285,7 +1317,11 @@ def test_keys_view(self): def test_load_device(self, tmpdir): t = nn.Transformer( - d_model=64, nhead=4, num_encoder_layers=3, dim_feedforward=128 + d_model=64, + nhead=4, + num_encoder_layers=3, + dim_feedforward=128, + batch_first=True, ) state_dict = TensorDict.from_module(t) @@ -1321,7 +1357,7 @@ def assert_device(item): def assert_fake(tensor): assert isinstance(tensor, FakeTensor) - fake_state_dict.apply(assert_fake) + fake_state_dict.apply(assert_fake, filter_empty=True) def test_load_state_dict_incomplete(self): data = TensorDict({"a": {"b": {"c": {}}}, "d": 1}, []) @@ -1432,10 +1468,14 @@ def test_make_memmap_from_tensor(self, tmpdir): if HAS_NESTED_TENSOR: # test update - td.make_memmap_from_tensor( - ("e", "f"), - torch.nested.nested_tensor([torch.zeros((1, 2)), torch.zeros((1, 3))]), - ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + td.make_memmap_from_tensor( + ("e", "f"), + torch.nested.nested_tensor( + [torch.zeros((1, 2)), torch.zeros((1, 3))] + ), + ) td_load.memmap_refresh_() assert td_load["e", "f"].is_nested @@ -2554,12 +2594,14 @@ def test_to_module_state_dict(self): nhead=2, num_encoder_layers=3, dim_feedforward=12, + batch_first=True, ) net1 = nn.Transformer( d_model=16, nhead=2, num_encoder_layers=3, dim_feedforward=12, + batch_first=True, ) def hook( @@ -2594,14 +2636,16 @@ def hook( @pytest.mark.parametrize("mask_key", [None, "mask"]) def test_to_padded_tensor(self, mask_key): - td = TensorDict( - { - "nested": torch.nested.nested_tensor( - [torch.ones(3, 4, 5), torch.ones(3, 6, 5)] - ) - }, - batch_size=[2, 3, -1], - ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + td = TensorDict( + { + "nested": torch.nested.nested_tensor( + [torch.ones(3, 4, 5), torch.ones(3, 6, 5)] + ) + }, + batch_size=[2, 3, -1], + ) assert td.shape == torch.Size([2, 3, -1]) td_padded = td.to_padded_tensor(padding=0, mask_key=mask_key) assert td_padded.shape == torch.Size([2, 3, 6]) @@ -4686,7 +4730,9 @@ def test_nestedtensor_stack(self, td_name, device, dim, key): if dim in (0, -5): # this will work if stack_dim is 0 (or equivalently -self.batch_dims) # it is the proper way to get that entry - td_stack.get_nestedtensor(key) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + td_stack.get_nestedtensor(key) else: # if the stack_dim is not zero, then calling get_nestedtensor is disallowed with pytest.raises( @@ -4744,18 +4790,19 @@ def test_new_tensor(self, td_name, device): td = getattr(self, td_name)(device) if td_name in ("td_params",): td = td.data - tdn = td.new_tensor(torch.zeros(0, device="cpu")) - assert tdn.device == torch.device("cpu") - assert tdn.shape == (0,) - tdn = td.new_tensor(torch.zeros(2, device="cpu")) - assert tdn.device == torch.device("cpu") - assert tdn.shape == (2,) - tdn = td.new_tensor(td[0] * 0) - assert tdn.device == td.device - assert (tdn == 0).all() - assert tdn.shape == td.shape[1:] - if td._has_non_tensor: - assert tdn._has_non_tensor + with pytest.warns(UserWarning, match="To copy construct from a tensor"): + tdn = td.new_tensor(torch.zeros(0, device="cpu")) + assert tdn.device == torch.device("cpu") + assert tdn.shape == (0,) + tdn = td.new_tensor(torch.zeros(2, device="cpu")) + assert tdn.device == torch.device("cpu") + assert tdn.shape == (2,) + tdn = td.new_tensor(td[0] * 0) + assert tdn.device == td.device + assert (tdn == 0).all() + assert tdn.shape == td.shape[1:] + if td._has_non_tensor: + assert tdn._has_non_tensor def test_new_zeros(self, td_name, device): td = getattr(self, td_name)(device) @@ -9605,7 +9652,14 @@ def test_modules(self, as_module): modules = [ lambda: nn.Linear(3, 4), lambda: nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 4)), - lambda: nn.Transformer(16, 4, 2, 2, 8), + lambda: nn.Transformer( + 16, + 4, + 2, + 2, + 8, + batch_first=True, + ), lambda: nn.Sequential(nn.Conv2d(3, 4, 3), nn.Conv2d(4, 4, 3)), ] inputs = [ From 66203914defdbe095911381668b959ff9027d4a5 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 6 Aug 2024 16:45:56 -0400 Subject: [PATCH 2/4] amend --- test/test_nn.py | 3 +++ test/test_tensorclass.py | 10 +++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/test/test_nn.py b/test/test_nn.py index 80d0c802f..220f3eacd 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -71,6 +71,9 @@ # Capture all warnings pytestmark = [ pytest.mark.filterwarnings("error"), + pytest.mark.filterwarnings( + "ignore:You are using `torch.load` with `weights_only=False`" + ), pytest.mark.filterwarnings("ignore:enable_nested_tensor is True"), ] diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index 87896b954..7bbd69342 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -48,7 +48,15 @@ from torch import Tensor # Capture all warnings -pytestmark = pytest.mark.filterwarnings("error") +pytestmark = [ + pytest.mark.filterwarnings("error"), + pytest.mark.filterwarnings( + "ignore:type_hints are none, cannot perform auto-casting" + ), + pytest.mark.filterwarnings( + "ignore:You are using `torch.load` with `weights_only=False`" + ), +] def _make_data(shape): From d2265190135b241c4d123a92a00a9c1299aae306 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 6 Aug 2024 18:20:35 -0400 Subject: [PATCH 3/4] amend --- tensordict/base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tensordict/base.py b/tensordict/base.py index ded5a1817..af496a9cf 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -6973,9 +6973,10 @@ def map_iter( ) finally: try: + pool.close() + pool.join(timeout=60) + except TimeoutError: pool.terminate() - finally: - pool.join() else: yield from self._map( fn=fn, From 7fcb2024c92a377a48a613b6b7982755bed6f0fa Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 6 Aug 2024 18:48:55 -0400 Subject: [PATCH 4/4] amend --- tensordict/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensordict/base.py b/tensordict/base.py index af496a9cf..3487ebb25 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -6974,8 +6974,8 @@ def map_iter( finally: try: pool.close() - pool.join(timeout=60) - except TimeoutError: + pool.join() + except Exception: pool.terminate() else: yield from self._map(