Skip to content

Commit

Permalink
Raise warning for unpickable local function (pytorch#80232)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#80232

X-link: pytorch/data#547

Fixes pytorch/data#538
- Improve the validation function to raise warning about unpickable function when either lambda or local function is provided to DataPipe.
- The inner function from functools.partial object is extracted as well for validation
- Mimic the behavior of pickle module for local lambda function: It would only raise Error for the local function rather than lambda function. So, we will raise warning about local function not lambda function.
```py

>>> import pickle
>>> def fn():
...     lf = lambda x: x
...     pickle.dumps(lf)
>>> pickle.dumps(fn)
AttributeError: Can't pickle local object 'fn.<locals>.<lambda>'
```

This Diff also fixes the Error introduced by pytorch#79344

Test Plan:
```
buck test //caffe2/test:datapipe
buck test //pytorch/data/test:tests
```
Tested in OSS
```
# PT
pytest test/test_datapipe.py -v
# TD
pytest test/test_iterdatapipe.py -v
pytest test/test_mapdatapipe.py -v
pytest test/test_serialization.py -v
# TV
pytest test/test_prototype_builtin_datasets.py -v
```

Reviewed By: NivekT

Differential Revision: D37417556

fbshipit-source-id: 6fae4059285b8c742feda739cc5fe590b2e20c5e
  • Loading branch information
ejguan authored and facebook-github-bot committed Jun 27, 2022
1 parent 590d3e5 commit 373109c
Show file tree
Hide file tree
Showing 7 changed files with 163 additions and 71 deletions.
164 changes: 108 additions & 56 deletions test/test_datapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,11 @@ def _worker_init_fn(worker_id):
torch.utils.data.graph_settings.apply_sharding(datapipe, num_workers, worker_id)


lambda_fn1 = lambda x: x # noqa: E731
lambda_fn2 = lambda x: x % 2 # noqa: E731
lambda_fn3 = lambda x: x >= 5 # noqa: E731


class TestFunctionalIterDataPipe(TestCase):

def _serialization_test_helper(self, datapipe, use_dill):
Expand Down Expand Up @@ -702,30 +707,58 @@ def test_serializable(self):
def test_serializable_with_dill(self):
"""Only for DataPipes that take in a function as argument"""
input_dp = dp.iter.IterableWrapper(range(10))
unpicklable_datapipes: List[Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]]] = [
(dp.iter.Collator, (lambda x: x,), {}),
(dp.iter.Demultiplexer, (2, lambda x: x % 2,), {}),
(dp.iter.Filter, (lambda x: x >= 5,), {}),
(dp.iter.Grouper, (lambda x: x >= 5,), {}),
(dp.iter.Mapper, (lambda x: x,), {}),

datapipes_with_lambda_fn: List[Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]]] = [
(dp.iter.Collator, (lambda_fn1,), {}),
(dp.iter.Demultiplexer, (2, lambda_fn2,), {}),
(dp.iter.Filter, (lambda_fn3,), {}),
(dp.iter.Grouper, (lambda_fn3,), {}),
(dp.iter.Mapper, (lambda_fn1,), {}),
]

def _local_fns():
def _fn1(x):
return x

def _fn2(x):
return x % 2

def _fn3(x):
return x >= 5

return _fn1, _fn2, _fn3

fn1, fn2, fn3 = _local_fns()

datapipes_with_local_fn: List[Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]]] = [
(dp.iter.Collator, (fn1,), {}),
(dp.iter.Demultiplexer, (2, fn2,), {}),
(dp.iter.Filter, (fn3,), {}),
(dp.iter.Grouper, (fn3,), {}),
(dp.iter.Mapper, (fn1,), {}),
]

dp_compare_children = {dp.iter.Demultiplexer}

