diff --git a/benchmarks/common/memmap_benchmarks_test.py b/benchmarks/common/memmap_benchmarks_test.py index 455a01b6a..389febae6 100644 --- a/benchmarks/common/memmap_benchmarks_test.py +++ b/benchmarks/common/memmap_benchmarks_test.py @@ -7,7 +7,7 @@ import pytest import torch -from tensordict import MemmapTensor, TensorDict +from tensordict import MemoryMappedTensor, TensorDict from torch import nn @@ -25,9 +25,9 @@ def tensor(): return torch.zeros(3, 4, 5) -@pytest.fixture(params=get_available_devices()) +@pytest.fixture(params=[torch.device("cpu")]) def memmap_tensor(request): - return MemmapTensor(3, 4, 5, device=request.param) + return MemoryMappedTensor.zeros((3, 4, 5)) @pytest.fixture @@ -37,14 +37,14 @@ def td_memmap(): ).memmap_() -@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("device", [torch.device("cpu")]) def test_creation(benchmark, device): - benchmark(MemmapTensor, 3, 4, 5, device=device) + benchmark(MemoryMappedTensor.empty, (3, 4, 5)) def test_creation_from_tensor(benchmark, tensor): benchmark( - MemmapTensor.from_tensor, + MemoryMappedTensor.from_tensor, tensor, ) diff --git a/benchmarks/distributed/dataloading.py b/benchmarks/distributed/dataloading.py index dd79e23a1..24e799f14 100644 --- a/benchmarks/distributed/dataloading.py +++ b/benchmarks/distributed/dataloading.py @@ -28,7 +28,7 @@ import torch import tqdm -from tensordict import MemmapTensor +from tensordict import MemoryMappedTensor from tensordict.prototype import tensorclass from torch import multiprocessing as mp, nn from torch.distributed import rpc @@ -109,12 +109,14 @@ class ImageNetData: @classmethod def from_dataset(cls, dataset): data = cls( - images=MemmapTensor( - len(dataset), - *dataset[0][0].squeeze().shape, + images=MemoryMappedTensor.empty( + ( + len(dataset), + *dataset[0][0].squeeze().shape, + ), dtype=torch.uint8, ), - targets=MemmapTensor(len(dataset), dtype=torch.int64), + targets=MemoryMappedTensor.empty(len(dataset), dtype=torch.int64), batch_size=[len(dataset)], ) # locks the tensorclass and ensures that is_memmap will return True. @@ -139,12 +141,14 @@ def load(cls, dataset, path): import torchsnapshot data = cls( - images=MemmapTensor( - len(dataset), - *dataset[0][0].squeeze().shape, + images=MemoryMappedTensor.empty( + ( + len(dataset), + *dataset[0][0].squeeze().shape, + ), dtype=torch.uint8, ), - targets=MemmapTensor(len(dataset), dtype=torch.int64), + targets=MemoryMappedTensor(len(dataset), dtype=torch.int64), batch_size=[len(dataset)], ) # locks the tensorclass and ensures that is_memmap will return True. diff --git a/benchmarks/distributed/distributed_benchmark_test.py b/benchmarks/distributed/distributed_benchmark_test.py index 2b7beb11e..5e890e3d9 100644 --- a/benchmarks/distributed/distributed_benchmark_test.py +++ b/benchmarks/distributed/distributed_benchmark_test.py @@ -2,14 +2,15 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - import os +import pathlib +import tempfile import time import pytest import torch -from tensordict import MemmapTensor, TensorDict +from tensordict import MemoryMappedTensor, TensorDict from torch.distributed import rpc MAIN_NODE = "Main" @@ -44,58 +45,63 @@ def __call__(self, *args, **kwargs): def exec_distributed_test(rank_node): - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29549" - os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" - str_init_method = "tcp://localhost:10001" - options = rpc.TensorPipeRpcBackendOptions( - num_worker_threads=16, init_method=str_init_method - ) - rank = rank_node - if rank == 0: - rpc.init_rpc( - MAIN_NODE, - rank=rank, - backend=rpc.BackendType.TENSORPIPE, - rpc_backend_options=options, + with tempfile.TemporaryDirectory() as tmpdir: + tmpdir = pathlib.Path(tmpdir) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "29549" + os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" + str_init_method = "tcp://localhost:10001" + options = rpc.TensorPipeRpcBackendOptions( + num_worker_threads=16, init_method=str_init_method ) + rank = rank_node + if rank == 0: + rpc.init_rpc( + MAIN_NODE, + rank=rank, + backend=rpc.BackendType.TENSORPIPE, + rpc_backend_options=options, + ) - # create a tensordict is 1Gb big, stored on disk, assuming that both nodes have access to /tmp/ - tensordict = TensorDict( - { - "memmap": MemmapTensor( - 1000, 640, 640, 3, dtype=torch.uint8, prefix="/tmp/" + # create a tensordict is 1Gb big, stored on disk, assuming that both nodes have access to /tmp/ + tensordict = TensorDict( + { + "memmap": MemoryMappedTensor.empty( + (1000, 640, 640, 3), + dtype=torch.uint8, + filename=tmpdir / "mmap.memmap", + ) + }, + [1000], + _is_memmap=True, + ) + assert tensordict.is_memmap() + + while True: + try: + worker_info = rpc.get_worker_info("worker") + break + except RuntimeError: + time.sleep(0.1) + + def fill_tensordict(tensordict, idx): + tensordict[idx] = TensorDict( + {"memmap": torch.ones(5, 640, 640, 3, dtype=torch.uint8)}, [5] ) - }, - [1000], - ) - assert tensordict.is_memmap() - - while True: - try: - worker_info = rpc.get_worker_info("worker") - break - except RuntimeError: - time.sleep(0.1) - - def fill_tensordict(tensordict, idx): - tensordict[idx] = TensorDict( - {"memmap": torch.ones(5, 640, 640, 3, dtype=torch.uint8)}, [5] + return tensordict + + fill_tensordict_cp = CloudpickleWrapper(fill_tensordict) + idx = [0, 1, 2, 3, 999] + rpc.rpc_sync(worker_info, fill_tensordict_cp, args=(tensordict, idx)) + + idx = [4, 5, 6, 7, 998] + rpc.rpc_sync(worker_info, fill_tensordict_cp, args=(tensordict, idx)) + + rpc.shutdown() + elif rank == 1: + rpc.init_rpc( + WORKER_NODE, + rank=rank, + backend=rpc.BackendType.TENSORPIPE, + rpc_backend_options=options, ) - return tensordict - - fill_tensordict_cp = CloudpickleWrapper(fill_tensordict) - idx = [0, 1, 2, 3, 999] - rpc.rpc_sync(worker_info, fill_tensordict_cp, args=(tensordict, idx)) - - idx = [4, 5, 6, 7, 998] - rpc.rpc_sync(worker_info, fill_tensordict_cp, args=(tensordict, idx)) - - rpc.shutdown() - elif rank == 1: - rpc.init_rpc( - WORKER_NODE, - rank=rank, - backend=rpc.BackendType.TENSORPIPE, - rpc_backend_options=options, - ) diff --git a/test/test_memmap.py b/test/test_memmap.py index a820dd56e..4c9ce6555 100644 --- a/test/test_memmap.py +++ b/test/test_memmap.py @@ -29,7 +29,7 @@ ) @pytest.mark.parametrize("shape", [[2], [1, 2]]) def test_memmap_data_type(dtype, shape): - """Test that MemmapTensor can be created with a given data type and shape.""" + """Test that MemoryMappedTensor can be created with a given data type and shape.""" t = torch.tensor([1, 0], dtype=dtype).reshape(shape) m = MemoryMappedTensor.from_tensor(t) assert m.dtype == t.dtype @@ -68,7 +68,7 @@ def test_memmap_new(index): @pytest.mark.parametrize("device", get_available_devices()) def test_memmap_same_device_as_tensor(device): """ - Created MemmapTensor should be on the same device as the input tensor. + Created MemoryMappedTensor should be on the same device as the input tensor. Check if device is correct when .to(device) is called. """ t = torch.tensor([1], device=device) @@ -78,7 +78,7 @@ def test_memmap_same_device_as_tensor(device): @pytest.mark.parametrize("device", get_available_devices()) def test_memmap_create_on_same_device(device): - """Test if the device arg for MemmapTensor init is respected.""" + """Test if the device arg for MemoryMappedTensor init is respected.""" with pytest.raises(ValueError) if device.type != "cpu" else nullcontext(): MemoryMappedTensor([3, 4], device=device) # assert m.device == torch.device(device) @@ -90,7 +90,7 @@ def test_memmap_create_on_same_device(device): @pytest.mark.parametrize("shape", [[3, 4]]) def test_memmap_zero_value(value, shape): """ - Test if all entries are zeros when MemmapTensor is created with size. + Test if all entries are zeros when MemoryMappedTensor is created with size. """ device = "cpu" value = value.to(device) diff --git a/test/test_memmap_deprec.py b/test/test_memmap_deprec.py deleted file mode 100644 index ac2f80b4a..000000000 --- a/test/test_memmap_deprec.py +++ /dev/null @@ -1,556 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. -import argparse -import os.path -import pickle -import tempfile - -import numpy as np -import pytest -import torch -from _utils_internal import get_available_devices - -from tensordict import MemmapTensor -from torch import multiprocessing as mp - -TIMEOUT = 100 - - -def test_memmap_type(): - array = np.random.rand(1) - with pytest.raises( - TypeError, match="Convert input to torch.Tensor before calling MemmapTensor" - ): - MemmapTensor.from_tensor(array) - - -def test_grad(): - t = torch.tensor([1.0]) - MemmapTensor.from_tensor(t) - t = t.requires_grad_() - with pytest.raises( - RuntimeError, match="MemmapTensor is incompatible with tensor.requires_grad." - ): - MemmapTensor.from_tensor(t) - with pytest.raises( - RuntimeError, match="MemmapTensor is incompatible with tensor.requires_grad." - ): - MemmapTensor.from_tensor(t + 1) - - -@pytest.mark.parametrize( - "dtype", - [ - torch.half, - torch.float, - torch.double, - torch.int, - torch.uint8, - torch.long, - torch.bool, - ], -) -@pytest.mark.parametrize("shape", [[2], [1, 2]]) -def test_memmap_data_type(dtype, shape): - """Test that MemmapTensor can be created with a given data type and shape.""" - t = torch.tensor([1, 0], dtype=dtype).reshape(shape) - m = MemmapTensor.from_tensor(t) - assert m.dtype == t.dtype - assert (m == t).all() - assert m.shape == t.shape - - assert m.contiguous().dtype == t.dtype - assert (m.contiguous() == t).all() - assert m.contiguous().shape == t.shape - - assert m.clone().dtype == t.dtype - assert (m.clone() == t).all() - assert m.clone().shape == t.shape - - -def test_memmap_del(): - t = torch.tensor([1]) - m = MemmapTensor.from_tensor(t) - filename = m.filename - assert os.path.isfile(filename) - del m - assert not os.path.isfile(filename) - - -# @pytest.mark.parametrize("transfer_ownership", [True, False]) -# def test_memmap_ownership(transfer_ownership): -# t = torch.tensor([1]) -# m = MemmapTensor.from_tensor(t, transfer_ownership=transfer_ownership) -# assert not m.file.delete -# with tempfile.NamedTemporaryFile(suffix=".pkl") as tmp: -# pickle.dump(m, tmp) -# assert m._has_ownership is not m.transfer_ownership -# m2 = pickle.load(open(tmp.name, "rb")) -# assert m2._memmap_array is None # assert data is not actually loaded -# assert isinstance(m2, MemmapTensor) -# assert m2.filename == m.filename -# # assert m2.file.name == m2.filename -# # assert m2.file._closer.name == m2.filename -# assert ( -# m._has_ownership is not m2._has_ownership -# ) # delete attributes must have changed -# # assert ( -# # m.file._closer.delete is not m2.file._closer.delete -# # ) # delete attributes must have changed -# del m -# if transfer_ownership: -# assert os.path.isfile(m2.filename) -# else: -# # m2 should point to a non-existing file -# assert not os.path.isfile(m2.filename) -# with pytest.raises(FileNotFoundError): -# m2.contiguous() -# -# -@pytest.mark.parametrize("value", [True, False]) -def test_memmap_ownership_2pass(value): - t = torch.tensor([1]) - m1 = MemmapTensor.from_tensor(t, transfer_ownership=value) - filename = m1.filename - with tempfile.NamedTemporaryFile(suffix=".pkl") as tmp2: - pickle.dump(m1, tmp2) - # after we dump m1, m1 has lost ownership and waits for m2 to pick it up - # if m1 is deleted and m2 is never created, the file is not cleared. - if value: - assert not m1._has_ownership - else: - assert m1._has_ownership - - m2 = pickle.load(open(tmp2.name, "rb")) - assert m2.filename == m1.filename - with tempfile.NamedTemporaryFile(suffix=".pkl") as tmp3: - pickle.dump(m2, tmp3) - m3 = pickle.load(open(tmp3.name, "rb")) - assert m3.filename == m1.filename - - del m1, m2, m3 - assert not os.path.isfile(filename) - - -class TestMP: - @staticmethod - def getdata(data, queue): - queue.put(("has_ownership", data._has_ownership)) - queue.put(("transfer_ownership", data.transfer_ownership)) - - @pytest.mark.parametrize("transfer_ownership", [True, False]) - def test(self, transfer_ownership, tmp_path): - m = MemmapTensor( - 3, transfer_ownership=transfer_ownership, filename=tmp_path / "tensor.mp" - ) - queue = mp.Queue(1) - p = mp.Process(target=TestMP.getdata, args=(m, queue)) - p.start() - try: - msg, val = queue.get() - assert msg == "has_ownership" - assert val is transfer_ownership - if transfer_ownership: - assert not m._has_ownership - else: - assert m._has_ownership - msg, val = queue.get() - assert msg == "transfer_ownership" - assert val is transfer_ownership - finally: - p.join() - queue.close() - - -@pytest.mark.parametrize( - "index", - [ - None, - [ - 0, - ], - ], -) -def test_memmap_new(index): - t = torch.tensor([1]) - m = MemmapTensor.from_tensor(t) - if index is not None: - m1 = m[index] - else: - m1 = m - m2 = MemmapTensor.from_tensor(m1) - assert isinstance(m2, MemmapTensor) - assert m2.filename == m1.filename - assert m2.filename == m2.file.name - assert m2.filename == m2.file._closer.name - if index is not None: - assert m2.contiguous() == t[index] - m2c = m2.contiguous() - assert isinstance(m2c, torch.Tensor) - assert m2c == m1 - - -@pytest.mark.parametrize("device", get_available_devices()) -def test_memmap_same_device_as_tensor(device): - """ - Created MemmapTensor should be on the same device as the input tensor. - Check if device is correct when .to(device) is called. - """ - t = torch.tensor([1], device=device) - m = MemmapTensor.from_tensor(t) - assert t.device == torch.device(device) - assert m.device == torch.device(device) - for other_device in get_available_devices(): - if other_device != device: - with pytest.raises( - RuntimeError, - match="Expected all tensors to be on the same device, " - + "but found at least two devices", - ): - assert torch.all(m + torch.ones([3, 4], device=other_device) == 1) - m = m.to(other_device) - assert m.device == torch.device(other_device) - - -@pytest.mark.parametrize("device", get_available_devices()) -def test_memmap_create_on_same_device(device): - """Test if the device arg for MemmapTensor init is respected.""" - m = MemmapTensor([3, 4], device=device) - assert m.device == torch.device(device) - - -@pytest.mark.parametrize("device", get_available_devices()) -@pytest.mark.parametrize( - "value", [torch.zeros([3, 4]), MemmapTensor.from_tensor(torch.zeros([3, 4]))] -) -@pytest.mark.parametrize("shape", [[3, 4], [[3, 4]]]) -def test_memmap_zero_value(device, value, shape): - """ - Test if all entries are zeros when MemmapTensor is created with size. - """ - value = value.to(device) - expected_memmap_tensor = MemmapTensor.from_tensor(value) - m = MemmapTensor(*shape, device=device) - assert m.shape == (3, 4) - assert torch.all(m == expected_memmap_tensor) - assert torch.all(m + torch.ones([3, 4], device=device) == 1) - - -class TestIndexing: - @staticmethod - def _recv_and_send( - queue_out, - queue_in, - filename, - shape, - ): - t = queue_in.get(timeout=TIMEOUT) - assert isinstance(t, MemmapTensor) - assert t.filename == filename - assert t.shape == shape - assert (t == 0).all() - msg = "done" - queue_out.put(msg) - - msg = queue_in.get(timeout=TIMEOUT) - assert msg == "modified" - assert (t == 1).all() - queue_out.put("done!!") - - msg = queue_in.get(timeout=TIMEOUT) - assert msg == "deleted" - assert not os.path.isfile(filename) - with pytest.raises(FileNotFoundError, match="No such file or directory"): - t + 1 - queue_out.put("done again") - del queue_in, queue_out - - def test_simple_index(self): - t = MemmapTensor.from_tensor(torch.zeros(10)) - # int - assert isinstance(t[0], MemmapTensor) - assert t[0].filename == t.filename - assert t[0].shape == torch.Size([]) - assert t.shape == torch.Size([10]) - - def test_range_index(self): - t = MemmapTensor.from_tensor(torch.zeros(10)) - # int - assert isinstance(t[:2], MemmapTensor) - assert t[:2].filename == t.filename - assert t[:2].shape == torch.Size([2]) - assert t.shape == torch.Size([10]) - - def test_double_index(self): - t = MemmapTensor.from_tensor(torch.zeros(10)) - y = t[:2][-1:] - # int - assert isinstance(y, MemmapTensor) - assert y.filename == t.filename - assert y.shape == torch.Size([1]) - assert t.shape == torch.Size([10]) - - def test_ownership(self): - t = MemmapTensor.from_tensor(torch.zeros(10)) - filename = t.filename - y = t[:2][-1:] - del t - # this would fail if t was gone with its file - assert (y * 0 + 1 == 1).all() - del y - # check that file has gone - assert not os.path.isfile(filename) - - @pytest.mark.flaky(reruns=5, reruns_delay=5) - def test_send_across_procs(self): - t = MemmapTensor.from_tensor(torch.zeros(10), transfer_ownership=False) - queue_in = mp.Queue(1) - queue_out = mp.Queue(1) - filename = t.filename - p = mp.Process( - target=TestIndexing._recv_and_send, - args=(queue_in, queue_out, filename, torch.Size([10])), - ) - try: - p.start() - queue_out.put(t, block=True) - msg = queue_in.get(timeout=TIMEOUT) - assert msg == "done" - - t.fill_(1.0) - queue_out.put("modified", block=True) - msg = queue_in.get(timeout=TIMEOUT) - assert msg == "done!!" - - del t - queue_out.put("deleted") - msg = queue_in.get(timeout=TIMEOUT) - assert msg == "done again" - p.join() - except Exception as e: - p.join() - raise e - - @pytest.mark.flaky(reruns=5, reruns_delay=5) - def test_send_across_procs_index(self): - t = MemmapTensor.from_tensor(torch.zeros(10), transfer_ownership=False) - queue_in = mp.Queue(1) - queue_out = mp.Queue(1) - filename = t.filename - p = mp.Process( - target=TestIndexing._recv_and_send, - args=(queue_in, queue_out, filename, torch.Size([3])), - ) - try: - p.start() - queue_out.put(t[:3], block=True) - msg = queue_in.get(timeout=TIMEOUT) - assert msg == "done" - - t.fill_(1.0) - queue_out.put("modified", block=True) - msg = queue_in.get(timeout=TIMEOUT) - assert msg == "done!!" - - del t - queue_out.put("deleted") - msg = queue_in.get(timeout=TIMEOUT) - assert msg == "done again" - p.join() - except Exception as e: - p.join() - raise e - - def test_iteration(self): - t = MemmapTensor.from_tensor(torch.rand(10)) - for i, _t in enumerate(t): - assert _t == t[i] - - def test_iteration_nd(self): - t = MemmapTensor.from_tensor(torch.rand(10, 5)) - for i, _t in enumerate(t): - assert (_t == t[i]).all() - - @staticmethod - def _test_copy_onto_subproc(queue): - t = MemmapTensor.from_tensor(torch.rand(10, 5)) - idx = torch.tensor([1, 2]) - queue.put(t[idx], block=True) - while queue.full(): - continue - - idx = torch.tensor([3, 4]) - queue.put(t[idx], block=True) - while queue.full(): - continue - msg = queue.get(timeout=TIMEOUT) - assert msg == "done" - del queue - - def test_copy_onto(self): - queue = mp.Queue(1) - p = mp.Process(target=TestIndexing._test_copy_onto_subproc, args=(queue,)) - p.start() - try: - t_indexed1 = queue.get(timeout=TIMEOUT) - assert (t_indexed1._index[0] == torch.tensor([1, 2])).all() - # check that file is not opened if we did not access it - assert t_indexed1._memmap_array is None - _ = t_indexed1 + 1 - # check that file is now opened - assert t_indexed1._memmap_array is not None - - # receive 2nd copy - t_indexed2 = queue.get(timeout=TIMEOUT) - assert t_indexed2.filename == t_indexed1.filename - assert (t_indexed2._index[0] == torch.tensor([3, 4])).all() - # check that file is open only once - assert t_indexed1._memmap_array is not None - assert t_indexed2._memmap_array is None - t_indexed1.copy_(t_indexed2) - # same assertion: after copying we should only have one file opened - assert t_indexed1._memmap_array is not None - assert t_indexed2._memmap_array is None - _ = t_indexed2 + 1 - # now we should find 2 opened files - assert t_indexed1._memmap_array is not None - assert t_indexed2._memmap_array is not None - queue.put("done", block=True) - queue.close() - p.join() - except Exception as e: - p.join() - raise e - - -def test_as_tensor(): - num_samples = 300 - rows, cols = 48, 48 - idx = torch.randint(num_samples, (128,)) - y = MemmapTensor(num_samples, rows, cols, dtype=torch.uint8) - y.copy_(y + torch.randn(num_samples, rows, cols)) - assert isinstance(y, MemmapTensor) - assert isinstance(y[idx], MemmapTensor) - assert (y[idx] == y.as_tensor()[idx]).all() - - -def test_filename(tmp_path): - mt = MemmapTensor(10, dtype=torch.float32, filename=tmp_path / "test.memmap") - assert mt.filename == str(tmp_path / "test.memmap") - - mt2 = MemmapTensor.from_tensor(mt) - assert mt2.filename == str(tmp_path / "test.memmap") - assert mt2 is mt - - mt3 = MemmapTensor.from_tensor(mt, filename=tmp_path / "test.memmap") - assert mt3.filename == str(tmp_path / "test.memmap") - assert mt3 is mt - - mt4 = MemmapTensor.from_tensor(mt, filename=tmp_path / "test2.memmap") - assert mt4.filename == str(tmp_path / "test2.memmap") - assert mt4 is not mt - - del mt - del mt4 - # files should persist - assert (tmp_path / "test.memmap").exists() - assert (tmp_path / "test2.memmap").exists() - - -@pytest.mark.parametrize( - "mode", ["r", "r+", "w+", "c", "readonly", "readwrite", "write", "copyonwrite"] -) -def test_mode(mode, tmp_path): - mt = MemmapTensor(10, dtype=torch.float32, filename=tmp_path / "test.memmap") - mt[:] = torch.ones(10) * 1.5 - del mt - - if mode in ("r", "readonly"): - with pytest.raises(ValueError, match=r"Accepted values for mode are"): - MemmapTensor( - 10, dtype=torch.float32, filename=tmp_path / "test.memmap", mode=mode - ) - return - mt = MemmapTensor( - 10, dtype=torch.float32, filename=tmp_path / "test.memmap", mode=mode - ) - if mode in ("r+", "readwrite", "c", "copyonwrite"): - # data in memmap persists - assert (mt.as_tensor() == 1.5).all() - elif mode in ("w+", "write"): - # memmap is initialized to zero - assert (mt.as_tensor() == 0).all() - - mt[:] = torch.ones(10) * 2.5 - assert (mt.as_tensor() == 2.5).all() - del mt - - mt2 = MemmapTensor(10, dtype=torch.float32, filename=tmp_path / "test.memmap") - if mode in ("c", "copyonwrite"): - # tensor was only mutated in memory, not on disk - assert (mt2.as_tensor() == 1.5).all() - else: - assert (mt2.as_tensor() == 2.5).all() - - -def test_memmap_from_memmap(): - mt2 = MemmapTensor.from_tensor(MemmapTensor(4, 3, 2, 1)) - assert mt2.squeeze(-1).shape == torch.Size([4, 3, 2]) - - -def test_memmap_cast(): - # ensure memmap can be cast to tensor and viceversa - x = torch.zeros(3, 4, 5) - y = MemmapTensor.from_tensor(torch.ones(3, 4, 5)) - - x[:2] = y[:2] - assert (x[:2] == 1).all() - y[2:] = x[2:] - assert (y[2:] == 0).all() - - -@pytest.fixture -def dummy_memmap(): - return MemmapTensor.from_tensor(torch.zeros(10, 11)) - - -@pytest.mark.parametrize("device", get_available_devices()) -class TestOps: - def test_eq(self, device, dummy_memmap): - dummy_memmap.device = device - assert (dummy_memmap == dummy_memmap.clone()).all() - assert (dummy_memmap.clone() == dummy_memmap).all() - if device.type == "cpu": - assert (dummy_memmap == dummy_memmap.as_tensor()).all() - assert (dummy_memmap.as_tensor() == dummy_memmap).all() - else: - assert (dummy_memmap == dummy_memmap._tensor).all() - assert (dummy_memmap._tensor == dummy_memmap).all() - - def test_fill_(self, device, dummy_memmap): - memmap = dummy_memmap.to(device) - assert (memmap.fill_(1.0) == 1).all() - - def test_copy_(self, device, dummy_memmap): - memmap = dummy_memmap.to(device) - assert (memmap.copy_(torch.ones(10, 11, device=device)) == 1).all() - assert (torch.ones(10, 11, device=device).copy_(memmap) == 1).all() - - def test_or(self, device): - memmap = MemmapTensor.from_tensor(torch.ones(10, 11, dtype=torch.bool)).to( - device - ) - assert (memmap | (~memmap)).all() - - def test_ne(self, device): - memmap = MemmapTensor.from_tensor(torch.ones(10, 11, dtype=torch.bool)).to( - device - ) - assert (memmap != ~memmap).all() - - -if __name__ == "__main__": - args, unknown = argparse.ArgumentParser().parse_known_args() - pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index 28797971c..09f93d33c 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -372,7 +372,7 @@ class MyDataNested: def test_setitem_memmap(): # regression test PR #203 - # We should be able to set tensors items with MemmapTensors and viceversa + # We should be able to set tensors items with MemoryMappedTensors and viceversa @tensorclass class MyDataMemMap1: x: torch.Tensor diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 4510eef7a..5c8fb40f5 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -47,7 +47,7 @@ ) from functorch import dim as ftdim -from tensordict import LazyStackedTensorDict, make_tensordict, MemmapTensor, TensorDict +from tensordict import LazyStackedTensorDict, make_tensordict, TensorDict from tensordict._lazy import _CustomOpTensorDict from tensordict._td import _SubTensorDict, is_tensor_collection from tensordict._torch_func import _stack as stack_td @@ -1895,10 +1895,6 @@ def test_rename_key(self, td_name, device) -> None: assert "a" not in td.keys() z = td.get("z") - if isinstance(a, MemmapTensor): - a = a._tensor - if isinstance(z, MemmapTensor): - z = z._tensor torch.testing.assert_close(a, z) new_z = torch.randn_like(z) @@ -2023,13 +2019,13 @@ def test_setitem_string(self, td_name, device): def test_getitem_string(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device) - assert isinstance(td["a"], (MemmapTensor, torch.Tensor)) + assert isinstance(td["a"], torch.Tensor) def test_getitem_nestedtuple(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device) - assert isinstance(td[(("a",))], (MemmapTensor, torch.Tensor)) - assert isinstance(td.get((("a",))), (MemmapTensor, torch.Tensor)) + assert isinstance(td[(("a",))], torch.Tensor) + assert isinstance(td.get((("a",))), torch.Tensor) def test_setitem_nestedtuple(self, td_name, device): torch.manual_seed(1) @@ -2361,7 +2357,7 @@ def test_as_tensor(self, td_name, device): assert (tdt == td).all() elif "memmap" in td_name: with pytest.raises( - RuntimeError, match="can only be called with MemmapTensors stored" + RuntimeError, match="can only be called with MemoryMappedTensors stored" ): td.as_tensor() else: @@ -3076,7 +3072,7 @@ def test_add_batch_dim_cache(self, td_name, device, nested): if td_name == "memmap_td" and device.type != "cpu": with pytest.raises( RuntimeError, - match="MemmapTensor with non-cpu device are not supported in vmap ops", + match="MemoryMappedTensor with non-cpu device are not supported in vmap ops", ): fun(td) return