Skip to content

Commit

Permalink
Merge pull request #1112 from openvinotoolkit/vsaltykovx/add_mmdetect…
Browse files Browse the repository at this point in the history
…ion_input_parameters_validation_2

Vsaltykovx/add mmdetection input parameters validation 2
  • Loading branch information
goodsong81 authored May 25, 2022
2 parents 38a4d88 + bdb7599 commit f003a20
Show file tree
Hide file tree
Showing 24 changed files with 3,300 additions and 100 deletions.
44 changes: 33 additions & 11 deletions external/mmdetection/detection_tasks/apis/detection/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,18 @@
import os
import tempfile
from collections import defaultdict
from typing import List, Optional
from typing import List, Optional, Union

import torch
from mmcv import Config, ConfigDict
from ote_sdk.entities.datasets import DatasetEntity
from ote_sdk.entities.label import LabelEntity, Domain
from ote_sdk.usecases.reporting.time_monitor_callback import TimeMonitorCallback
from ote_sdk.utils.argument_checks import (
DatasetParamTypeCheck,
DirectoryPathCheck,
check_input_parameters_type
)

from detection_tasks.extension.datasets.data_utils import get_anchor_boxes, \
get_sizes_from_dataset_entity, format_list_to_str
Expand All @@ -43,14 +48,16 @@
logger = get_root_logger()


@check_input_parameters_type()
def is_epoch_based_runner(runner_config: ConfigDict):
return 'Epoch' in runner_config.type


@check_input_parameters_type({"work_dir": DirectoryPathCheck})
def patch_config(config: Config, work_dir: str, labels: List[LabelEntity], domain: Domain, random_seed: Optional[int] = None):
# Set runner if not defined.
if 'runner' not in config:
config.runner = {'type': 'EpochBasedRunner'}
config.runner = ConfigDict({'type': 'EpochBasedRunner'})