if HAS_DILL:
for dpipe, dp_args, dp_kwargs in unpicklable_datapipes:
for dpipe, dp_args, dp_kwargs in datapipes_with_lambda_fn + datapipes_with_local_fn:
if dpipe in dp_compare_children:
dp1, dp2 = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg]
self._serialization_test_for_dp_with_children(dp1, dp2, use_dill=True)
else:
datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg]
self._serialization_test_for_single_dp(datapipe, use_dill=True)
else:
for dpipe, dp_args, dp_kwargs in unpicklable_datapipes:
with warnings.catch_warnings(record=True) as wa:
datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg]
self.assertEqual(len(wa), 1)
self.assertRegex(str(wa[0].message), r"^Lambda function is not supported for pickle")
with self.assertRaises(AttributeError):
p = pickle.dumps(datapipe)
msgs = (
r"^Lambda function is not supported by pickle",
r"^Local function is not supported by pickle"
)
for dps, msg in zip((datapipes_with_lambda_fn, datapipes_with_local_fn), msgs):
for dpipe, dp_args, dp_kwargs in dps:
with self.assertWarnsRegex(UserWarning, msg):
datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg]
with self.assertRaises((pickle.PicklingError, AttributeError)):
pickle.dumps(datapipe)

def test_iterable_wrapper_datapipe(self):

Expand Down Expand Up @@ -1145,42 +1178,43 @@ def fn_n1(d0, d1):
def fn_nn(d0, d1):
return -d0, -d1, d0 + d1

def _helper(ref_fn, fn, input_col=None, output_col=None):
def _helper(ref_fn, fn, input_col=None, output_col=None, error=None):
for constr in (list, tuple):
datapipe = dp.iter.IterableWrapper([constr((0, 1, 2)), constr((3, 4, 5)), constr((6, 7, 8))])
res_dp = datapipe.map(fn, input_col, output_col)
ref_dp = datapipe.map(ref_fn)
self.assertEqual(list(res_dp), list(ref_dp))
# Reset
self.assertEqual(list(res_dp), list(ref_dp))
if ref_fn is None:
with self.assertRaises(error):
res_dp = datapipe.map(fn, input_col, output_col)
list(res_dp)
else:
res_dp = datapipe.map(fn, input_col, output_col)
ref_dp = datapipe.map(ref_fn)
self.assertEqual(list(res_dp), list(ref_dp))
# Reset
self.assertEqual(list(res_dp), list(ref_dp))

# Replacing with one input column and default output column
_helper(lambda data: (data[0], -data[1], data[2]), fn_11, 1)
_helper(lambda data: (data[0], (-data[1], data[1]), data[2]), fn_1n, 1)
# The index of input column is out of range
with self.assertRaises(IndexError):
_helper(None, fn_1n, 3)
_helper(None, fn_1n, 3, error=IndexError)
# Unmatched input columns with fn arguments
with self.assertRaises(TypeError):
_helper(None, fn_n1, 1)
_helper(None, fn_n1, 1, error=TypeError)

# Replacing with multiple input columns and default output column (the left-most input column)
_helper(lambda data: (data[1], data[2] + data[0]), fn_n1, [2, 0])
_helper(lambda data: (data[0], (-data[2], -data[1], data[2] + data[1])), fn_nn, [2, 1])

# output_col can only be specified when input_col is not None
with self.assertRaises(ValueError):
_helper(None, fn_n1, None, 1)
_helper(None, fn_n1, None, 1, error=ValueError)
# output_col can only be single-element list or tuple
with self.assertRaises(ValueError):
_helper(None, fn_n1, None, [0, 1])
_helper(None, fn_n1, None, [0, 1], error=ValueError)
# Single-element list as output_col
_helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, [0])
# Replacing with one input column and single specified output column
_helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, 0)
_helper(lambda data: (data[0], data[1], (-data[1], data[1])), fn_1n, 1, 2)
# The index of output column is out of range
with self.assertRaises(IndexError):
_helper(None, fn_1n, 1, 3)
_helper(None, fn_1n, 1, 3, error=IndexError)
_helper(lambda data: (data[0], data[0] + data[2], data[2]), fn_n1, [0, 2], 1)
_helper(lambda data: ((-data[1], -data[2], data[1] + data[2]), data[1], data[2]), fn_nn, [1, 2], 0)

