This repository has been archived by the owner on Oct 9, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 211
/
Copy pathpostprocessing_pipeline.py
154 lines (113 loc) · 4.83 KB
/
postprocessing_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import os
from functools import wraps
from typing import Any, Callable, Mapping, Optional, Sequence
import torch
from flash.core.model import Task
class PostProcessingPipeline:
def __init__(self, save_path: Optional[str] = None):
self._saved_samples = 0
self._save_path = save_path
def pre_uncollate(self, batch: Any) -> Any:
"""Transforms to apply to a whole batch before uncollation to single samples.
Can involve both CPU and Device transforms as this is not applied in separate workers.
"""
return batch
def post_uncollate(self, sample: Any) -> Any:
"""Transforms to apply to a single sample after splitting up the batch.
Can involve both CPU and Device transforms as this is not applied in separate workers.
"""
return sample
def uncollate(self, batch: Any) -> Any:
"""Uncollates a batch into single samples.
Tries to preserve the type whereever possible.
"""
return default_uncollate(batch)
def save_data(self, data: Any, path: str) -> None:
"""Saves all data together to a single path.
"""
torch.save(data, path)
def save_sample(self, sample: Any, path: str) -> None:
"""Saves each sample individually to a given path.
"""
torch.save(sample, path)
def format_sample_save_path(self, path: str) -> None:
path = os.path.join(path, f'sample_{self._saved_samples}.ptl')
self._saved_samples += 1
return path
def _save_data(self, data: Any) -> None:
self.save_data(data, self._save_path)
def _save_sample(self, sample: Any) -> None:
self.save_sample(sample, self.format_sample_save_path(self._save_path))
def _is_overriden(self, method_name: str) -> bool:
"""Cropped Version of https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/model_helpers.py
"""
super_obj = PostProcessingPipeline
if not hasattr(self, method_name) or not hasattr(super_obj, method_name):
return False
return getattr(self, method_name).__code__ is not getattr(super_obj, method_name)
@staticmethod
def _model_predict_wrapper(func: Callable, uncollater: UnCollater) -> Callable:
@wraps(func)
def new_func(*args, **kwargs):
predicted = func(*args, **kwargs)
return uncollater(predicted)
return new_func
def _attach_to_model(self, model: Task) -> Task:
if self._save_path is None:
save_per_sample = None
save_fn = None
else:
save_per_sample = self._is_overriden('save_sample')
if save_per_sample:
save_fn = self._save_sample
else:
save_fn = self._save_data
# TODO: move this to on_predict_end?
model.predict_step = self._model_predict_wrapper(
model.predict_step,
UnCollater(
self.uncollate,
self.pre_uncollate,
self.post_uncollate,
save_fn=save_fn,
save_per_sample=save_per_sample
)
)
return model
class UnCollater:
def __init__(
self,
uncollate_fn: Callable,
pre_uncollate: Callable,
post_uncollate: Callable,
save_fn: Optional[Callable] = None,
save_per_sample: bool = False
):
self.uncollate_fn = uncollate_fn
self.pre_uncollate = pre_uncollate
self.post_uncollate = post_uncollate
self.save_fn = save_fn
self.save_per_sample = save_per_sample
def __call__(self, batch: Sequence[Any]):
uncollated = self.uncollate_fn(self.pre_uncollate(batch))
final_preds = type(uncollated)([self.post_uncollate(sample) for sample in uncollated])
if self.save_fn is not None:
if self.save_per_sample:
for pred in final_preds:
self.save_fn(pred)
else:
self.save_fn(final_preds)
def __repr__(self) -> str:
repr_str = f'UnCollater:\n\t(pre_uncollate): {repr(self.pre_uncollate)}\n\t(uncollate_fn): {repr(self.uncollate_fn)}\n\t(post_uncollate): {repr(self.post_uncollate)}'
return repr_str
def default_uncollate(batch: Any):
batch_type = type(batch)
if isinstance(batch, torch.Tensor):
return list(torch.unbind(batch, 0))
elif isinstance(batch, Mapping):
return [batch_type(dict(zip(batch, default_uncollate(t)))) for t in zip(*batch.values())]
elif isinstance(batch, tuple) and hasattr(batch, '_fields'): # namedtuple
return [batch_type(*default_uncollate(sample)) for sample in zip(*batch)]
elif isinstance(batch, Sequence) and not isinstance(batch, str):
return [default_uncollate(sample) for sample in batch]
return batch