# Check that there is no conflict in specification of number of training epochs.
# Move global definition of epochs inside runner config.
Expand Down Expand Up @@ -112,6 +119,7 @@ def patch_config(config: Config, work_dir: str, labels: List[LabelEntity], domai
config.seed = random_seed


@check_input_parameters_type()
def set_hyperparams(config: Config, hyperparams: OTEDetectionConfig):
config.optimizer.lr = float(hyperparams.learning_parameters.learning_rate)
config.lr_config.warmup_iters = int(hyperparams.learning_parameters.learning_rate_warmup_iters)
Expand All @@ -126,7 +134,8 @@ def set_hyperparams(config: Config, hyperparams: OTEDetectionConfig):
config.runner.max_iters = total_iterations


def patch_adaptive_repeat_dataset(config: Config, num_samples: int,
@check_input_parameters_type()
def patch_adaptive_repeat_dataset(config: Union[Config, ConfigDict], num_samples: int,
decay: float = -0.002, factor: float = 30):
""" Patch the repeat times and training epochs adatively
Expand Down Expand Up @@ -155,14 +164,17 @@ def patch_adaptive_repeat_dataset(config: Config, num_samples: int,
data_train.times = new_repeat


def prepare_for_testing(config: Config, dataset: DatasetEntity) -> Config:
@check_input_parameters_type({"dataset": DatasetParamTypeCheck})
def prepare_for_testing(config: Union[Config, ConfigDict], dataset: DatasetEntity) -> Config:
config = copy.deepcopy(config)
# FIXME. Should working directories be modified here?
config.data.test.ote_dataset = dataset
return config


def prepare_for_training(config: Config, train_dataset: DatasetEntity, val_dataset: DatasetEntity,
@check_input_parameters_type({"train_dataset": DatasetParamTypeCheck,
"val_dataset": DatasetParamTypeCheck})
def prepare_for_training(config: Union[Config, ConfigDict], train_dataset: DatasetEntity, val_dataset: DatasetEntity,
time_monitor: TimeMonitorCallback, learning_curves: defaultdict) -> Config:
config = copy.deepcopy(config)
prepare_work_dir(config)
Expand All @@ -175,7 +187,8 @@ def prepare_for_training(config: Config, train_dataset: DatasetEntity, val_datas
return config


def config_to_string(config: Config) -> str:
@check_input_parameters_type()
def config_to_string(config: Union[Config, ConfigDict]) -> str:
"""
Convert a full mmdetection config to a string.
Expand All @@ -194,6 +207,7 @@ def config_to_string(config: Config) -> str:
return Config(config_copy).pretty_text


@check_input_parameters_type()
def config_from_string(config_string: str) -> Config:
"""
Generate an mmdetection config dict object from a string.
Expand All @@ -207,6 +221,7 @@ def config_from_string(config_string: str) -> Config:
return Config.fromfile(temp_file.name)


@check_input_parameters_type()
def save_config_to_file(config: Config):
""" Dump the full config to a file. Filename is 'config.py', it is saved in the current work_dir. """
filepath = os.path.join(config.work_dir, 'config.py')
Expand All @@ -215,7 +230,8 @@ def save_config_to_file(config: Config):
f.write(config_string)


def prepare_work_dir(config: Config) -> str:
@check_input_parameters_type()
def prepare_work_dir(config: Union[Config, ConfigDict]) -> str:
base_work_dir = config.work_dir
checkpoint_dirs = glob.glob(os.path.join(base_work_dir, "checkpoints_round_*"))
train_round_checkpoint_dir = os.path.join(base_work_dir, f"checkpoints_round_{len(checkpoint_dirs)}")
Expand All @@ -230,6 +246,7 @@ def prepare_work_dir(config: Config) -> str:
return train_round_checkpoint_dir


@check_input_parameters_type()
def set_data_classes(config: Config, labels: List[LabelEntity]):
# Save labels in data configs.
for subset in ('train', 'val', 'test'):
Expand All @@ -256,7 +273,8 @@ def set_data_classes(config: Config, labels: List[LabelEntity]):
# self.config.model.CLASSES = label_names


def patch_datasets(config: Config, domain):
@check_input_parameters_type()
def patch_datasets(config: Config, domain: Domain):

def patch_color_conversion(pipeline):
# Default data format for OTE is RGB, while mmdet uses BGR, so negate the color conversion flag.
Expand Down Expand Up @@ -289,7 +307,8 @@ def patch_color_conversion(pipeline):
patch_color_conversion(cfg.pipeline)


def remove_from_config(config, key: str):
@check_input_parameters_type()
def remove_from_config(config: Union[Config, ConfigDict], key: str):
if key in config:
if isinstance(config, Config):
del config._cfg_dict[key]
Expand All @@ -298,6 +317,8 @@ def remove_from_config(config, key: str):
else:
raise ValueError(f'Unknown config type {type(config)}')


@check_input_parameters_type({"dataset": DatasetParamTypeCheck})
def cluster_anchors(config: Config, dataset: DatasetEntity, model: BaseDetector):
if not kmeans_import:
raise ImportError('Sklearn package is not installed. To enable anchor boxes clustering, please install '
Expand All @@ -308,7 +329,7 @@ def cluster_anchors(config: Config, dataset: DatasetEntity, model: BaseDetector)
if transforms.type == 'MultiScaleFlipAug']
prev_generator = config.model.bbox_head.anchor_generator
group_as = [len(width) for width in prev_generator.widths]
wh_stats = get_sizes_from_dataset_entity(dataset, target_wh)
wh_stats = get_sizes_from_dataset_entity(dataset, list(target_wh))

if len(wh_stats) < sum(group_as):
logger.warning(f'There are not enough objects to cluster: {len(wh_stats)} were detected, while it should be '
Expand All @@ -332,7 +353,8 @@ def cluster_anchors(config: Config, dataset: DatasetEntity, model: BaseDetector)
return config, model


def get_data_cfg(config: Config, subset: str = 'train') -> Config:
@check_input_parameters_type()
def get_data_cfg(config: Union[Config, ConfigDict], subset: str = 'train') -> Config:
data_cfg = config.data[subset]
while 'dataset' in data_cfg:
data_cfg = data_cfg.dataset
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@
from ote_sdk.usecases.tasks.interfaces.inference_interface import IInferenceTask
from ote_sdk.usecases.tasks.interfaces.unload_interface import IUnload
from ote_sdk.serialization.label_mapper import label_schema_to_bytes
from ote_sdk.utils.argument_checks import (
DatasetParamTypeCheck,
check_input_parameters_type,
)

from mmdet.apis import export_model
from detection_tasks.apis.detection.config_utils import patch_config, prepare_for_testing, set_hyperparams
Expand All @@ -63,6 +67,7 @@ class OTEDetectionInferenceTask(IInferenceTask, IExportTask, IEvaluationTask, IU

_task_environment: TaskEnvironment

@check_input_parameters_type()
def __init__(self, task_environment: TaskEnvironment):
""""
Task for inference object detection models using OTEDetection.
Expand Down Expand Up @@ -239,6 +244,7 @@ def _add_predictions_to_dataset(self, prediction_results, dataset, confidence_th
dataset_item.append_metadata_item(active_score, model=self._task_environment.model)


@check_input_parameters_type({"dataset": DatasetParamTypeCheck})
def infer(self, dataset: DatasetEntity, inference_parameters: Optional[InferenceParameters] = None) -> DatasetEntity:
""" Analyzes a dataset using the latest inference model. """

Expand Down Expand Up @@ -330,7 +336,7 @@ def dummy_dump_features_hook(mod, inp, out):
eval_predictions = zip(eval_predictions, feature_vectors)
return eval_predictions, metric


@check_input_parameters_type()
def evaluate(self,
output_result_set: ResultSetEntity,
evaluation_metric: Optional[str] = None):
Expand Down Expand Up @@ -375,6 +381,7 @@ def unload(self):
logger.warning(f"Done unloading. "
f"Torch is still occupying {torch.cuda.memory_allocated()} bytes of GPU memory")

@check_input_parameters_type()
def export(self,
export_type: ExportType,
output_model: ModelEntity):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@
from ote_sdk.usecases.tasks.interfaces.export_interface import ExportType
from ote_sdk.usecases.tasks.interfaces.optimization_interface import IOptimizationTask
from ote_sdk.usecases.tasks.interfaces.optimization_interface import OptimizationType
from ote_sdk.utils.argument_checks import (
DatasetParamTypeCheck,
check_input_parameters_type,
)

from mmdet.apis import train_detector
from mmdet.apis.fake_input import get_fake_input
Expand All @@ -59,6 +63,7 @@

class OTEDetectionNNCFTask(OTEDetectionInferenceTask, IOptimizationTask):

@check_input_parameters_type()
def __init__(self, task_environment: TaskEnvironment):
""""
Task for compressing object detection models using NNCF.
Expand Down Expand Up @@ -177,12 +182,13 @@ def _create_compressed_model(self, dataset, config):
get_fake_input_func=get_fake_input,
is_accuracy_aware=is_acc_aware_training_set)

@check_input_parameters_type({"dataset": DatasetParamTypeCheck})
def optimize(
self,
optimization_type: OptimizationType,
dataset: DatasetEntity,
output_model: ModelEntity,
optimization_parameters: Optional[OptimizationParameters],
optimization_parameters: Optional[OptimizationParameters] = None,
):
if optimization_type is not OptimizationType.NNCF:
raise RuntimeError("NNCF is the only supported optimization")
Expand Down Expand Up @@ -247,6 +253,7 @@ def optimize(

self._is_training = False

@check_input_parameters_type()
def export(self, export_type: ExportType, output_model: ModelEntity):
if self._compression_ctrl is None:
super().export(export_type, output_model)
Expand All @@ -256,6 +263,7 @@ def export(self, export_type: ExportType, output_model: ModelEntity):
super().export(export_type, output_model)
self._model.enable_dynamic_graph_building()

@check_input_parameters_type()
def save_model(self, output_model: ModelEntity):
buffer = io.BytesIO()
hyperparams = self._task_environment.get_hyper_parameters(OTEDetectionConfig)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,19 @@
from ote_sdk.usecases.exportable_code.inference import BaseInferencer
from ote_sdk.usecases.exportable_code.prediction_to_annotation_converter import (
DetectionBoxToAnnotationConverter,
IPredictionToAnnotationConverter,
MaskToAnnotationConverter,
RotatedRectToAnnotationConverter,
)
from ote_sdk.usecases.tasks.interfaces.deployment_interface import IDeploymentTask
from ote_sdk.usecases.tasks.interfaces.evaluate_interface import IEvaluationTask
from ote_sdk.usecases.tasks.interfaces.inference_interface import IInferenceTask
from ote_sdk.usecases.tasks.interfaces.optimization_interface import IOptimizationTask, OptimizationType
from ote_sdk.utils.argument_checks import (
DatasetParamTypeCheck,
check_input_parameters_type,
)
from shutil import copyfile, copytree
from typing import Any, Dict, List, Optional, Tuple, Union
from zipfile import ZipFile

Expand All @@ -66,24 +72,29 @@

class BaseInferencerWithConverter(BaseInferencer):

def __init__(self, configuration, model, converter) -> None:
@check_input_parameters_type()
def __init__(self, configuration: dict, model: Model, converter: IPredictionToAnnotationConverter) -> None:
self.configuration = configuration
self.model = model
self.converter = converter

@check_input_parameters_type()
def pre_process(self, image: np.ndarray) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
return self.model.preprocess(image)

@check_input_parameters_type()
def post_process(self, prediction: Dict[str, np.ndarray], metadata: Dict[str, Any]) -> AnnotationSceneEntity:
detections = self.model.postprocess(prediction, metadata)

return self.converter.convert_to_annotation(detections, metadata)

@check_input_parameters_type()
def forward(self, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
return self.model.infer_sync(inputs)


class OpenVINODetectionInferencer(BaseInferencerWithConverter):
@check_input_parameters_type()
def __init__(
self,
hparams: OTEDetectionConfig,
Expand Down Expand Up @@ -115,6 +126,7 @@ def __init__(


class OpenVINOMaskInferencer(BaseInferencerWithConverter):
@check_input_parameters_type()
def __init__(
self,
hparams: OTEDetectionConfig,
Expand Down Expand Up @@ -149,6 +161,7 @@ def __init__(


class OpenVINORotatedRectInferencer(BaseInferencerWithConverter):
@check_input_parameters_type()
def __init__(
self,
hparams: OTEDetectionConfig,
Expand Down Expand Up @@ -183,11 +196,13 @@ def __init__(


class OTEOpenVinoDataLoader(DataLoader):
@check_input_parameters_type({"dataset": DatasetParamTypeCheck})
def __init__(self, dataset: DatasetEntity, inferencer: BaseInferencer):
self.dataset = dataset
self.inferencer = inferencer

def __getitem__(self, index):
@check_input_parameters_type()
def __getitem__(self, index: int):
image = self.dataset[index].numpy
annotation = self.dataset[index].annotation_scene
inputs, metadata = self.inferencer.pre_process(image)
Expand All @@ -199,6 +214,7 @@ def __len__(self):


class OpenVINODetectionTask(IDeploymentTask, IInferenceTask, IEvaluationTask, IOptimizationTask):
@check_input_parameters_type()
def __init__(self, task_environment: TaskEnvironment):
logger.info('Loading OpenVINO OTEDetectionTask')
self.task_environment = task_environment
Expand Down Expand Up @@ -230,6 +246,7 @@ def load_inferencer(self) -> Union[OpenVINODetectionInferencer, OpenVINOMaskInfe
return OpenVINORotatedRectInferencer(*args)
raise RuntimeError(f"Unknown OpenVINO Inferencer TaskType: {self.task_type}")

@check_input_parameters_type({"dataset": DatasetParamTypeCheck})
def infer(self, dataset: DatasetEntity, inference_parameters: Optional[InferenceParameters] = None) -> DatasetEntity:
logger.info('Start OpenVINO inference')
update_progress_callback = default_progress_callback
Expand All @@ -243,6 +260,7 @@ def infer(self, dataset: DatasetEntity, inference_parameters: Optional[Inference
logger.info('OpenVINO inference completed')
return dataset

@check_input_parameters_type()
def evaluate(self,
output_result_set: ResultSetEntity,
evaluation_metric: Optional[str] = None):
Expand All @@ -252,6 +270,7 @@ def evaluate(self,
output_result_set.performance = MetricsHelper.compute_f_measure(output_result_set).get_performance()
logger.info('OpenVINO metric evaluation completed')

@check_input_parameters_type()
def deploy(self,
output_model: ModelEntity) -> None:
logger.info('Deploying the model')
Expand Down Expand Up @@ -279,11 +298,12 @@ def deploy(self,
output_model.exportable_code = zip_buffer.getvalue()
logger.info('Deploying completed')

@check_input_parameters_type({"dataset": DatasetParamTypeCheck})
def optimize(self,
optimization_type: OptimizationType,
dataset: DatasetEntity,
output_model: ModelEntity,
optimization_parameters: Optional[OptimizationParameters]):
optimization_parameters: Optional[OptimizationParameters] = None):
logger.info('Start POT optimization')

if optimization_type is not OptimizationType.POT:
Expand Down
Loading

0 comments on commit f003a20

Please sign in to comment.