Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Remove remaining MemmapTensor references #617

Merged
merged 1 commit into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions benchmarks/common/memmap_benchmarks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest
import torch

from tensordict import MemmapTensor, TensorDict
from tensordict import MemoryMappedTensor, TensorDict
from torch import nn


Expand All @@ -25,9 +25,9 @@ def tensor():
return torch.zeros(3, 4, 5)


@pytest.fixture(params=get_available_devices())
@pytest.fixture(params=[torch.device("cpu")])
def memmap_tensor(request):
return MemmapTensor(3, 4, 5, device=request.param)
return MemoryMappedTensor.zeros((3, 4, 5))


@pytest.fixture
Expand All @@ -37,14 +37,14 @@ def td_memmap():
).memmap_()


@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("device", [torch.device("cpu")])
def test_creation(benchmark, device):
benchmark(MemmapTensor, 3, 4, 5, device=device)
benchmark(MemoryMappedTensor.empty, (3, 4, 5))


def test_creation_from_tensor(benchmark, tensor):
benchmark(
MemmapTensor.from_tensor,
MemoryMappedTensor.from_tensor,
tensor,
)

Expand Down
22 changes: 13 additions & 9 deletions benchmarks/distributed/dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import torch
import tqdm

from tensordict import MemmapTensor
from tensordict import MemoryMappedTensor
from tensordict.prototype import tensorclass
from torch import multiprocessing as mp, nn
from torch.distributed import rpc
Expand Down Expand Up @@ -109,12 +109,14 @@ class ImageNetData:
@classmethod
def from_dataset(cls, dataset):
data = cls(
images=MemmapTensor(
len(dataset),
*dataset[0][0].squeeze().shape,
images=MemoryMappedTensor.empty(
(
len(dataset),
*dataset[0][0].squeeze().shape,
),
dtype=torch.uint8,
),
targets=MemmapTensor(len(dataset), dtype=torch.int64),
targets=MemoryMappedTensor.empty(len(dataset), dtype=torch.int64),
batch_size=[len(dataset)],
)
# locks the tensorclass and ensures that is_memmap will return True.
Expand All @@ -139,12 +141,14 @@ def load(cls, dataset, path):
import torchsnapshot

data = cls(
images=MemmapTensor(
len(dataset),
*dataset[0][0].squeeze().shape,
images=MemoryMappedTensor.empty(
(
len(dataset),
*dataset[0][0].squeeze().shape,
),
dtype=torch.uint8,
),
targets=MemmapTensor(len(dataset), dtype=torch.int64),
targets=MemoryMappedTensor(len(dataset), dtype=torch.int64),
batch_size=[len(dataset)],
)
# locks the tensorclass and ensures that is_memmap will return True.
Expand Down
112 changes: 59 additions & 53 deletions benchmarks/distributed/distributed_benchmark_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import pathlib
import tempfile
import time

import pytest
import torch

from tensordict import MemmapTensor, TensorDict
from tensordict import MemoryMappedTensor, TensorDict
from torch.distributed import rpc

MAIN_NODE = "Main"
Expand Down Expand Up @@ -44,58 +45,63 @@ def __call__(self, *args, **kwargs):


def exec_distributed_test(rank_node):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29549"
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
str_init_method = "tcp://localhost:10001"
options = rpc.TensorPipeRpcBackendOptions(
num_worker_threads=16, init_method=str_init_method
)
rank = rank_node
if rank == 0:
rpc.init_rpc(
MAIN_NODE,
rank=rank,
backend=rpc.BackendType.TENSORPIPE,
rpc_backend_options=options,
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir = pathlib.Path(tmpdir)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29549"
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
str_init_method = "tcp://localhost:10001"
options = rpc.TensorPipeRpcBackendOptions(
num_worker_threads=16, init_method=str_init_method
)
rank = rank_node
if rank == 0:
rpc.init_rpc(
MAIN_NODE,
rank=rank,
backend=rpc.BackendType.TENSORPIPE,
rpc_backend_options=options,
)

# create a tensordict is 1Gb big, stored on disk, assuming that both nodes have access to /tmp/
tensordict = TensorDict(
{
"memmap": MemmapTensor(
1000, 640, 640, 3, dtype=torch.uint8, prefix="/tmp/"
# create a tensordict is 1Gb big, stored on disk, assuming that both nodes have access to /tmp/
tensordict = TensorDict(
{
"memmap": MemoryMappedTensor.empty(
(1000, 640, 640, 3),
dtype=torch.uint8,
filename=tmpdir / "mmap.memmap",
)
},
[1000],
_is_memmap=True,
)
assert tensordict.is_memmap()

while True:
try:
worker_info = rpc.get_worker_info("worker")
break
except RuntimeError:
time.sleep(0.1)

def fill_tensordict(tensordict, idx):
tensordict[idx] = TensorDict(
{"memmap": torch.ones(5, 640, 640, 3, dtype=torch.uint8)}, [5]
)
},
[1000],
)
assert tensordict.is_memmap()

while True:
try:
worker_info = rpc.get_worker_info("worker")
break
except RuntimeError:
time.sleep(0.1)

def fill_tensordict(tensordict, idx):
tensordict[idx] = TensorDict(
{"memmap": torch.ones(5, 640, 640, 3, dtype=torch.uint8)}, [5]
return tensordict

fill_tensordict_cp = CloudpickleWrapper(fill_tensordict)
idx = [0, 1, 2, 3, 999]
rpc.rpc_sync(worker_info, fill_tensordict_cp, args=(tensordict, idx))

idx = [4, 5, 6, 7, 998]
rpc.rpc_sync(worker_info, fill_tensordict_cp, args=(tensordict, idx))

rpc.shutdown()
elif rank == 1:
rpc.init_rpc(
WORKER_NODE,
rank=rank,
backend=rpc.BackendType.TENSORPIPE,
rpc_backend_options=options,
)
return tensordict

fill_tensordict_cp = CloudpickleWrapper(fill_tensordict)
idx = [0, 1, 2, 3, 999]
rpc.rpc_sync(worker_info, fill_tensordict_cp, args=(tensordict, idx))

idx = [4, 5, 6, 7, 998]
rpc.rpc_sync(worker_info, fill_tensordict_cp, args=(tensordict, idx))

rpc.shutdown()
elif rank == 1:
rpc.init_rpc(
WORKER_NODE,
rank=rank,
backend=rpc.BackendType.TENSORPIPE,
rpc_backend_options=options,
)
8 changes: 4 additions & 4 deletions test/test_memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
)
@pytest.mark.parametrize("shape", [[2], [1, 2]])
def test_memmap_data_type(dtype, shape):
"""Test that MemmapTensor can be created with a given data type and shape."""
"""Test that MemoryMappedTensor can be created with a given data type and shape."""
t = torch.tensor([1, 0], dtype=dtype).reshape(shape)
m = MemoryMappedTensor.from_tensor(t)
assert m.dtype == t.dtype
Expand Down Expand Up @@ -68,7 +68,7 @@ def test_memmap_new(index):
@pytest.mark.parametrize("device", get_available_devices())
def test_memmap_same_device_as_tensor(device):
"""
Created MemmapTensor should be on the same device as the input tensor.
Created MemoryMappedTensor should be on the same device as the input tensor.
Check if device is correct when .to(device) is called.
"""
t = torch.tensor([1], device=device)
Expand All @@ -78,7 +78,7 @@ def test_memmap_same_device_as_tensor(device):

@pytest.mark.parametrize("device", get_available_devices())
def test_memmap_create_on_same_device(device):
"""Test if the device arg for MemmapTensor init is respected."""
"""Test if the device arg for MemoryMappedTensor init is respected."""
with pytest.raises(ValueError) if device.type != "cpu" else nullcontext():
MemoryMappedTensor([3, 4], device=device)
# assert m.device == torch.device(device)
Expand All @@ -90,7 +90,7 @@ def test_memmap_create_on_same_device(device):
@pytest.mark.parametrize("shape", [[3, 4]])
def test_memmap_zero_value(value, shape):
"""
Test if all entries are zeros when MemmapTensor is created with size.
Test if all entries are zeros when MemoryMappedTensor is created with size.
"""
device = "cpu"
value = value.to(device)
Expand Down
Loading
Loading