From ad3474e653a3860f95afbf52740b19af0e500275 Mon Sep 17 00:00:00 2001 From: Erjia Guan Date: Mon, 27 Jun 2022 09:17:32 -0700 Subject: [PATCH] Raise warning for unpickable local function (#80232) Summary: X-link: https://github.com/pytorch/pytorch/pull/80232 Pull Request resolved: https://github.com/pytorch/data/pull/547 Fixes https://github.com/pytorch/data/issues/538 - Improve the validation function to raise warning about unpickable function when either lambda or local function is provided to DataPipe. - The inner function from functools.partial object is extracted as well for validation - Mimic the behavior of pickle module for local lambda function: It would only raise Error for the local function rather than lambda function. So, we will raise warning about local function not lambda function. ```py >>> import pickle >>> def fn(): ... lf = lambda x: x ... pickle.dumps(lf) >>> pickle.dumps(fn) AttributeError: Can't pickle local object 'fn..' ``` This Diff also fixes the Error introduced by https://github.com/pytorch/pytorch/pull/79344 Reviewed By: NivekT Differential Revision: D37417556 fbshipit-source-id: 6babb73a7daad470ec93c3bd0a0c08d71849d3c8 --- test/test_serialization.py | 12 +++++------- torchdata/datapipes/iter/transform/callable.py | 6 +++--- torchdata/datapipes/iter/util/cacheholder.py | 5 +++-- torchdata/datapipes/iter/util/combining.py | 12 ++++++------ torchdata/datapipes/iter/util/converter.py | 5 +++-- torchdata/datapipes/iter/util/paragraphaggregator.py | 4 ++-- 6 files changed, 22 insertions(+), 22 deletions(-) diff --git a/test/test_serialization.py b/test/test_serialization.py index f867c1acf..4138f54e4 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -374,15 +374,13 @@ def test_serializable_with_dill(self): else: dp_no_attribute_error = (iterdp.OnDiskCacheHolder,) try: - with warnings.catch_warnings(record=True) as wa: + with self.assertWarnsRegex(UserWarning, r"^Local function is not supported by pickle"): datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg] - self.assertEqual(len(wa), 1) - self.assertRegex(str(wa[0].message), r"^Lambda function is not supported for pickle") - if isinstance(datapipe, dp_no_attribute_error): + if isinstance(datapipe, dp_no_attribute_error): + _ = pickle.dumps(datapipe) + else: + with self.assertRaises(AttributeError): _ = pickle.dumps(datapipe) - else: - with self.assertRaises(AttributeError): - _ = pickle.dumps(datapipe) except Exception as e: print(f"{dpipe} is failing.") raise e diff --git a/torchdata/datapipes/iter/transform/callable.py b/torchdata/datapipes/iter/transform/callable.py index c1187c451..3762cd5a1 100644 --- a/torchdata/datapipes/iter/transform/callable.py +++ b/torchdata/datapipes/iter/transform/callable.py @@ -7,7 +7,7 @@ from typing import Callable, Iterator, List, TypeVar from torch.utils.data import functional_datapipe, IterDataPipe -from torch.utils.data.datapipes.utils.common import _check_lambda_fn +from torch.utils.data.datapipes.utils.common import _check_unpickable_fn T_co = TypeVar("T_co", covariant=True) @@ -59,7 +59,7 @@ def __init__( ) -> None: self.datapipe = datapipe - _check_lambda_fn(fn) + _check_unpickable_fn(fn) self.fn = fn # type: ignore[assignment] assert batch_size > 0, "Batch size is required to be larger than 0!" @@ -118,7 +118,7 @@ class FlatMapperIterDataPipe(IterDataPipe[T_co]): def __init__(self, datapipe: IterDataPipe, fn: Callable, input_col=None) -> None: self.datapipe = datapipe - _check_lambda_fn(fn) + _check_unpickable_fn(fn) self.fn = fn # type: ignore[assignment] self.input_col = input_col diff --git a/torchdata/datapipes/iter/util/cacheholder.py b/torchdata/datapipes/iter/util/cacheholder.py index ce78e70b2..b3723a368 100644 --- a/torchdata/datapipes/iter/util/cacheholder.py +++ b/torchdata/datapipes/iter/util/cacheholder.py @@ -27,7 +27,7 @@ raise -from torch.utils.data.datapipes.utils.common import _check_lambda_fn, DILL_AVAILABLE +from torch.utils.data.datapipes.utils.common import _check_unpickable_fn, DILL_AVAILABLE from torch.utils.data.graph import traverse from torchdata.datapipes import functional_datapipe @@ -184,7 +184,8 @@ def __init__( ): self.source_datapipe = source_datapipe - _check_lambda_fn(filepath_fn) + if filepath_fn is not None: + _check_unpickable_fn(filepath_fn) filepath_fn = _generator_to_list(filepath_fn) if inspect.isgeneratorfunction(filepath_fn) else filepath_fn if hash_dict is not None and hash_type not in ("sha256", "md5"): diff --git a/torchdata/datapipes/iter/util/combining.py b/torchdata/datapipes/iter/util/combining.py index 7c094bc57..a7148d26f 100644 --- a/torchdata/datapipes/iter/util/combining.py +++ b/torchdata/datapipes/iter/util/combining.py @@ -9,7 +9,7 @@ from typing import Callable, Iterator, Optional, TypeVar from torch.utils.data import functional_datapipe, IterDataPipe, MapDataPipe -from torch.utils.data.datapipes.utils.common import _check_lambda_fn +from torch.utils.data.datapipes.utils.common import _check_unpickable_fn T_co = TypeVar("T_co", covariant=True) @@ -64,14 +64,14 @@ def __init__( raise TypeError(f"ref_datapipe must be a IterDataPipe, but its type is {type(ref_datapipe)} instead.") self.source_datapipe = source_datapipe self.ref_datapipe = ref_datapipe - _check_lambda_fn(key_fn) + _check_unpickable_fn(key_fn) self.key_fn = key_fn if ref_key_fn is not None: - _check_lambda_fn(ref_key_fn) + _check_unpickable_fn(ref_key_fn) self.ref_key_fn = key_fn if ref_key_fn is None else ref_key_fn self.keep_key = keep_key if merge_fn is not None: - _check_lambda_fn(merge_fn) + _check_unpickable_fn(merge_fn) self.merge_fn = merge_fn if buffer_size is not None and buffer_size <= 0: raise ValueError("'buffer_size' is required to be either None or a positive integer.") @@ -185,10 +185,10 @@ def __init__( raise TypeError(f"map_datapipe must be a MapDataPipe, but its type is {type(map_datapipe)} instead.") self.source_iterdatapipe: IterDataPipe = source_iterdatapipe self.map_datapipe: MapDataPipe = map_datapipe - _check_lambda_fn(key_fn) + _check_unpickable_fn(key_fn) self.key_fn: Callable = key_fn if merge_fn is not None: - _check_lambda_fn(merge_fn) + _check_unpickable_fn(merge_fn) self.merge_fn: Optional[Callable] = merge_fn self.length: int = -1 diff --git a/torchdata/datapipes/iter/util/converter.py b/torchdata/datapipes/iter/util/converter.py index 7c6d62751..e4adaafdb 100644 --- a/torchdata/datapipes/iter/util/converter.py +++ b/torchdata/datapipes/iter/util/converter.py @@ -9,7 +9,7 @@ from typing import Callable, Dict, Optional from torch.utils.data import IterDataPipe, MapDataPipe -from torch.utils.data.datapipes.utils.common import _check_lambda_fn, DILL_AVAILABLE +from torch.utils.data.datapipes.utils.common import _check_unpickable_fn, DILL_AVAILABLE if DILL_AVAILABLE: import dill @@ -52,7 +52,8 @@ def __init__(self, datapipe: IterDataPipe, key_value_fn: Optional[Callable] = No if not isinstance(datapipe, IterDataPipe): raise TypeError(f"IterToMapConverter can only apply on IterDataPipe, but found {type(datapipe)}") self.datapipe = datapipe - _check_lambda_fn(key_value_fn) + if key_value_fn is not None: + _check_unpickable_fn(key_value_fn) self.key_value_fn = key_value_fn # type: ignore[assignment] self._map = None self._length = -1 diff --git a/torchdata/datapipes/iter/util/paragraphaggregator.py b/torchdata/datapipes/iter/util/paragraphaggregator.py index 696ba33fa..be7c21daf 100644 --- a/torchdata/datapipes/iter/util/paragraphaggregator.py +++ b/torchdata/datapipes/iter/util/paragraphaggregator.py @@ -6,7 +6,7 @@ from typing import Callable, Iterator, List, Tuple, TypeVar -from torch.utils.data.datapipes.utils.common import _check_lambda_fn +from torch.utils.data.datapipes.utils.common import _check_unpickable_fn from torchdata.datapipes import functional_datapipe from torchdata.datapipes.iter import IterDataPipe @@ -44,7 +44,7 @@ class ParagraphAggregatorIterDataPipe(IterDataPipe[Tuple[str, str]]): def __init__(self, source_datapipe: IterDataPipe[Tuple[str, T_co]], joiner: Callable = _default_line_join) -> None: self.source_datapipe: IterDataPipe[Tuple[str, T_co]] = source_datapipe - _check_lambda_fn(joiner) + _check_unpickable_fn(joiner) self.joiner: Callable = joiner self.buffer: List = []