Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
working
Browse files Browse the repository at this point in the history
  • Loading branch information
ananyahjha93 committed Sep 24, 2021
1 parent 2b22d87 commit 6c624ad
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 106 deletions.
2 changes: 1 addition & 1 deletion flash/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from flash.image.classification.backbones import IMAGE_CLASSIFIER_BACKBONES # noqa: F401
from flash.image.detection import ObjectDetectionData, ObjectDetector # noqa: F401
from flash.image.embedding import ImageEmbedder # noqa: F401
from flash.image.face_detection import FaceDetector # noqa: F401
from flash.image.face_detection import FaceDetector, FaceDetectionData # noqa: F401
from flash.image.instance_segmentation import InstanceSegmentation, InstanceSegmentationData # noqa: F401
from flash.image.keypoint_detection import KeypointDetectionData, KeypointDetector # noqa: F401
from flash.image.segmentation import ( # noqa: F401
Expand Down
3 changes: 2 additions & 1 deletion flash/image/face_detection/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from flash.image.face_detection.model import FaceDetector # noqa: F401
from flash.image.face_detection.model import FaceDetector # noqa: F401
from flash.image.face_detection.data import FaceDetectionData # noqa: F401
118 changes: 101 additions & 17 deletions flash/image/face_detection/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,45 +11,90 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, Optional, Sequence, Tuple
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Mapping

import torch
import torchvision
import torch.nn as nn

from torch.utils.data import Dataset

from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources
from flash.core.data.process import Preprocess
from flash.core.utilities.imports import _TORCHVISION_AVAILABLE
from flash.core.data.transforms import ApplyToKeys
from flash.core.data.data_source import DatasetDataSource, DefaultDataKeys, DefaultDataSources
from flash.core.data.process import Preprocess, Postprocess
from flash.core.utilities.imports import _TORCHVISION_AVAILABLE, _ICEVISION_AVAILABLE, _FASTFACE_AVAILABLE
from flash.image.data import ImagePathsDataSource
from flash.image.detection.transforms import default_transforms
from flash.core.integrations.icevision.data import IceVisionParserDataSource
from flash.core.integrations.icevision.transforms import default_transforms
from flash.image.detection import ObjectDetectionData

if _TORCHVISION_AVAILABLE:
from torchvision.datasets.folder import default_loader

if _ICEVISION_AVAILABLE:
from icevision.parsers import COCOBBoxParser
else:
COCOBBoxParser = object

if _FASTFACE_AVAILABLE:
import fastface as ff


def fastface_collate_fn(samples: Sequence[Dict[str, Any]]) -> Dict[str, Sequence[Any]]:
samples = {key: [sample[key] for sample in samples] for key in samples[0]}

images, scales, paddings = ff.utils.preprocess.prepare_batch(
samples[DefaultDataKeys.INPUT], None, adaptive_batch=True
)

samples["scales"] = scales
samples["paddings"] = paddings

class FastFaceDataSource(DataSource[Tuple[str, str]]):
if DefaultDataKeys.TARGET in samples.keys():
targets = samples[DefaultDataKeys.TARGET]
targets = [{"target_boxes": target["boxes"]} for target in targets]

def load_data(self, data: Dataset, dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]:
for i, (target, scale, padding) in enumerate(zip(targets, scales, paddings)):
target["target_boxes"] *= scale
target["target_boxes"][:, [0, 2]] += padding[0]
target["target_boxes"][:, [1, 3]] += padding[1]
targets[i]["target_boxes"] = target["target_boxes"]

samples[DefaultDataKeys.TARGET] = targets
samples[DefaultDataKeys.INPUT] = images

return samples


class FastFaceDataSource(DatasetDataSource):
def load_data(self, data: Dataset, dataset: Any = None) -> Dataset:
new_data = []
for img_file_path, targets in zip(data.ids, data.targets):
new_data.append(
dict(
input=img_file_path,
target=dict(
boxes=targets["target_boxes"],
labels=[1 for _ in range(targets["target_boxes"].shape[0])],
super().load_sample(
(
img_file_path,
dict(
boxes=targets["target_boxes"],
labels=[1 for _ in range(targets["target_boxes"].shape[0])],
)
)
)
)

return new_data

def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
def load_sample(self, sample: Any, dataset: Optional[Any] = None) -> Mapping[str, Any]:
filepath = sample[DefaultDataKeys.INPUT]
img = default_loader(filepath)
sample[DefaultDataKeys.INPUT] = img

w, h = img.size # WxH
sample[DefaultDataKeys.METADATA] = {
"filepath": filepath,
"size": (h, w),
}

return sample


Expand All @@ -60,8 +105,11 @@ def __init__(
train_transform: Optional[Dict[str, Callable]] = None,
val_transform: Optional[Dict[str, Callable]] = None,
test_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None
predict_transform: Optional[Dict[str, Callable]] = None,
image_size: Tuple[int, int] = (128, 128),
):
self.image_size = image_size

super().__init__(
train_transform=train_transform,
val_transform=val_transform,
Expand All @@ -70,7 +118,7 @@ def __init__(
data_sources={
DefaultDataSources.FILES: ImagePathsDataSource(),
DefaultDataSources.FOLDERS: ImagePathsDataSource(),
"fastface": FastFaceDataSource()
DefaultDataSources.DATASETS: FastFaceDataSource(),
},
default_data_source=DefaultDataSources.FILES,
)
Expand All @@ -82,5 +130,41 @@ def get_state_dict(self) -> Dict[str, Any]:
def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False):
return cls(**state_dict)

def default_transforms(self) -> Optional[Dict[str, Callable]]:
return default_transforms()
def default_transforms(self) -> Dict[str, Callable]:
return {
"to_tensor_transform": nn.Sequential(
ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()),
ApplyToKeys(
DefaultDataKeys.TARGET,
nn.Sequential(
ApplyToKeys('boxes', torch.as_tensor),
ApplyToKeys('labels', torch.as_tensor),
)
),
),
"collate": fastface_collate_fn,
}


class FaceDetectionPostProcess(Postprocess):
@staticmethod
def per_batch_transform(batch: Any) -> Any:
scales = batch['scales']
paddings = batch['paddings']

batch.pop('scales', None)
batch.pop('paddings', None)

preds = batch[DefaultDataKeys.PREDS]

# preds: list of torch.Tensor(N, 5) as x1, y1, x2, y2, score
preds = [preds[preds[:, 5] == batch_idx, :5] for batch_idx in range(len(preds))]
preds = ff.utils.preprocess.adjust_results(preds, scales, paddings)
batch[DefaultDataKeys.PREDS] = preds

return batch


class FaceDetectionData(ObjectDetectionData):
preprocess_cls = FaceDetectionPreprocess
postprocess_cls = FaceDetectionPostProcess
119 changes: 41 additions & 78 deletions flash/image/face_detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,42 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, List, Mapping, Optional, Sequence, Type, Union
from typing import Any, Callable, List, Mapping, Optional, Sequence, Type, Union, Dict

import torch
import pytorch_lightning as pl

from torch import nn
from torch.optim import Optimizer

from flash.core.model import Task
from flash.core.data.data_source import DefaultDataKeys
from flash.core.data.process import Postprocess
from flash.core.data.process import Preprocess, Serializer
from flash.core.model import Task
from flash.core.utilities.imports import _FASTFACE_AVAILABLE
from flash.image.detection.finetuning import ObjectDetectionFineTuning
from flash.image.detection.serialization import DetectionLabels
from flash.image.face_detection.data import FaceDetectionPreprocess
from flash.core.finetuning import FlashBaseFinetuning
from flash.image.face_detection.data import FaceDetectionPreprocess, FaceDetectionPostProcess

if _FASTFACE_AVAILABLE:
import fastface as ff


class FaceDetectionFineTuning(FlashBaseFinetuning):
def __init__(self, train_bn: bool = True) -> None:
super().__init__(train_bn=train_bn)

def freeze_before_training(self, pl_module: pl.LightningModule) -> None:
self.freeze(modules=pl_module.model.backbone, train_bn=self.train_bn)


class DetectionLabels(Serializer):
"""A :class:`.Serializer` which extracts predictions from sample dict."""

def serialize(self, sample: Any) -> Dict[str, Any]:
sample = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample
return sample


class FaceDetector(Task):
"""The ``FaceDetector`` is a :class:`~flash.Task` for detecting faces in images. For more details, see
:ref:`face_detection`.
Expand Down Expand Up @@ -100,54 +118,19 @@ def get_model(
return model

def forward(self, x: List[torch.Tensor]) -> Any:
images = self._prepare_batch(x)
logits = self.model(images)

batch, scales, paddings = ff.utils.preprocess.prepare_batch(x, None, adaptive_batch=True)
# batch: torch.Tensor(B,C,T,T)
# scales: torch.Tensor(B,)
# paddings: torch.Tensor(B,4) as pad (left, top, right, bottom)

# apply preprocess
batch = (((batch * 255) / self.model.normalizer) - self.model.mean) / self.model.std

# get logits
logits = self.model(batch)
# logits, any

preds = self.model.logits_to_preds(logits)
# preds: torch.Tensor(B, N, 5)

preds = self.model._postprocess(preds)
# preds: torch.Tensor(N, 6) as x1,y1,x2,y2,score,batch_idx

preds = [preds[preds[:, 5] == batch_idx, :5] for batch_idx in range(batch.size(0))]
# preds: list of torch.Tensor(N, 5) as x1,y1,x2,y2,score

preds = ff.utils.preprocess.adjust_results(preds, scales, paddings)
# preds: list of torch.Tensor(N, 5) as x1,y1,x2,y2,score
preds = self.model.logits_to_preds(logits)
preds = self.model._postprocess(preds)

return preds

def _prepare_batch(self, batch):
images, targets = batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]

targets = [{"target_boxes": target["boxes"]} for target in targets]

batch, scales, paddings = ff.utils.preprocess.prepare_batch(images, None, adaptive_batch=True)
# batch: torch.Tensor(B,C,T,T)
# scales: torch.Tensor(B,)
# paddings: torch.Tensor(B,4) as pad (left, top, right, bottom)

# apply preprocess
batch = (((batch * 255) / self.model.normalizer) - self.model.mean) / self.model.std

# adjust targets
for i, (target, scale, padding) in enumerate(zip(targets, scales, paddings)):
target["target_boxes"] *= scale
target["target_boxes"][:, [0, 2]] += padding[0]
target["target_boxes"][:, [1, 3]] += padding[1]
targets[i]["target_boxes"] = target["target_boxes"]

return batch, targets
return batch

def _compute_metrics(self, logits, targets):
preds = self.model.logits_to_preds(logits)
Expand All @@ -162,19 +145,19 @@ def _compute_metrics(self, logits, targets):
for metric in self.val_metrics.values():
metric.update(pred_boxes, target_boxes)

def training_step(self, batch, batch_idx) -> Any:
"""The training step. Overrides ``Task.training_step``
"""
def shared_step(self, batch, train=False) -> Any:
images, targets = batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]
images = self._prepare_batch(images)
logits = self.model(images)
loss = self.model.compute_loss(logits, targets)

batch, targets = self._prepare_batch(batch)
if not train:
self._compute_metrics(logits, targets)

# get logits
logits = self.model(batch)
# logits, any
return loss, logits

# compute loss
loss = self.model.compute_loss(logits, targets)
# loss: dict of losses or loss
def training_step(self, batch, batch_idx) -> Any:
loss, _ = self.shared_step(batch)

self.log_dict({f"train_{k}": v for k, v in loss.items()}, on_step=True, on_epoch=True, prog_bar=True)
return loss
Expand All @@ -184,17 +167,7 @@ def on_validation_epoch_start(self) -> None:
metric.reset()

def validation_step(self, batch, batch_idx):
batch, targets = self._prepare_batch(batch)

# get logits
logits = self.model(batch)
# logits, any

# compute loss
loss = self.model.compute_loss(logits, targets)
# loss: dict of losses or loss

self._compute_metrics(logits, targets)
loss, logits = self.shared_step(batch)

self.log_dict({f"val_{k}": v for k, v in loss.items()}, on_step=True, on_epoch=True, prog_bar=True)
return loss
Expand All @@ -208,17 +181,7 @@ def on_test_epoch_start(self) -> None:
metric.reset()

def test_step(self, batch, batch_idx):
batch, targets = self._prepare_batch(batch)

# get logits
logits = self.model(batch)
# logits, any

# compute loss
loss = self.model.compute_loss(logits, targets)
# loss: dict of losses or loss

self._compute_metrics(logits, targets)
loss, logits = self.shared_step(batch)

self.log_dict({f"test_{k}": v for k, v in loss.items()}, on_step=True, on_epoch=True, prog_bar=True)
return loss
Expand All @@ -233,4 +196,4 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A
return batch

def configure_finetune_callback(self):
return [ObjectDetectionFineTuning(train_bn=True)]
return [FaceDetectionFineTuning()]
Loading

0 comments on commit 6c624ad

Please sign in to comment.