Skip to content

Commit

Permalink
Fix pin_memory_fn for NamedTuple (#1086)
Browse files Browse the repository at this point in the history
Summary:
Fixes #1085

Per title. And, even though I can add a test, this test won't be executed as we don't have a GPU CI machine yet.

I have tested on my local machine though

Pull Request resolved: #1086

Reviewed By: NivekT

Differential Revision: D44094225

Pulled By: ejguan

fbshipit-source-id: 9c8414c31b76c93cee7e31c4e2da14076e9792bf
  • Loading branch information
ejguan authored and NivekT committed Apr 19, 2023
1 parent 4dee790 commit d8250dc
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
22 changes: 21 additions & 1 deletion test/test_iterdatapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import asyncio
import io
import itertools
import pickle
import unittest
import warnings

from collections import defaultdict
from typing import Dict
from typing import Dict, NamedTuple

import expecttest
import torch
Expand Down Expand Up @@ -78,6 +79,21 @@ def _convert_to_tensor(data):
return torch.tensor(data)


async def _async_mul_ten(x):
await asyncio.sleep(0.1)
return x * 10


async def _async_x_mul_y(x, y):
await asyncio.sleep(0.1)
return x * y


class NamedTensors(NamedTuple):
x: torch.Tensor
y: torch.Tensor


class TestIterDataPipe(expecttest.TestCase):
def test_in_memory_cache_holder_iterdatapipe(self) -> None:
source_dp = IterableWrapper(range(10))
Expand Down Expand Up @@ -1508,6 +1524,10 @@ def test_pin_memory(self):
dp = IterableWrapper([{str(i): (i, i + 1)} for i in range(10)]).map(_convert_to_tensor).pin_memory()
self.assertTrue(all(v.is_pinned() for d in dp for v in d.values()))

# NamedTuple
dp = IterableWrapper([NamedTensors(torch.tensor(i), torch.tensor(i + 1)) for i in range(10)]).pin_memory()
self.assertTrue(all(v.is_pinned() for d in dp for v in d))

# Dict of List of Tensors
dp = (
IterableWrapper([{str(i): [(i - 1, i), (i, i + 1)]} for i in range(10)])
Expand Down
2 changes: 1 addition & 1 deletion torchdata/datapipes/utils/pin_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def pin_memory_fn(data, device=None):
elif isinstance(data, collections.abc.Sequence):
pinned_data = [pin_memory_fn(sample, device) for sample in data] # type: ignore[assignment]
try:
type(data)(*pinned_data)
return type(data)(*pinned_data)
except TypeError:
# The sequence type may not support `__init__(iterable)` (e.g., `range`).
return pinned_data
Expand Down

0 comments on commit d8250dc

Please sign in to comment.