Skip to content

Commit

Permalink
pipelines
Browse files Browse the repository at this point in the history
  • Loading branch information
arie-matsliah committed Feb 2, 2021
1 parent acaa740 commit 9408f9d
Showing 1 changed file with 42 additions and 1 deletion.
43 changes: 42 additions & 1 deletion sleap/nn/data/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
RandomCropper,
)
from sleap.nn.data.normalization import Normalizer
from sleap.nn.data.resizing import Resizer, PointsRescaler
from sleap.nn.data.resizing import Resizer, PointsRescaler, SizeMatcher
from sleap.nn.data.instance_centroids import InstanceCentroidFinder
from sleap.nn.data.instance_cropping import InstanceCropper, PredictedInstanceCropper
from sleap.nn.data.confidence_maps import (
Expand Down Expand Up @@ -68,6 +68,7 @@
RandomCropper,
Normalizer,
Resizer,
SizeMatcher,
InstanceCentroidFinder,
InstanceCropper,
MultiConfidenceMapGenerator,
Expand Down Expand Up @@ -353,6 +354,11 @@ def make_base_pipeline(self, data_provider: Provider) -> Pipeline:
pipeline = Pipeline(providers=data_provider)
pipeline += Normalizer.from_config(self.data_config.preprocessing)
pipeline += Resizer.from_config(self.data_config.preprocessing)
if self.data_config.preprocessing.resize_and_pad_to_target:
pipeline += SizeMatcher.from_config(
config=self.data_config.preprocessing,
provider=data_provider,
)
if self.optimization_config.augmentation_config.random_crop:
pipeline += RandomCropper(
crop_height=self.optimization_config.augmentation_config.random_crop_height,
Expand Down Expand Up @@ -393,6 +399,11 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline:
)
pipeline += Normalizer.from_config(self.data_config.preprocessing)
pipeline += Resizer.from_config(self.data_config.preprocessing)
if self.data_config.preprocessing.resize_and_pad_to_target:
pipeline += SizeMatcher.from_config(
config=self.data_config.preprocessing,
provider=data_provider,
)

pipeline += SingleInstanceConfidenceMapGenerator(
sigma=self.single_instance_confmap_head.sigma,
Expand Down Expand Up @@ -483,6 +494,11 @@ def make_base_pipeline(self, data_provider: Provider) -> Pipeline:
pipeline = Pipeline(providers=data_provider)
pipeline += Normalizer.from_config(self.data_config.preprocessing)
pipeline += Resizer.from_config(self.data_config.preprocessing)
if self.data_config.preprocessing.resize_and_pad_to_target:
pipeline += SizeMatcher.from_config(
config=self.data_config.preprocessing,
provider=data_provider,
)
if self.optimization_config.augmentation_config.random_crop:
pipeline += RandomCropper(
crop_height=self.optimization_config.augmentation_config.random_crop_height,
Expand Down Expand Up @@ -529,6 +545,11 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline:
)
pipeline += Normalizer.from_config(self.data_config.preprocessing)
pipeline += Resizer.from_config(self.data_config.preprocessing)
if self.data_config.preprocessing.resize_and_pad_to_target:
pipeline += SizeMatcher.from_config(
config=self.data_config.preprocessing,
provider=data_provider,
)

pipeline += InstanceCentroidFinder.from_config(
self.data_config.instance_cropping,
Expand Down Expand Up @@ -635,6 +656,11 @@ def make_base_pipeline(self, data_provider: Provider) -> Pipeline:
pipeline = Pipeline(providers=data_provider)
pipeline += Normalizer.from_config(self.data_config.preprocessing)
pipeline += Resizer.from_config(self.data_config.preprocessing)
if self.data_config.preprocessing.resize_and_pad_to_target:
pipeline += SizeMatcher.from_config(
config=self.data_config.preprocessing,
provider=data_provider,
)
pipeline += InstanceCentroidFinder.from_config(
self.data_config.instance_cropping,
skeletons=self.data_config.labels.skeletons,
Expand Down Expand Up @@ -672,6 +698,11 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline:
)
pipeline += Normalizer.from_config(self.data_config.preprocessing)
pipeline += Resizer.from_config(self.data_config.preprocessing)
if self.data_config.preprocessing.resize_and_pad_to_target:
pipeline += SizeMatcher.from_config(
config=self.data_config.preprocessing,
provider=data_provider,
)

pipeline += InstanceCentroidFinder.from_config(
self.data_config.instance_cropping,
Expand Down Expand Up @@ -764,6 +795,11 @@ def make_base_pipeline(self, data_provider: Provider) -> Pipeline:
pipeline = Pipeline(providers=data_provider)
pipeline += Normalizer.from_config(self.data_config.preprocessing)
pipeline += Resizer.from_config(self.data_config.preprocessing)
if self.data_config.preprocessing.resize_and_pad_to_target:
pipeline += SizeMatcher.from_config(
config=self.data_config.preprocessing,
provider=data_provider,
)
if self.optimization_config.augmentation_config.random_crop:
pipeline += RandomCropper(
crop_height=self.optimization_config.augmentation_config.random_crop_height,
Expand Down Expand Up @@ -805,6 +841,11 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline:
)
pipeline += Normalizer.from_config(self.data_config.preprocessing)
pipeline += Resizer.from_config(self.data_config.preprocessing)
if self.data_config.preprocessing.resize_and_pad_to_target:
pipeline += SizeMatcher.from_config(
config=self.data_config.preprocessing,
provider=data_provider,
)

pipeline += MultiConfidenceMapGenerator(
sigma=self.confmaps_head.sigma,
Expand Down

0 comments on commit 9408f9d

Please sign in to comment.