Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Quality] non_blocking_pin instead of pin_memory #915

Merged
merged 2 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9233,15 +9233,15 @@ def to(self, *args, **kwargs) -> T:
a dtype, the dtype is gathered from the example leaves.
If there are more than one dtype, then no dtype
casting is undertook.
pin_memory (bool, optional): if ``True``, the tensors are pinned before
non_blocking_pin (bool, optional): if ``True``, the tensors are pinned before
being sent to device. This will be done asynchronously but can be
controlled via the ``num_threads`` argument.

.. note:: Calling ``tensordict.pin_memory().to("cuda")`` will usually
be much slower than ``tensordict.to("cuda", pin_memory=True)`` as
be much slower than ``tensordict.to("cuda", non_blocking_pin=True)`` as
the pin_memory is called asynchronously in the second case.

num_threads (int or None, optional): if ``pin_memory=True``, the number
num_threads (int or None, optional): if ``non_blocking_pin=True``, the number
of threads to be used for ``pin_memory``. By default, multithreading
will be used with ``num_threads=None`` in
:meth:`~concurrent.futures.ThreadPoolExecutor(max_workers=None)`, which will
Expand Down Expand Up @@ -9269,7 +9269,7 @@ def to(self, *args, **kwargs) -> T:
_,
convert_to_format,
batch_size,
pin_memory,
non_blocking_pin,
num_threads,
) = _parse_to(*args, **kwargs)
result = self
Expand Down Expand Up @@ -9302,7 +9302,7 @@ def to(tensor):

apply_kwargs = {}
if device is not None or dtype is not None:
if pin_memory and num_threads != 0:
if non_blocking_pin and num_threads != 0:
result = self._multithread_apply_nest(
lambda x: x.pin_memory(),
num_threads=num_threads,
Expand All @@ -9311,7 +9311,7 @@ def to(tensor):
checked=True,
)
else:
if pin_memory:
if non_blocking_pin:
result = result.pin_memory()
apply_kwargs["device"] = device if device is not None else self.device
apply_kwargs["batch_size"] = batch_size
Expand Down
6 changes: 3 additions & 3 deletions tensordict/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,12 +973,12 @@ def to(self, *args, **kwargs: Any) -> PersistentTensorDict:
non_blocking,
convert_to_format,
batch_size,
pin_memory,
non_blocking_pin,
num_threads,
) = _parse_to(*args, **kwargs)
if pin_memory:
if non_blocking_pin:
raise RuntimeError(
f"Cannot call pin_memory {type(self).__name__}.to(). Call "
f"Cannot use non_blocking_pin=True {type(self).__name__}.to(). Call "
f"`to_tensordict()` before executing this code."
)
result = self
Expand Down
43 changes: 2 additions & 41 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1214,45 +1214,6 @@ def new_func(self, *args, **kwargs):
return new_func


# class as_decorator:
# """Converts a method to a decorator.
#
# Examples:
# >>> from tensordict import TensorDict
# >>> data = TensorDict({}, [])
# >>> with data.lock_(): # lock_ is decorated
# ... assert data.is_locked
# >>> assert not data.is_locked
# """
#
# def __init__(self, attr=None):
# self.attr = attr
#
# def __call__(self, func):
# if self.attr is not None:
#
# @wraps(func)
# def new_func(_self, *args, **kwargs):
# _attr_pre = getattr(_self, self.attr)
# out = func(_self, *args, **kwargs)
# _attr_post = getattr(_self, self.attr)
# if out is not None:
# if _attr_post is not _attr_pre:
# out._last_op = (new_func.__name__, (args, kwargs, _self))
# else:
# out._last_op = None
# return out
#
# else:
#
# @wraps(func)
# def new_func(_self, *args, **kwargs):
# out = func(_self, *args, **kwargs)
# if out is not None:
# out._last_op = (new_func.__name__, (args, kwargs, _self))
# return out
#
# return new_func
def _as_context_manager(attr=None):
"""Converts a method to a decorator.

Expand Down Expand Up @@ -1401,7 +1362,7 @@ def _split_generator():

def _parse_to(*args, **kwargs):
batch_size = kwargs.pop("batch_size", None)
pin_memory = kwargs.pop("pin_memory", False)
non_blocking_pin = kwargs.pop("non_blocking_pin", False)
num_threads = kwargs.pop("num_threads", None)
other = kwargs.pop("other", None)
if not torch.compiler.is_dynamo_compiling():
Expand Down Expand Up @@ -1440,7 +1401,7 @@ def _parse_to(*args, **kwargs):
non_blocking,
convert_to_format,
batch_size,
pin_memory,
non_blocking_pin,
num_threads,
)

Expand Down
31 changes: 20 additions & 11 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3035,30 +3035,39 @@ def test_cache(self, td_name, device, op):
)
@pytest.mark.parametrize("device_cast", get_available_devices())
@pytest.mark.parametrize(
"pin_memory", [False] if not torch.cuda.is_available() else [False, True]
"non_blocking_pin", [False] if not torch.cuda.is_available() else [False, True]
)
@pytest.mark.parametrize("num_threads", [0, 1, 4, None])
def test_cast_device(self, td_name, device, device_cast, pin_memory, num_threads):
def test_cast_device(
self, td_name, device, device_cast, non_blocking_pin, num_threads
):
torch.manual_seed(1)
td = getattr(self, td_name)(device)
if pin_memory and td_name == "td_h5":
if non_blocking_pin and td_name == "td_h5":
with pytest.raises(
RuntimeError, match="Cannot call pin_memory PersistentTensorDict.to()"
RuntimeError,
match="Cannot use non_blocking_pin=True PersistentTensorDict.to()",
):
td_device = td.to(
device_cast, pin_memory=pin_memory, num_threads=num_threads
device_cast,
non_blocking_pin=non_blocking_pin,
num_threads=num_threads,
)
return

if device.type == "cuda" and device_cast.type == "cpu" and pin_memory:
if device.type == "cuda" and device_cast.type == "cpu" and non_blocking_pin:
with pytest.raises(
RuntimeError, match="only dense CPU tensors can be pinned"
):
td_device = td.to(
device_cast, pin_memory=pin_memory, num_threads=num_threads
device_cast,
non_blocking_pin=non_blocking_pin,
num_threads=num_threads,
)
return
td_device = td.to(device_cast, pin_memory=pin_memory, num_threads=num_threads)
td_device = td.to(
device_cast, non_blocking_pin=non_blocking_pin, num_threads=num_threads
)

for item in td_device.values():
assert item.device == device_cast
Expand Down Expand Up @@ -8606,16 +8615,16 @@ def test_subtd(self):

@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize(
"pin_memory", [False] if not torch.cuda.is_available() else [False, True]
"non_blocking_pin", [False] if not torch.cuda.is_available() else [False, True]
)
@pytest.mark.parametrize("num_threads", [0, 1, 4, None])
def test_to(self, device, pin_memory, num_threads):
def test_to(self, device, non_blocking_pin, num_threads):
td = TensorDict(
{"": TensorDict({}, [3, 4, 1, 6])},
batch_size=[3, 4, 1, 6],
names=["a", "b", "c", "d"],
)
tdt = td.to(device, pin_memory=pin_memory, num_threads=num_threads)
tdt = td.to(device, non_blocking_pin=non_blocking_pin, num_threads=num_threads)
assert tdt.names == ["a", "b", "c", "d"]

def test_unbind(self):
Expand Down
Loading