Expand Down Expand Up @@ -1213,38 +1247,39 @@ def _dict_update(data, newdata, remove_idx=None):
del _data[idx]
return _data

def _helper(ref_fn, fn, input_col=None, output_col=None):
def _helper(ref_fn, fn, input_col=None, output_col=None, error=None):
datapipe = dp.iter.IterableWrapper(
[{"x": 0, "y": 1, "z": 2},
{"x": 3, "y": 4, "z": 5},
{"x": 6, "y": 7, "z": 8}]
)
res_dp = datapipe.map(fn, input_col, output_col)
ref_dp = datapipe.map(ref_fn)
self.assertEqual(list(res_dp), list(ref_dp))
# Reset
self.assertEqual(list(res_dp), list(ref_dp))
if ref_fn is None:
with self.assertRaises(error):
res_dp = datapipe.map(fn, input_col, output_col)
list(res_dp)
else:
res_dp = datapipe.map(fn, input_col, output_col)
ref_dp = datapipe.map(ref_fn)
self.assertEqual(list(res_dp), list(ref_dp))
# Reset
self.assertEqual(list(res_dp), list(ref_dp))

# Replacing with one input column and default output column
_helper(lambda data: _dict_update(data, {"y": -data["y"]}), fn_11, "y")
_helper(lambda data: _dict_update(data, {"y": (-data["y"], data["y"])}), fn_1n, "y")
# The key of input column is not in dict
with self.assertRaises(KeyError):
_helper(None, fn_1n, "a")
_helper(None, fn_1n, "a", error=KeyError)
# Unmatched input columns with fn arguments
with self.assertRaises(TypeError):
_helper(None, fn_n1, "y")
_helper(None, fn_n1, "y", error=TypeError)
# Replacing with multiple input columns and default output column (the left-most input column)
_helper(lambda data: _dict_update(data, {"z": data["x"] + data["z"]}, ["x"]), fn_n1, ["z", "x"])
_helper(lambda data: _dict_update(
data, {"z": (-data["z"], -data["y"], data["y"] + data["z"])}, ["y"]), fn_nn, ["z", "y"])

# output_col can only be specified when input_col is not None
with self.assertRaises(ValueError):
_helper(None, fn_n1, None, "x")
_helper(None, fn_n1, None, "x", error=ValueError)
# output_col can only be single-element list or tuple
with self.assertRaises(ValueError):
_helper(None, fn_n1, None, ["x", "y"])
_helper(None, fn_n1, None, ["x", "y"], error=ValueError)
# Single-element list as output_col
_helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", ["x"])
# Replacing with one input column and single specified output column
Expand Down Expand Up @@ -1617,24 +1652,41 @@ def test_serializable(self):
def test_serializable_with_dill(self):
"""Only for DataPipes that take in a function as argument"""
input_dp = dp.map.SequenceWrapper(range(10))
unpicklable_datapipes: List[

datapipes_with_lambda_fn: List[
Tuple[Type[MapDataPipe], Tuple, Dict[str, Any]]
] = [
(dp.map.Mapper, (lambda x: x,), {}),
(dp.map.Mapper, (lambda_fn1,), {}),
]

def _local_fns():
def _fn1(x):
return x

return _fn1

fn1 = _local_fns()

datapipes_with_local_fn: List[
Tuple[Type[MapDataPipe], Tuple, Dict[str, Any]]
] = [
(dp.map.Mapper, (fn1,), {}),
]

if HAS_DILL:
for dpipe, dp_args, dp_kwargs in unpicklable_datapipes:
for dpipe, dp_args, dp_kwargs in datapipes_with_lambda_fn + datapipes_with_local_fn:
_ = dill.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg]
else:
for dpipe, dp_args, dp_kwargs in unpicklable_datapipes:
with warnings.catch_warnings(record=True) as wa:
datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg]
self.assertEqual(len(wa), 1)
self.assertRegex(
str(wa[0].message), r"^Lambda function is not supported for pickle"
)
with self.assertRaises(AttributeError):
p = pickle.dumps(datapipe)
msgs = (
r"^Lambda function is not supported by pickle",
r"^Local function is not supported by pickle"
)
for dps, msg in zip((datapipes_with_lambda_fn, datapipes_with_local_fn), msgs):
for dpipe, dp_args, dp_kwargs in dps:
with self.assertWarnsRegex(UserWarning, msg):
datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg]
with self.assertRaises((pickle.PicklingError, AttributeError)):
pickle.dumps(datapipe)

