Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

DataPipeline PoC #141

Merged
merged 191 commits into from
Mar 29, 2021
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
191 commits
Select commit Hold shift + click to select a range
45691cd
add prototype of DataPipeline
justusschock Feb 18, 2021
135eb17
Add Prototype of PostProcessingPipeline
justusschock Feb 18, 2021
535353c
isort + pep8
justusschock Feb 18, 2021
f66f223
update post_processing_pipeline
justusschock Feb 20, 2021
67de76f
update data pipline
justusschock Feb 20, 2021
3be12a3
add new prediction part
justusschock Feb 20, 2021
17cecb8
change loader name
justusschock Feb 22, 2021
be4f505
update
tchaton Feb 22, 2021
2e2fa54
uypdate new datapipeline
justusschock Feb 23, 2021
fc34775
update model with new pipeline
justusschock Feb 23, 2021
b417683
update
tchaton Feb 23, 2021
307b210
update gitignore
justusschock Feb 23, 2021
9dc842a
add autodataset
justusschock Feb 23, 2021
77f935c
add batch processing
justusschock Feb 23, 2021
7606f93
Merge branch 'datapipeline_poc_1' of github.com:PyTorchLightning/ligh…
justusschock Feb 23, 2021
dd68bf3
update
tchaton Feb 24, 2021
69c5bc1
Merge branch 'datapipeline_poc_1' of https://github.com/PyTorchLightn…
tchaton Feb 24, 2021
b5b3ad0
update
tchaton Feb 24, 2021
040c3be
update
tchaton Feb 25, 2021
4edee9c
add process file
justusschock Feb 27, 2021
327f19c
make datapipeline attaching and detaching more robust
justusschock Feb 27, 2021
95e809c
resolve flake8
tchaton Feb 28, 2021
3780522
update
tchaton Feb 28, 2021
966b1a9
push curr state
justusschock Mar 2, 2021
41a5e71
Update flash/data/batch.py
justusschock Mar 4, 2021
d6b7347
resolve some bugs
tchaton Mar 8, 2021
8e96e7e
tests
justusschock Mar 8, 2021
4067356
tests
justusschock Mar 8, 2021
d40b8c9
update
tchaton Mar 8, 2021
4cbcff8
Merge branch 'datapipeline_poc_1' of https://github.com/PyTorchLightn…
tchaton Mar 8, 2021
f8a3580
make everything nn.Module and check serialization
justusschock Mar 8, 2021
e388659
resolve kornia example
tchaton Mar 10, 2021
92df45a
Merge branch 'datapipeline_poc_1' of https://github.com/PyTorchLightn…
tchaton Mar 10, 2021
38d2574
add prototype of DataPipeline
justusschock Feb 18, 2021
6ff8c1c
Add Prototype of PostProcessingPipeline
justusschock Feb 18, 2021
f73f45b
isort + pep8
justusschock Feb 18, 2021
66f6562
update post_processing_pipeline
justusschock Feb 20, 2021
07ab337
update data pipline
justusschock Feb 20, 2021
f92b3cb
add new prediction part
justusschock Feb 20, 2021
0e7ee40
change loader name
justusschock Feb 22, 2021
e03b0b1
update
tchaton Feb 22, 2021
e3c1582
update
tchaton Feb 23, 2021
1c915ca
update
tchaton Feb 24, 2021
9906ad4
uypdate new datapipeline
justusschock Feb 23, 2021
fba408e
update model with new pipeline
justusschock Feb 23, 2021
99ebec8
update gitignore
justusschock Feb 23, 2021
a68419c
add autodataset
justusschock Feb 23, 2021
ac25999
add batch processing
justusschock Feb 23, 2021
70ba492
update
tchaton Feb 24, 2021
0bb5fdc
update
tchaton Feb 25, 2021
ac50910
add process file
justusschock Feb 27, 2021
97a8e4e
make datapipeline attaching and detaching more robust
justusschock Feb 27, 2021
5ef3f7d
resolve flake8
tchaton Feb 28, 2021
b7275de
update
tchaton Feb 28, 2021
f7a1966
push curr state
justusschock Mar 2, 2021
d84e53f
Update flash/data/batch.py
justusschock Mar 4, 2021
31d9b6d
resolve some bugs
tchaton Mar 8, 2021
d3a7cd7
update
tchaton Mar 8, 2021
eaee810
tests
justusschock Mar 8, 2021
f7f8642
resolve kornia example
tchaton Mar 10, 2021
e290159
make everything nn.Module and check serialization
justusschock Mar 8, 2021
6133ef8
rebase_fixes
justusschock Mar 10, 2021
a5289d4
add more tests
tchaton Mar 10, 2021
9f144f4
Merge branch 'datapipeline_poc_1' of https://github.com/PyTorchLightn…
tchaton Mar 10, 2021
f962401
update tabular
tchaton Mar 11, 2021
59bbe09
add new hooks
tchaton Mar 11, 2021
5604042
update tabular
tchaton Mar 12, 2021
fba7e96
update
tchaton Mar 12, 2021
ef6a9fd
Move func to data module
justusschock Mar 13, 2021
08bab33
fix vision to current version
justusschock Mar 13, 2021
07fb5e6
transfer text classification to new API
justusschock Mar 13, 2021
b744741
add more tests
tchaton Mar 13, 2021
7b782e1
update
justusschock Mar 13, 2021
1abee8a
resolve most bugs
tchaton Mar 14, 2021
0b00b22
address most comments
tchaton Mar 14, 2021
4d15e94
remove kornia example
tchaton Mar 14, 2021
a598e99
add support for summurization example
tchaton Mar 14, 2021
e8968a7
work with ObjectDetection
tchaton Mar 14, 2021
1ea587c
Update gitignore
kaushikb11 Mar 14, 2021
0ae1729
updates
tchaton Mar 14, 2021
d1da35f
Merge branch 'datapipeline_poc_1' of https://github.com/PyTorchLightn…
tchaton Mar 14, 2021
3c2c08b
resolve bug
tchaton Mar 14, 2021
ab887be
update
justusschock Mar 15, 2021
ef91f81
resolve image embedder
tchaton Mar 15, 2021
c5bfc41
Merge branch 'datapipeline_poc_1' of https://github.com/PyTorchLightn…
tchaton Mar 15, 2021
b2b6b54
Update Image Classifer
kaushikb11 Mar 15, 2021
382feb5
Renaming
kaushikb11 Mar 15, 2021
59365f4
fix recursion
justusschock Mar 16, 2021
b1951c8
resolve bug
tchaton Mar 16, 2021
214dd58
Merge branch 'datapipeline_poc_1' of https://github.com/PyTorchLightn…
tchaton Mar 16, 2021
e8a7842
Merge branch 'datapipeline_poc_1' of github.com:PyTorchLightning/ligh…
justusschock Mar 16, 2021
a18d745
Fix DataPipeline function resolution
justusschock Mar 16, 2021
1903fa7
put back properties instead of attributes
justusschock Mar 16, 2021
832663e
fix import
justusschock Mar 16, 2021
0187b13
fix examples
justusschock Mar 16, 2021
cc4b0d5
add checks for loading
justusschock Mar 16, 2021
e55899f
fix recursion
justusschock Mar 16, 2021
2443720
fix seq2seq dataset
justusschock Mar 16, 2021
f67b209
fix dm init in tests
justusschock Mar 16, 2021
27aa8b4
fix data parts
justusschock Mar 16, 2021
3969aa0
resolve tests and flake8
tchaton Mar 17, 2021
af7beef
Merge branch 'master' into datapipeline_poc_1
tchaton Mar 17, 2021
8b73caa
update on comments
tchaton Mar 18, 2021
23ba639
update notebooks
tchaton Mar 18, 2021
214ada8
devel
Borda Mar 18, 2021
a656fc9
Merge branch 'datapipeline_poc_1' of https://github.com/PyTorchLightn…
Borda Mar 18, 2021
ca017da
update
tchaton Mar 18, 2021
8f276b2
update
tchaton Mar 18, 2021
44ffd16
update
tchaton Mar 18, 2021
3c1e433
resolve the doc
tchaton Mar 18, 2021
1471fb0
update
tchaton Mar 18, 2021
d0e599c
don't apply flake8 on notebook
tchaton Mar 18, 2021
9c24add
resolve tests
tchaton Mar 18, 2021
d16b9fd
comment a notebook
tchaton Mar 18, 2021
7baf1cb
update
tchaton Mar 18, 2021
7bd5700
update ci
tchaton Mar 18, 2021
c324c34
add fixes
tchaton Mar 18, 2021
6eabc01
updaet
tchaton Mar 18, 2021
dd35707
Merge branch 'master' into datapipeline_poc_1
tchaton Mar 18, 2021
e4917ed
update with lightning
tchaton Mar 22, 2021
0b170b3
add a test for flash_special_arguments
tchaton Mar 22, 2021
7083679
Merge branch 'datapipeline_poc_1' of https://github.com/PyTorchLightn…
tchaton Mar 22, 2021
2f381ef
add data_pipeline
tchaton Mar 22, 2021
465522d
update ci
tchaton Mar 22, 2021
819c018
delete generate .py file
tchaton Mar 22, 2021
2b4756d
update bolts
tchaton Mar 22, 2021
d291f12
udpate ci
tchaton Mar 22, 2021
ffdd258
update
tchaton Mar 22, 2021
f183f19
Merge branch 'master' into data_pipeline_1_n
tchaton Mar 22, 2021
2e7bc4b
Update flash/data/auto_dataset.py
tchaton Mar 22, 2021
2c1e412
update
tchaton Mar 22, 2021
b8d2abc
Merge branch 'data_pipeline_1_n' of https://github.com/PyTorchLightni…
tchaton Mar 22, 2021
d278382
Update tests/data/test_data_pipeline.py
tchaton Mar 22, 2021
0e32fa1
update
tchaton Mar 22, 2021
eba35f6
Merge branch 'data_pipeline_1_n' of https://github.com/PyTorchLightni…
tchaton Mar 22, 2021
8bea3dd
update
tchaton Mar 22, 2021
2990b0b
add some docstring
tchaton Mar 23, 2021
276cf40
update docstring
tchaton Mar 23, 2021
06e5a09
update on comments
tchaton Mar 23, 2021
913bb45
Fixes
carmocca Mar 24, 2021
98aa56d
Docs
carmocca Mar 24, 2021
58c147f
Docs
carmocca Mar 24, 2021
98f75c4
Merge branch 'master' into data_pipeline_1_n
carmocca Mar 24, 2021
84ce3b1
update ci
tchaton Mar 24, 2021
86669c6
update on comments
tchaton Mar 25, 2021
54d0fc3
Update flash/data/batch.py
tchaton Mar 25, 2021
637ff25
Update flash/data/data_module.py
kaushikb11 Mar 25, 2021
dd3dfdb
Update flash/data/process.py
kaushikb11 Mar 25, 2021
4c487a9
Apply suggestions from code review
Borda Mar 25, 2021
ab96ac7
cleaning
Borda Mar 25, 2021
fddd6b1
Merge branch 'data_pipeline_1_n' into datapipeline_poc_1
tchaton Mar 25, 2021
51ea5d9
add pip install
tchaton Mar 25, 2021
0f1f15f
switch back to master
tchaton Mar 25, 2021
23aaebf
update requierements
tchaton Mar 25, 2021
41dd86c
try
Borda Mar 25, 2021
7d8c955
try
Borda Mar 25, 2021
8451011
try
Borda Mar 25, 2021
40a6b33
update
tchaton Mar 25, 2021
82fcace
prune legacy
tchaton Mar 25, 2021
fc045b1
Merge branch 'data_pipeline_1_n' into datapipeline_poc_1
tchaton Mar 25, 2021
13db92e
Merge branch 'master' into datapipeline_poc_1
tchaton Mar 25, 2021
176f14b
update
tchaton Mar 25, 2021
f4410de
Merge branch 'master' into datapipeline_poc_1
tchaton Mar 26, 2021
d8367a2
update
tchaton Mar 26, 2021
e89037f
update to latest
tchaton Mar 26, 2021
095bdbe
delete extra files
tchaton Mar 26, 2021
a659c3d
updates to Task class
kaushikb11 Mar 26, 2021
b17b413
Update Datamodule
kaushikb11 Mar 26, 2021
1877733
resolve comments
tchaton Mar 26, 2021
a434a68
Merge branch 'datapipeline_poc_1' of https://github.com/PyTorchLightn…
tchaton Mar 26, 2021
bda0ca3
update
tchaton Mar 26, 2021
b81a204
update
tchaton Mar 26, 2021
8043153
update
tchaton Mar 26, 2021
9297c5b
update
tchaton Mar 26, 2021
4f1e87d
try
tchaton Mar 26, 2021
d00382e
update
tchaton Mar 26, 2021
f11002c
udpate
tchaton Mar 26, 2021
6e05051
update
tchaton Mar 26, 2021
e72c1c3
update
tchaton Mar 26, 2021
9289632
update
tchaton Mar 26, 2021
77e3e0e
formatting
Borda Mar 28, 2021
d72d1b7
update on comments
tchaton Mar 29, 2021
503b62a
update on comments
tchaton Mar 29, 2021
25bd225
Merge branch 'master' into datapipeline_poc_1
tchaton Mar 29, 2021
2c58006
General changes
carmocca Mar 29, 2021
816c010
General changes
carmocca Mar 29, 2021
fbfb71f
update
tchaton Mar 29, 2021
d3d4c78
Merge branch 'datapipeline_poc_1' of https://github.com/PyTorchLightn…
tchaton Mar 29, 2021
0245d17
update
tchaton Mar 29, 2021
e2f24dc
add _data_pipeline back
tchaton Mar 29, 2021
de3327b
update
tchaton Mar 29, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,4 @@ titanic.csv
data_folder
*.pt
*.zip
data
18 changes: 11 additions & 7 deletions flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,27 @@

