Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
borhanMorphy committed Jul 16, 2021
1 parent 57226b6 commit a3a82bb
Showing 1 changed file with 23 additions and 6 deletions.
29 changes: 23 additions & 6 deletions flash/image/detection/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from flash.core.data.transforms import ApplyToKeys
import os
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, TYPE_CHECKING

Expand Down Expand Up @@ -41,6 +42,7 @@

if _TORCHVISION_AVAILABLE:
from torchvision.datasets.folder import default_loader
from torchvision.transforms import Normalize

if _FASTFACE_AVAILABLE:
import fastface as ff
Expand Down Expand Up @@ -356,7 +358,9 @@ def __init__(
val_transform=val_transform,
test_transform=test_transform,
predict_transform=predict_transform,
data_sources={ # TODO add default detection sources
data_sources={
DefaultDataSources.FILES: ImagePathsDataSource(),
DefaultDataSources.FOLDERS: ImagePathsDataSource(),
"fddb": FDDBDataSource()
},
default_data_source=DefaultDataSources.FILES,
Expand All @@ -370,20 +374,29 @@ def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False):
return cls(**state_dict)

def default_transforms(self) -> Optional[Dict[str, Callable]]:
return default_transforms()

face_det_transforms = default_transforms()

face_det_transforms["to_tensor_transform"].add_module(
str(len(face_det_transforms["to_tensor_transform"])),
# Re scale image to [0, 255]
ApplyToKeys("input", Normalize(mean=0, std=1/255))
)

return face_det_transforms


class FaceDetectionData(DataModule):

preprocess_cls = FaceDetectionPreprocess # TODO
preprocess_cls = FaceDetectionPreprocess

@classmethod
def from_fddb(
cls,
data_folder: Optional[str] = None,
train_data: bool = False,
test_data: bool = False,
val_data: bool = False,
test_data: bool = False,
train_transform: Optional[Dict[str, Callable]] = None,
val_transform: Optional[Dict[str, Callable]] = None,
test_transform: Optional[Dict[str, Callable]] = None,
Expand All @@ -399,7 +412,9 @@ def from_fddb(
Args:
data_folder: The folder containing the `2002`, `2003` ad `FDDB-folds` folders.
val_ann_file: The COCO format annotation file.
train_data: set True to load `train phase` dataset, Default False.
val_data: set True to load `val phase` dataset, Default False.
test_data: set True to load `test phase` dataset, Default False.
train_transform: The dictionary of transforms to use during training which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
val_transform: The dictionary of transforms to use during validation which maps
Expand All @@ -424,10 +439,12 @@ def from_fddb(
data_module = FaceDetectionData.from_fddb(
data_folder="data/",
train_data=True,
val_data=True,
)
"""

return cls.from_data_source( # TODO
return cls.from_data_source(
"fddb",
(data_folder, "train") if train_data else None,
(data_folder, "val") if val_data else None,
Expand Down

0 comments on commit a3a82bb

Please sign in to comment.