diff --git a/benchmarks/common/common_ops_test.py b/benchmarks/common/common_ops_test.py index 15b04b4f4..b7c4d6e8e 100644 --- a/benchmarks/common/common_ops_test.py +++ b/benchmarks/common/common_ops_test.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + import argparse import pytest diff --git a/benchmarks/common/memmap_benchmarks_test.py b/benchmarks/common/memmap_benchmarks_test.py index 389febae6..5f5e2849c 100644 --- a/benchmarks/common/memmap_benchmarks_test.py +++ b/benchmarks/common/memmap_benchmarks_test.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + import argparse import pathlib import time diff --git a/benchmarks/common/pytree_benchmarks_test.py b/benchmarks/common/pytree_benchmarks_test.py index 88bd931e0..0b2900c0f 100644 --- a/benchmarks/common/pytree_benchmarks_test.py +++ b/benchmarks/common/pytree_benchmarks_test.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + import pytest import torch diff --git a/benchmarks/nn/functional_benchmarks_test.py b/benchmarks/nn/functional_benchmarks_test.py index d9784ed90..96263ff61 100644 --- a/benchmarks/nn/functional_benchmarks_test.py +++ b/benchmarks/nn/functional_benchmarks_test.py @@ -1,3 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + # we use deepcopy as our implementation modifies the modules in-place import argparse from copy import deepcopy diff --git a/benchmarks/tensorclass/test_tensorclass_speed.py b/benchmarks/tensorclass/test_tensorclass_speed.py new file mode 100644 index 000000000..0c8bf8c7d --- /dev/null +++ b/benchmarks/tensorclass/test_tensorclass_speed.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import argparse + +import pytest +import torch + +from tensordict.prototype import tensorclass + + +@tensorclass +class MyData: + a: torch.Tensor + b: torch.Tensor + c: str + d: "MyData" = None + + +def test_tc_init(benchmark): + z = torch.zeros(()) + o = torch.ones(()) + benchmark(lambda: MyData(a=z, b=o, c="a string", d=None)) + + +def test_tc_init_nested(benchmark): + z = torch.zeros(()) + o = torch.ones(()) + benchmark( + lambda: MyData(a=z, b=o, c="a string", d=MyData(a=z, b=o, c="a string", d=None)) + ) + + +def test_tc_first_layer_tensor(benchmark): + d = MyData(a=0, b=1, c="a string", d=MyData(None, None, None)) + benchmark(lambda: d.a) + + +def test_tc_first_layer_nontensor(benchmark): + d = MyData(a=0, b=1, c="a string", d=MyData(None, None, None)) + benchmark(lambda: d.c) + + +def test_tc_second_layer_tensor(benchmark): + d = MyData(a=0, b=1, c="a string", d=MyData(torch.zeros(()), None, None)) + benchmark(lambda: d.d.a) + + +def test_tc_second_layer_nontensor(benchmark): + d = MyData(a=0, b=1, c="a string", d=MyData(torch.zeros(()), None, "a string")) + benchmark(lambda: d.d.c) + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/benchmarks/tensorclass/torch_functions.py b/benchmarks/tensorclass/test_torch_functions.py similarity index 88% rename from benchmarks/tensorclass/torch_functions.py rename to benchmarks/tensorclass/test_torch_functions.py index 4e3977c10..0dc3c560c 100644 --- a/benchmarks/tensorclass/torch_functions.py +++ b/benchmarks/tensorclass/test_torch_functions.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + import pytest import torch