import torch

from flash.core.data import TaskDataPipeline
from flash.core.model import Task
from flash.data.data_pipeline import Postprocess


class ClassificationDataPipeline(TaskDataPipeline):
class ClassificationDataPipeline:
pass
carmocca marked this conversation as resolved.
Show resolved Hide resolved

def before_uncollate(self, batch: Union[torch.Tensor, tuple]) -> torch.Tensor:

class ClassificationPostprocess(Postprocess):

def pre_uncollate(self, batch: Union[torch.Tensor, tuple]) -> torch.Tensor:
if isinstance(batch, tuple):
batch = batch[0]
return torch.softmax(batch, -1)

def after_uncollate(self, samples: Any) -> Any:
def post_uncollate(self, samples: Any) -> Any:
return torch.argmax(samples, -1).tolist()
tchaton marked this conversation as resolved.
Show resolved Hide resolved


class ClassificationTask(Task):

@staticmethod
def default_pipeline() -> ClassificationDataPipeline:
return ClassificationDataPipeline()
@property
def postprocess(self):
return ClassificationPostprocess()
50 changes: 41 additions & 9 deletions flash/core/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset

from flash.core.data.datapipeline import DataPipeline
from flash.data.data_pipeline import DataPipeline, Postprocess, Preprocess


class TaskDataPipeline(DataPipeline):
Expand All @@ -44,13 +44,15 @@ def __init__(
train_ds: Optional[Dataset] = None,
valid_ds: Optional[Dataset] = None,
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
test_ds: Optional[Dataset] = None,
predict_ds: Optional[Dataset] = None,
batch_size: int = 1,
num_workers: Optional[int] = None,
):
super().__init__()
self._train_ds = train_ds
self._valid_ds = valid_ds
self._test_ds = test_ds
self._predict_ds = predict_ds

