Skip to content

Commit

Permalink
fix mmcls get classes (#215)
Browse files Browse the repository at this point in the history
* fix mmcls get classes

* resolve comment

* resolve comment
  • Loading branch information
RunningLeon authored Mar 9, 2022
1 parent 937985e commit 120f4ac
Showing 1 changed file with 27 additions and 13 deletions.
40 changes: 27 additions & 13 deletions mmdeploy/codebase/mmcls/deploy/classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
import mmcv
import numpy as np
import torch
from mmcls.datasets import DATASETS
from mmcls.models.classifiers.base import BaseClassifier
from mmcv.utils import Registry

from mmdeploy.codebase.base import BaseBackendModel
from mmdeploy.utils import (Backend, get_backend, get_codebase_config,
load_config)
get_root_logger, load_config)


def __build_backend_model(cls_name: str, registry: Registry, *args, **kwargs):
Expand Down Expand Up @@ -150,20 +149,35 @@ def get_classes_from_config(model_cfg: Union[str, mmcv.Config]):
Returns:
list[str]: A list of string specifying names of different class.
"""
model_cfg = load_config(model_cfg)[0]
from mmcls.datasets import DATASETS

module_dict = DATASETS.module_dict
model_cfg = load_config(model_cfg)[0]
data_cfg = model_cfg.data

if 'train' in data_cfg:
module = module_dict[data_cfg.train.type]
elif 'val' in data_cfg:
module = module_dict[data_cfg.val.type]
elif 'test' in data_cfg:
module = module_dict[data_cfg.test.type]
else:
raise RuntimeError(f'No dataset config found in: {model_cfg}')

return module.CLASSES
def _get_class_names(dataset_type: str):
dataset = data_cfg.get(dataset_type, None)
if (not dataset) or (dataset.type not in module_dict):
return None

module = module_dict[dataset.type]
if module.CLASSES is not None:
return module.CLASSES
return module.get_classes(dataset.get('classes', None))

class_names = None
for dataset_type in ['val', 'test', 'train']:
class_names = _get_class_names(dataset_type)
if class_names is not None:
break

if class_names is None:
logger = get_root_logger()
logger.warning(f'Use generated class names, because \
it failed to parse CLASSES from config: {data_cfg}')
num_classes = model_cfg.model.head.num_classes
class_names = [str(i) for i in range(num_classes)]
return class_names


def build_classification_model(model_files: Sequence[str],
Expand Down

0 comments on commit 120f4ac

Please sign in to comment.