From 6c624ad847dc07c73eedd2143522eec263634a6f Mon Sep 17 00:00:00 2001 From: Ananya Harsh Jha Date: Fri, 24 Sep 2021 07:34:34 -0400 Subject: [PATCH] working --- flash/image/__init__.py | 2 +- flash/image/face_detection/__init__.py | 3 +- flash/image/face_detection/data.py | 118 ++++++++++++++++++++---- flash/image/face_detection/model.py | 119 +++++++++---------------- flash_examples/face_detection.py | 20 +++-- 5 files changed, 156 insertions(+), 106 deletions(-) diff --git a/flash/image/__init__.py b/flash/image/__init__.py index d881de0e70..36257c77ce 100644 --- a/flash/image/__init__.py +++ b/flash/image/__init__.py @@ -6,7 +6,7 @@ from flash.image.classification.backbones import IMAGE_CLASSIFIER_BACKBONES # noqa: F401 from flash.image.detection import ObjectDetectionData, ObjectDetector # noqa: F401 from flash.image.embedding import ImageEmbedder # noqa: F401 -from flash.image.face_detection import FaceDetector # noqa: F401 +from flash.image.face_detection import FaceDetector, FaceDetectionData # noqa: F401 from flash.image.instance_segmentation import InstanceSegmentation, InstanceSegmentationData # noqa: F401 from flash.image.keypoint_detection import KeypointDetectionData, KeypointDetector # noqa: F401 from flash.image.segmentation import ( # noqa: F401 diff --git a/flash/image/face_detection/__init__.py b/flash/image/face_detection/__init__.py index c642f1c2ba..7d14bba121 100644 --- a/flash/image/face_detection/__init__.py +++ b/flash/image/face_detection/__init__.py @@ -1 +1,2 @@ -from flash.image.face_detection.model import FaceDetector # noqa: F401 +from flash.image.face_detection.model import FaceDetector # noqa: F401 +from flash.image.face_detection.data import FaceDetectionData # noqa: F401 diff --git a/flash/image/face_detection/data.py b/flash/image/face_detection/data.py index be8e8e9cd8..8d196819e8 100644 --- a/flash/image/face_detection/data.py +++ b/flash/image/face_detection/data.py @@ -11,45 +11,90 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Optional, Sequence, Tuple +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Mapping + +import torch +import torchvision +import torch.nn as nn from torch.utils.data import Dataset -from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources -from flash.core.data.process import Preprocess -from flash.core.utilities.imports import _TORCHVISION_AVAILABLE +from flash.core.data.transforms import ApplyToKeys +from flash.core.data.data_source import DatasetDataSource, DefaultDataKeys, DefaultDataSources +from flash.core.data.process import Preprocess, Postprocess +from flash.core.utilities.imports import _TORCHVISION_AVAILABLE, _ICEVISION_AVAILABLE, _FASTFACE_AVAILABLE from flash.image.data import ImagePathsDataSource -from flash.image.detection.transforms import default_transforms +from flash.core.integrations.icevision.data import IceVisionParserDataSource +from flash.core.integrations.icevision.transforms import default_transforms +from flash.image.detection import ObjectDetectionData if _TORCHVISION_AVAILABLE: from torchvision.datasets.folder import default_loader +if _ICEVISION_AVAILABLE: + from icevision.parsers import COCOBBoxParser +else: + COCOBBoxParser = object + +if _FASTFACE_AVAILABLE: + import fastface as ff + + +def fastface_collate_fn(samples: Sequence[Dict[str, Any]]) -> Dict[str, Sequence[Any]]: + samples = {key: [sample[key] for sample in samples] for key in samples[0]} + + images, scales, paddings = ff.utils.preprocess.prepare_batch( + samples[DefaultDataKeys.INPUT], None, adaptive_batch=True + ) + + samples["scales"] = scales + samples["paddings"] = paddings -class FastFaceDataSource(DataSource[Tuple[str, str]]): + if DefaultDataKeys.TARGET in samples.keys(): + targets = samples[DefaultDataKeys.TARGET] + targets = [{"target_boxes": target["boxes"]} for target in targets] - def load_data(self, data: Dataset, dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: + for i, (target, scale, padding) in enumerate(zip(targets, scales, paddings)): + target["target_boxes"] *= scale + target["target_boxes"][:, [0, 2]] += padding[0] + target["target_boxes"][:, [1, 3]] += padding[1] + targets[i]["target_boxes"] = target["target_boxes"] + + samples[DefaultDataKeys.TARGET] = targets + samples[DefaultDataKeys.INPUT] = images + + return samples + + +class FastFaceDataSource(DatasetDataSource): + def load_data(self, data: Dataset, dataset: Any = None) -> Dataset: new_data = [] for img_file_path, targets in zip(data.ids, data.targets): new_data.append( - dict( - input=img_file_path, - target=dict( - boxes=targets["target_boxes"], - labels=[1 for _ in range(targets["target_boxes"].shape[0])], + super().load_sample( + ( + img_file_path, + dict( + boxes=targets["target_boxes"], + labels=[1 for _ in range(targets["target_boxes"].shape[0])], + ) ) ) ) + return new_data - def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: + def load_sample(self, sample: Any, dataset: Optional[Any] = None) -> Mapping[str, Any]: filepath = sample[DefaultDataKeys.INPUT] img = default_loader(filepath) sample[DefaultDataKeys.INPUT] = img + w, h = img.size # WxH sample[DefaultDataKeys.METADATA] = { "filepath": filepath, "size": (h, w), } + return sample @@ -60,8 +105,11 @@ def __init__( train_transform: Optional[Dict[str, Callable]] = None, val_transform: Optional[Dict[str, Callable]] = None, test_transform: Optional[Dict[str, Callable]] = None, - predict_transform: Optional[Dict[str, Callable]] = None + predict_transform: Optional[Dict[str, Callable]] = None, + image_size: Tuple[int, int] = (128, 128), ): + self.image_size = image_size + super().__init__( train_transform=train_transform, val_transform=val_transform, @@ -70,7 +118,7 @@ def __init__( data_sources={ DefaultDataSources.FILES: ImagePathsDataSource(), DefaultDataSources.FOLDERS: ImagePathsDataSource(), - "fastface": FastFaceDataSource() + DefaultDataSources.DATASETS: FastFaceDataSource(), }, default_data_source=DefaultDataSources.FILES, ) @@ -82,5 +130,41 @@ def get_state_dict(self) -> Dict[str, Any]: def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): return cls(**state_dict) - def default_transforms(self) -> Optional[Dict[str, Callable]]: - return default_transforms() + def default_transforms(self) -> Dict[str, Callable]: + return { + "to_tensor_transform": nn.Sequential( + ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()), + ApplyToKeys( + DefaultDataKeys.TARGET, + nn.Sequential( + ApplyToKeys('boxes', torch.as_tensor), + ApplyToKeys('labels', torch.as_tensor), + ) + ), + ), + "collate": fastface_collate_fn, + } + + +class FaceDetectionPostProcess(Postprocess): + @staticmethod + def per_batch_transform(batch: Any) -> Any: + scales = batch['scales'] + paddings = batch['paddings'] + + batch.pop('scales', None) + batch.pop('paddings', None) + + preds = batch[DefaultDataKeys.PREDS] + + # preds: list of torch.Tensor(N, 5) as x1, y1, x2, y2, score + preds = [preds[preds[:, 5] == batch_idx, :5] for batch_idx in range(len(preds))] + preds = ff.utils.preprocess.adjust_results(preds, scales, paddings) + batch[DefaultDataKeys.PREDS] = preds + + return batch + + +class FaceDetectionData(ObjectDetectionData): + preprocess_cls = FaceDetectionPreprocess + postprocess_cls = FaceDetectionPostProcess diff --git a/flash/image/face_detection/model.py b/flash/image/face_detection/model.py index a36d52de3c..0679247935 100644 --- a/flash/image/face_detection/model.py +++ b/flash/image/face_detection/model.py @@ -11,24 +11,42 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, List, Mapping, Optional, Sequence, Type, Union +from typing import Any, Callable, List, Mapping, Optional, Sequence, Type, Union, Dict import torch +import pytorch_lightning as pl + from torch import nn from torch.optim import Optimizer +from flash.core.model import Task from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.process import Postprocess from flash.core.data.process import Preprocess, Serializer -from flash.core.model import Task from flash.core.utilities.imports import _FASTFACE_AVAILABLE -from flash.image.detection.finetuning import ObjectDetectionFineTuning -from flash.image.detection.serialization import DetectionLabels -from flash.image.face_detection.data import FaceDetectionPreprocess +from flash.core.finetuning import FlashBaseFinetuning +from flash.image.face_detection.data import FaceDetectionPreprocess, FaceDetectionPostProcess if _FASTFACE_AVAILABLE: import fastface as ff +class FaceDetectionFineTuning(FlashBaseFinetuning): + def __init__(self, train_bn: bool = True) -> None: + super().__init__(train_bn=train_bn) + + def freeze_before_training(self, pl_module: pl.LightningModule) -> None: + self.freeze(modules=pl_module.model.backbone, train_bn=self.train_bn) + + +class DetectionLabels(Serializer): + """A :class:`.Serializer` which extracts predictions from sample dict.""" + + def serialize(self, sample: Any) -> Dict[str, Any]: + sample = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample + return sample + + class FaceDetector(Task): """The ``FaceDetector`` is a :class:`~flash.Task` for detecting faces in images. For more details, see :ref:`face_detection`. @@ -100,54 +118,19 @@ def get_model( return model def forward(self, x: List[torch.Tensor]) -> Any: + images = self._prepare_batch(x) + logits = self.model(images) - batch, scales, paddings = ff.utils.preprocess.prepare_batch(x, None, adaptive_batch=True) - # batch: torch.Tensor(B,C,T,T) - # scales: torch.Tensor(B,) - # paddings: torch.Tensor(B,4) as pad (left, top, right, bottom) - - # apply preprocess - batch = (((batch * 255) / self.model.normalizer) - self.model.mean) / self.model.std - - # get logits - logits = self.model(batch) - # logits, any - - preds = self.model.logits_to_preds(logits) # preds: torch.Tensor(B, N, 5) - - preds = self.model._postprocess(preds) # preds: torch.Tensor(N, 6) as x1,y1,x2,y2,score,batch_idx - - preds = [preds[preds[:, 5] == batch_idx, :5] for batch_idx in range(batch.size(0))] - # preds: list of torch.Tensor(N, 5) as x1,y1,x2,y2,score - - preds = ff.utils.preprocess.adjust_results(preds, scales, paddings) - # preds: list of torch.Tensor(N, 5) as x1,y1,x2,y2,score + preds = self.model.logits_to_preds(logits) + preds = self.model._postprocess(preds) return preds def _prepare_batch(self, batch): - images, targets = batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET] - - targets = [{"target_boxes": target["boxes"]} for target in targets] - - batch, scales, paddings = ff.utils.preprocess.prepare_batch(images, None, adaptive_batch=True) - # batch: torch.Tensor(B,C,T,T) - # scales: torch.Tensor(B,) - # paddings: torch.Tensor(B,4) as pad (left, top, right, bottom) - - # apply preprocess batch = (((batch * 255) / self.model.normalizer) - self.model.mean) / self.model.std - - # adjust targets - for i, (target, scale, padding) in enumerate(zip(targets, scales, paddings)): - target["target_boxes"] *= scale - target["target_boxes"][:, [0, 2]] += padding[0] - target["target_boxes"][:, [1, 3]] += padding[1] - targets[i]["target_boxes"] = target["target_boxes"] - - return batch, targets + return batch def _compute_metrics(self, logits, targets): preds = self.model.logits_to_preds(logits) @@ -162,19 +145,19 @@ def _compute_metrics(self, logits, targets): for metric in self.val_metrics.values(): metric.update(pred_boxes, target_boxes) - def training_step(self, batch, batch_idx) -> Any: - """The training step. Overrides ``Task.training_step`` - """ + def shared_step(self, batch, train=False) -> Any: + images, targets = batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET] + images = self._prepare_batch(images) + logits = self.model(images) + loss = self.model.compute_loss(logits, targets) - batch, targets = self._prepare_batch(batch) + if not train: + self._compute_metrics(logits, targets) - # get logits - logits = self.model(batch) - # logits, any + return loss, logits - # compute loss - loss = self.model.compute_loss(logits, targets) - # loss: dict of losses or loss + def training_step(self, batch, batch_idx) -> Any: + loss, _ = self.shared_step(batch) self.log_dict({f"train_{k}": v for k, v in loss.items()}, on_step=True, on_epoch=True, prog_bar=True) return loss @@ -184,17 +167,7 @@ def on_validation_epoch_start(self) -> None: metric.reset() def validation_step(self, batch, batch_idx): - batch, targets = self._prepare_batch(batch) - - # get logits - logits = self.model(batch) - # logits, any - - # compute loss - loss = self.model.compute_loss(logits, targets) - # loss: dict of losses or loss - - self._compute_metrics(logits, targets) + loss, logits = self.shared_step(batch) self.log_dict({f"val_{k}": v for k, v in loss.items()}, on_step=True, on_epoch=True, prog_bar=True) return loss @@ -208,17 +181,7 @@ def on_test_epoch_start(self) -> None: metric.reset() def test_step(self, batch, batch_idx): - batch, targets = self._prepare_batch(batch) - - # get logits - logits = self.model(batch) - # logits, any - - # compute loss - loss = self.model.compute_loss(logits, targets) - # loss: dict of losses or loss - - self._compute_metrics(logits, targets) + loss, logits = self.shared_step(batch) self.log_dict({f"test_{k}": v for k, v in loss.items()}, on_step=True, on_epoch=True, prog_bar=True) return loss @@ -233,4 +196,4 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A return batch def configure_finetune_callback(self): - return [ObjectDetectionFineTuning(train_bn=True)] + return [FaceDetectionFineTuning()] diff --git a/flash_examples/face_detection.py b/flash_examples/face_detection.py index e8a14f7096..efe00aa4c3 100644 --- a/flash_examples/face_detection.py +++ b/flash_examples/face_detection.py @@ -12,30 +12,32 @@ # See the License for the specific language governing permissions and # limitations under the License. import flash + +from flash.core.data.utils import download_data from flash.core.data.data_module import DataModule from flash.core.utilities.imports import _FASTFACE_AVAILABLE from flash.image import FaceDetector -from flash.image.face_detection.data import FaceDetectionPreprocess +from flash.image.face_detection.data import FaceDetectionPreprocess, FaceDetectionPostProcess +from flash.image import FaceDetectionData if _FASTFACE_AVAILABLE: import fastface as ff else: - raise ModuleNotFoundError("Please, pip install -e '.[image]'") + raise ModuleNotFoundError("Please, pip install --upgrade 'lightning-flash[image_extras]'") -# 1. Create the DataModule +# # 1. Create the DataModule train_dataset = ff.dataset.FDDBDataset(source_dir="data/", phase="train") val_dataset = ff.dataset.FDDBDataset(source_dir="data/", phase="val") -datamodule = DataModule.from_data_source( - "fastface", train_data=train_dataset, val_data=val_dataset, preprocess=FaceDetectionPreprocess() +datamodule = FaceDetectionData.from_datasets( + train_dataset=train_dataset, val_dataset=val_dataset, batch_size=2 ) -# 2. Build the task +# # 2. Build the task model = FaceDetector(model="lffd_slim") -# 3. Create the trainer and finetune the model +# # 3. Create the trainer and finetune the model trainer = flash.Trainer(max_epochs=3, limit_train_batches=0.1, limit_val_batches=0.1) - trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Detect faces in a few images! @@ -46,5 +48,5 @@ ]) print(predictions) -# 5. Save the model! +# # 5. Save the model! trainer.save_checkpoint("face_detection_model.pt")