def test_sequence_wrapper_datapipe(self):
seq = list(range(10))
Expand Down
4 changes: 2 additions & 2 deletions torch/utils/data/datapipes/iter/callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data._utils.collate import default_collate
from torch.utils.data.datapipes.datapipe import IterDataPipe
from torch.utils.data.datapipes.utils.common import _check_lambda_fn
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn

__all__ = [
"CollatorIterDataPipe",
Expand Down Expand Up @@ -64,7 +64,7 @@ def __init__(
super().__init__()
self.datapipe = datapipe

_check_lambda_fn(fn)
_check_unpickable_fn(fn)
self.fn = fn # type: ignore[assignment]

self.input_col = input_col
Expand Down
4 changes: 2 additions & 2 deletions torch/utils/data/datapipes/iter/combining.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data.datapipes.datapipe import IterDataPipe
from torch.utils.data.datapipes.utils.common import _check_lambda_fn
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn

__all__ = [
"ConcaterIterDataPipe",
Expand Down Expand Up @@ -300,7 +300,7 @@ def __new__(cls, datapipe: IterDataPipe, num_instances: int,
if num_instances < 1:
raise ValueError(f"Expected `num_instaces` larger than 0, but {num_instances} is found")

_check_lambda_fn(classifier_fn)
_check_unpickable_fn(classifier_fn)

# When num_instances == 1, demux can be replaced by filter,
# but keep it as Demultiplexer for the sake of consistency
Expand Down
4 changes: 2 additions & 2 deletions torch/utils/data/datapipes/iter/grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data.datapipes.datapipe import IterDataPipe, DataChunk
from torch.utils.data.datapipes.utils.common import _check_lambda_fn
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
from typing import Any, Callable, DefaultDict, Iterator, List, Optional, Sized, TypeVar

__all__ = [
Expand Down Expand Up @@ -215,7 +215,7 @@ def __init__(self,
group_size: Optional[int] = None,
guaranteed_group_size: Optional[int] = None,
drop_remaining: bool = False):
_check_lambda_fn(group_key_fn)
_check_unpickable_fn(group_key_fn)
self.datapipe = datapipe
self.group_key_fn = group_key_fn

Expand Down
7 changes: 5 additions & 2 deletions torch/utils/data/datapipes/iter/selecting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data.datapipes.datapipe import IterDataPipe
from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper
from torch.utils.data.datapipes.utils.common import _check_lambda_fn, _deprecation_warning
from torch.utils.data.datapipes.utils.common import (
_check_unpickable_fn,
_deprecation_warning,
)

__all__ = ["FilterIterDataPipe", ]

Expand Down Expand Up @@ -48,7 +51,7 @@ def __init__(
super().__init__()
self.datapipe = datapipe

_check_lambda_fn(filter_fn)
_check_unpickable_fn(filter_fn)
self.filter_fn = filter_fn # type: ignore[assignment]

if drop_empty_batches is None:
Expand Down
4 changes: 2 additions & 2 deletions torch/utils/data/datapipes/map/callable.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from torch.utils.data.datapipes.utils.common import _check_lambda_fn
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
from typing import Callable, TypeVar
from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data.datapipes.datapipe import MapDataPipe
Expand Down Expand Up @@ -48,7 +48,7 @@ def __init__(
) -> None:
super().__init__()
self.datapipe = datapipe
_check_lambda_fn(fn)
_check_unpickable_fn(fn)
self.fn = fn # type: ignore[assignment]

def __len__(self) -> int:
Expand Down
Loading

0 comments on commit 373109c

Please sign in to comment.