if self._train_ds is not None:
self.train_dataloader = self._train_dataloader
Expand All @@ -61,6 +63,9 @@ def __init__(
if self._test_ds is not None:
self.test_dataloader = self._test_dataloader

if self._predict_ds is not None:
self.predict_dataloader = self._predict_dataloader

self.batch_size = batch_size

# TODO: figure out best solution for setting num_workers
Expand All @@ -72,6 +77,8 @@ def __init__(
self.num_workers = num_workers

self._data_pipeline = None
self._preprocess = None
self._postprocess = None

def _train_dataloader(self) -> DataLoader:
return DataLoader(
Expand All @@ -80,7 +87,7 @@ def _train_dataloader(self) -> DataLoader:
shuffle=True,
num_workers=self.num_workers,
pin_memory=True,
collate_fn=self.data_pipeline.collate_fn,
collate_fn=self.data_pipeline.worker_collate_fn,
drop_last=True,
)

Expand All @@ -90,7 +97,7 @@ def _val_dataloader(self) -> DataLoader:
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
collate_fn=self.data_pipeline.collate_fn,
collate_fn=self.data_pipeline.worker_collate_fn,
)

def _test_dataloader(self) -> DataLoader:
Expand All @@ -99,19 +106,44 @@ def _test_dataloader(self) -> DataLoader:
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
collate_fn=self.data_pipeline.collate_fn,
collate_fn=self.data_pipeline.worker_collate_fn,
)

def _predict_dataloader(self) -> DataLoader:
return DataLoader(
self._predict_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
collate_fn=self.data_pipeline.worker_collate_fn,
)

@property
def preprocess(self):
return self._preprocess

@preprocess.setter
def preprocess(self, preprocess: Preprocess) -> None:
self._preprocess = preprocess

@property
def postprocess(self):
return self._postprocess

@postprocess.setter
def postprocess(self, postprocess: Postprocess) -> None:
self._postprocess = postprocess

@property
def data_pipeline(self) -> DataPipeline:
if self._data_pipeline is None:
self._data_pipeline = self.default_pipeline()
preprocess = self._preprocess
postprocess = self._postprocess
if preprocess is None and postprocess is None:
self._data_pipeline = self.default_pipeline()
return DataPipeline(preprocess, postprocess)
return self._data_pipeline

@data_pipeline.setter
def data_pipeline(self, data_pipeline) -> None:
self._data_pipeline = data_pipeline

@staticmethod
def default_pipeline() -> DataPipeline:
return TaskDataPipeline()
14 changes: 7 additions & 7 deletions flash/core/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class NoFreeze(BaseFinetuning):
def freeze_before_training(self, pl_module: pl.LightningModule) -> None:
pass

def finetunning_function(
def finetune_function(
self,
pl_module: pl.LightningModule,
epoch: int,
Expand All @@ -42,7 +42,7 @@ def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: boo

FlashBaseFinetuning can be used to create a custom Flash Finetuning Callback.

Override ``finetunning_function`` to put your unfreeze logic.
Override ``finetune_function`` to put your unfreeze logic.
tchaton marked this conversation as resolved.
Show resolved Hide resolved

Args:
attr_names: Name(s) of the module attributes of the model to be frozen.
Expand All @@ -62,15 +62,15 @@ def freeze_using_attr_names(self, pl_module, attr_names: List[str], train_bn: bo
attr = getattr(pl_module, attr_name, None)
if attr is None or not isinstance(attr, nn.Module):
MisconfigurationException(f"Your model must have a {attr} attribute")
self.freeze(module=attr, train_bn=train_bn)
self.freeze(modules=attr, train_bn=train_bn)

def finetunning_function(self, pl_module: pl.LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int):
def finetune_function(self, pl_module: pl.LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int):
pass


class Freeze(FlashBaseFinetuning):

def finetunning_function(
def finetune_function(
self,
pl_module: pl.LightningModule,
epoch: int,
Expand All @@ -86,7 +86,7 @@ def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: boo
super().__init__(attr_names, train_bn)
self.unfreeze_epoch = unfreeze_epoch

def finetunning_function(
def finetune_function(
self,
pl_module: pl.LightningModule,
epoch: int,
Expand Down Expand Up @@ -116,7 +116,7 @@ def __init__(

super().__init__(attr_names, train_bn)

def finetunning_function(
def finetune_function(
self,
pl_module: pl.LightningModule,
epoch: int,
Expand Down
101 changes: 68 additions & 33 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

import pytorch_lightning as pl
import torch
from pytorch_lightning import Trainer
from torch import nn

from flash.core.data import DataModule, DataPipeline
from flash.core.data import DataModule
from flash.core.utils import get_callable_dict
from flash.data.data_pipeline import DataPipeline, Postprocess, Preprocess


def predict_context(func: Callable) -> Callable:
Expand All @@ -31,13 +33,16 @@ def predict_context(func: Callable) -> Callable:

@functools.wraps(func)
def wrapper(self, *args, **kwargs) -> Any:
grad_enabled = torch.is_grad_enabled()
is_training = self.training
self.eval()
torch.set_grad_enabled(False)

result = func(self, *args, **kwargs)

self.train()
torch.set_grad_enabled(True)
if is_training:
self.train()
torch.set_grad_enabled(grad_enabled)
return result

return wrapper
Expand All @@ -63,6 +68,8 @@ def __init__(
learning_rate: float = 5e-5,
):
super().__init__()
self._last_trainer_kwargs = {}

tchaton marked this conversation as resolved.
Show resolved Hide resolved
if model is not None:
self.model = model
self.loss_fn = {} if loss_fn is None else get_callable_dict(loss_fn)
Expand All @@ -71,15 +78,18 @@ def __init__(
self.learning_rate = learning_rate
# TODO: should we save more? Bug on some regarding yaml if we save metrics
self.save_hyperparameters("learning_rate", "optimizer")

self._data_pipeline = None
self._preprocess = None
self._postprocess = None

def step(self, batch: Any, batch_idx: int) -> Any:
"""
The training/validation/test step. Override for custom behavior.
"""
x, y = batch
y_hat = self.forward(x)
output = {"y_hat": self.data_pipeline.before_uncollate(y_hat)}
output = {"y_hat": self.data_pipeline.pre_uncollate(y_hat)}
losses = {name: l_fn(y_hat, y) for name, l_fn in self.loss_fn.items()}
logs = {}
for name, metric in self.metrics.items():
Expand Down Expand Up @@ -143,48 +153,73 @@ def predict(
The post-processed model predictions

"""
# enable x to be a path to a folder
if isinstance(x, str):
files = os.listdir(x)
files = [os.path.join(x, y) for y in files]
x = files

data_pipeline = data_pipeline or self.data_pipeline
batch = x if skip_collate_fn else data_pipeline.collate_fn(x)
batch_x, batch_y = batch if len(batch) == 2 and isinstance(batch, (list, tuple)) else (batch, None)
predictions = self.forward(batch_x)
output = data_pipeline.uncollate_fn(predictions) # TODO: pass batch and x
return output
x = [x for x in data_pipeline._generate_auto_dataset(x)]
x = self.data_pipeline.worker_collate_fn(x)
#x = self.data_pipeline.device_collate_fn(x)
predictions = self.predict_step(x, batch_idx)
return data_pipeline.uncollate_fn(predictions)

def predict_step(self, batch, batch_idx):
return self(batch)

def configure_optimizers(self) -> torch.optim.Optimizer:
return self.optimizer_cls(filter(lambda p: p.requires_grad, self.parameters()), lr=self.learning_rate)

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
self.data_pipeline = checkpoint["pipeline"]

def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
checkpoint["pipeline"] = self.data_pipeline

def configure_finetune_callback(self):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
return []

def predict_step(self, batch, batch_idx):
return self(batch)

@property
def preprocess(self):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
return self._preprocess

@preprocess.setter
def preprocess(self, preprocess: Preprocess) -> None:
data_pipeline = self.data_pipeline
self.data_pipeline = DataPipeline(preprocess, data_pipeline.postprocess)

@property
def postprocess(self):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
return self._postprocess

@postprocess.setter
def postprocess(self, postprocess: Postprocess) -> None:
data_pipeline = self.data_pipeline
self.data_pipeline = DataPipeline(data_pipeline.preprocess, postprocess)

@property
def data_pipeline(self) -> DataPipeline:
def data_pipeline(self) -> Optional[DataPipeline]:
# we need to save the pipeline in case this class
# is loaded from checkpoint and used to predict
if not self._data_pipeline:
try:
# datamodule pipeline takes priority
self._data_pipeline = self.trainer.datamodule.data_pipeline
except AttributeError:
self._data_pipeline = self.default_pipeline()
return self._data_pipeline
return self._get_pipeline("data_pipeline")

@data_pipeline.setter
def data_pipeline(self, data_pipeline: DataPipeline) -> None:
self._data_pipeline = data_pipeline
if isinstance(data_pipeline, DataPipeline):
self._data_pipeline._attach_to_model(self)

@staticmethod
def default_pipeline() -> DataPipeline:
"""Pipeline to use when there is no datamodule or it has not defined its pipeline"""
return DataModule.default_pipeline()
def _get_pipeline(self, pipeline_attr_name: str):

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
self.data_pipeline = checkpoint["pipeline"]
if getattr(self, '_' + pipeline_attr_name) is not None:
return getattr(self, '_' + pipeline_attr_name)

def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
checkpoint["pipeline"] = self.data_pipeline
if self.datamodule is not None and hasattr(self, pipeline_attr_name):
return getattr(self.datamodule, pipeline_attr_name)

def configure_finetune_callback(self):
return []
if self.trainer is not None and hasattr(self.trainer, 'datamodule') and self.trainer.datamodule is not None:
if hasattr(self.trainer.datamodule,
pipeline_attr_name) and getattr(self.trainer.datamodule, pipeline_attr_name):
data_pipeline = getattr(self.trainer.datamodule, pipeline_attr_name)
return DataPipeline(data_pipeline.preprocess, self.postprocess)

return None
Loading