diff --git a/docs/source/tutorials/data.rst b/docs/source/tutorials/data.rst index 9e58c4c..0d16b83 100644 --- a/docs/source/tutorials/data.rst +++ b/docs/source/tutorials/data.rst @@ -12,7 +12,7 @@ For accurate model evaluation, it is necessary to split the data into subsets su validation, and test sets for training and evaluating the model. The **cvtk** package provides a convenient command for splitting a single dataset into multiple subsets. -If the dataset is saved in a text file, use the ``cvtk split`` command. +If the dataset is saved in a text file, use the ``cvtk text-split`` command. For example, suppose you have a tab-delimited text file :file:`data.txt` with the image file paths in the first column and the label names in the second column: @@ -36,7 +36,7 @@ Note that, by adding the ``--shuffle`` argument, the data is shuffled before spl .. code-block:: sh - cvtk split --input data.txt --output data_subset.txt --ratios 6:2:2 --shuffle + cvtk text-split --input data.txt --output data_subset.txt --ratios 6:2:2 --shuffle The command generates the files @@ -66,7 +66,7 @@ to ensure that each subset has a uniform class distribution. .. code-block:: sh - cvtk split --input all.txt --output data_subset.txt --ratios 6:2:2 --shuffle --stratify + cvtk text-split --input all.txt --output data_subset.txt --ratios 6:2:2 --shuffle --stratify diff --git a/src/cvtk/__init__.py b/src/cvtk/__init__.py index d8dc6ea..18b9f68 100644 --- a/src/cvtk/__init__.py +++ b/src/cvtk/__init__.py @@ -1,4 +1,4 @@ -__version__ = '0.2.15' +__version__ = '0.2.16' from ._base import imread, imconvert, imwrite, imshow, imlist, imresize from ._base import Annotation, Image, ImageDeck, JsonComplexEncoder diff --git a/src/cvtk/ls/_base.py b/src/cvtk/ls/_base.py index acdda99..520ea88 100644 --- a/src/cvtk/ls/_base.py +++ b/src/cvtk/ls/_base.py @@ -8,7 +8,7 @@ import importlib from ..ml.torchutils import __generate_source as generate_source_cls from ..ml.mmdetutils import __generate_source as generate_source_det -from ..ml._subutils import __estimate_task_from_source, __generate_app_html_tmpl +from ..ml._subutils import __estimate_source_task, __estimate_source_vanilla, __generate_app_html_tmpl import label_studio_sdk @@ -99,7 +99,6 @@ def export(project: int, img['file_name'] = os.path.join(ls_data_root, img['file_name']) img['file_name'] = urllib.parse.unquote(img['file_name']) - print(img['file_name']) with open(output, 'w') as f: json.dump(exported_data, f, indent=indent, ensure_ascii=ensure_ascii) @@ -116,7 +115,6 @@ def generate_app(project: str, source: str, label: str, model: str, weights: str if not os.path.exists(project): os.makedirs(project) - coremodule = os.path.splitext(os.path.basename(source))[0] data_label = os.path.basename(label) model_cfg = os.path.basename(model) @@ -128,28 +126,25 @@ def generate_app(project: str, source: str, label: str, model: str, weights: str shutil.copy2(model, os.path.join(project, model_cfg)) shutil.copy2(weights, os.path.join(project, model_weights)) - source_task_type, source_is_vanilla = __estimate_task_from_source(source) + source_task_type = __estimate_source_task(source) + source_is_vanilla = __estimate_source_vanilla(source) # FastAPI script - tmpl = __generate_app_html_tmpl(importlib.resources.files('cvtk').joinpath(f'tmpl/_ls_backend.py'), source_task_type) + tmpl = __generate_app_html_tmpl(importlib.resources.files('cvtk').joinpath(f'tmpl/_ls_backend.py'), + source_task_type) + if vanilla: if source_is_vanilla: for i in range(len(tmpl)): - if tmpl[i][:9] == 'from cvtk': - if source_task_type == 'cls': - tmpl[i] = f'from {coremodule} import CLSCORE as MODULECORE' - elif source_task_type == 'det': - tmpl[i] = f'from {coremodule} import MMDETCORE as MODULECORE' - else: - raise ValueError('Unsupport Type.') + if (tmpl[i][:9] == 'from cvtk') and ('import ModuleCore' in tmpl[i]): + tmpl[i] = f'from {coremodule} import ModuleCore' else: - print('The CLSCORE or MMDETCORE class definition is not found in the source code. The script will be generated with importation of cvtk.') + # user specified vanilla, but the source code for CV task is not vanilla + print('The `ModuleCore` class definition is not found in the source code. `ModuleCore` will be generated with importation of cvtk regardless vanilla is specified.') + tmpl = ''.join(tmpl) tmpl = tmpl.replace('__DATALABEL__', data_label) tmpl = tmpl.replace('__MODELCFG__', model_cfg) tmpl = tmpl.replace('__MODELWEIGHT__', model_weights) with open(os.path.join(project, 'main.py'), 'w') as fh: fh.write(tmpl) - - - \ No newline at end of file diff --git a/src/cvtk/ml/__init__.py b/src/cvtk/ml/__init__.py index 4e634ca..348e6d6 100644 --- a/src/cvtk/ml/__init__.py +++ b/src/cvtk/ml/__init__.py @@ -1 +1 @@ -from ._base import split_dataset, generate_source, generate_app +from ._base import split_dataset, generate_source, generate_demoapp diff --git a/src/cvtk/ml/_base.py b/src/cvtk/ml/_base.py index 7e4d213..2cc9046 100644 --- a/src/cvtk/ml/_base.py +++ b/src/cvtk/ml/_base.py @@ -6,7 +6,7 @@ import re from .torchutils import __generate_source as generate_source_cls from .mmdetutils import __generate_source as generate_source_det -from ._subutils import __estimate_task_from_source, __generate_app_html_tmpl +from ._subutils import __estimate_source_task, __estimate_source_vanilla, __generate_app_html_tmpl def split_dataset(data: str|list[str, str]|tuple[str, str], @@ -24,9 +24,9 @@ def split_dataset(data: str|list[str, str]|tuple[str, str], or a path to a text file. If list is given, each element of the list is treated as a sample. output: The output file name will be appended with the index of the split subset. ratios: The ratios to split the dataset. The sum of the ratios should be 1. - shuffle (bool): Shuffle the dataset before splitting. - stratify (bool): Split the dataset with a balanced class distribution if `label` is given. - random_seed (int|none): Random seed for shuffling the dataset. + shuffle: Shuffle the dataset before splitting. + stratify: Split the dataset with a balanced class distribution if `label` is given. + random_seed: Random seed for shuffling the dataset. Returns: A list of the split datasets. The length of the list is the same as the length of `ratios`. @@ -113,8 +113,8 @@ def split_dataset(data: str|list[str, str]|tuple[str, str], -def generate_source(project: str, task: str='cls', vanilla=False) -> None: - """Generate source code for training and inference of a classification model using PyTorch +def generate_source(project: str, task: str='cls', vanilla: bool=False) -> None: + """Generate source code for classification or detection tasks This function generates a Python script for training and inference of a model using PyTorch (for classification task) or MMDetection (for object detection and instance segmentation tasks). @@ -128,9 +128,9 @@ def generate_source(project: str, task: str='cls', vanilla=False) -> None: since all functions is implemented directly in torch and torchvision. Args: - project (str): A file path to save the script. - task (str): The task type of project. Three types of tasks can be specified ('cls', 'det', 'segm'). The default is 'cls'. - vanilla (bool): Generate a script without importation of cvtk. The default is False. + project: A file path to save the script. + task: The task type of project. Three types of tasks can be specified ('cls', 'det', 'segm'). The default is 'cls'. + vanilla: Generate a script without importation of cvtk. The default is False. """ if task.lower() in ['cls', 'classification']: @@ -141,18 +141,18 @@ def generate_source(project: str, task: str='cls', vanilla=False) -> None: raise ValueError('The current version only support classification (`cls`), detection (`det`), and segmentation (`segm`) tasks.') -def generate_app(project: str, source: str, label: str, model: str, weights: str, vanilla=False) -> None: +def generate_demoapp(project: str, source: str, label: str, model: str, weights: str, vanilla: bool=False) -> None: """Generate a FastAPI application for inference of a classification or detection model This function generates a FastAPI application for inference of a classification or detection model. Args: - project (str): A file path to save the FastAPI application. - source (str): The source code of the model. - label (str): The label file of the dataset. - model (str): The configuration file of the model. - weights (str): The weights file of the model. - module (str): Script with importation of cvtk ('cvtk') or not ('fastapi'). + project: A file path to save the FastAPI application. + source: The source code of the model. + label: The label file of the dataset. + model: The configuration file of the model. + weights: The weights file of the model. + module: Script with importation of cvtk ('cvtk') or not ('fastapi'). Examples: >>> from cvtk.ml import generate_app @@ -173,22 +173,20 @@ def generate_app(project: str, source: str, label: str, model: str, weights: str shutil.copy2(model, os.path.join(project, model_cfg)) shutil.copy2(weights, os.path.join(project, model_weights)) - source_task_type, source_is_vanilla = __estimate_task_from_source(source) + source_task_type = __estimate_source_task(source) + source_is_vanilla = __estimate_source_vanilla(source) # FastAPI script - tmpl = __generate_app_html_tmpl(importlib.resources.files('cvtk').joinpath(f'tmpl/_flask.py'), source_task_type) + tmpl = __generate_app_html_tmpl(importlib.resources.files('cvtk').joinpath(f'tmpl/_flask.py'), + source_task_type) if vanilla: if source_is_vanilla: for i in range(len(tmpl)): - if tmpl[i][:9] == 'from cvtk': - if source_task_type == 'cls': - tmpl[i] = f'from {coremodule} import CLSCORE as MODULECORE' - elif source_task_type == 'det': - tmpl[i] = f'from {coremodule} import MMDETCORE as MODULECORE' - else: - raise ValueError('Unsupport Type.') + if (tmpl[i][:9] == 'from cvtk') and ('import ModuleCore' in tmpl[i]): + tmpl[i] = f'from {coremodule} import ModuleCore' else: - print('The CLSCORE or MMDETCORE class definition is not found in the source code. The script will be generated with importation of cvtk.') + # user specified vanilla, but the source code for CV task is not vanilla + print('The `ModuleCore` class definition is not found in the source code. `ModuleCore` will be generated with importation of cvtk regardless vanilla is specified.') tmpl = ''.join(tmpl) tmpl = tmpl.replace('__DATALABEL__', data_label) tmpl = tmpl.replace('__MODELCFG__', model_cfg) @@ -199,7 +197,6 @@ def generate_app(project: str, source: str, label: str, model: str, weights: str # HTML template if not os.path.exists(os.path.join(project, 'templates')): os.makedirs(os.path.join(project, 'templates')) - tmpl = __generate_app_html_tmpl(importlib.resources.files('cvtk').joinpath(f'tmpl/html/fastapi_.html'), source_task_type) + tmpl = __generate_app_html_tmpl(importlib.resources.files('cvtk').joinpath(f'tmpl/html/_flask.html'), source_task_type) with open(os.path.join(project, 'templates', 'index.html'), 'w') as fh: fh.write(''.join(tmpl)) - \ No newline at end of file diff --git a/src/cvtk/ml/_subutils.py b/src/cvtk/ml/_subutils.py index d0cb778..89d8d05 100644 --- a/src/cvtk/ml/_subutils.py +++ b/src/cvtk/ml/_subutils.py @@ -7,66 +7,44 @@ def __get_imports(code_file: str) -> list[str]: - """Find lines containing import statements from a file. - - Args: - code_file (str): Path to a python file. - """ imports = [] - with open(code_file, 'r') as codefh: for codeline in codefh: if codeline[0:6] == 'import': imports.append(codeline) + if codeline[0:6] == 'class ' or codeline[0:4] == 'def ': + break return imports def __insert_imports(tmpl: list[str], modules: list[str]) -> list[str]: - """Insert import statements to a template. - - Insert import statements to a template (`tmpl`) at the end of the import statements in the template. - - Args: - tmpl (list[str]): A list of strings containing the template. - modules (list[str]): A list of strings containing import statements. - - Examples: - >>> tmpl = ['import os', - ... '', - ... 'print("Hello, World!")'], - >>> modules = ['import cvtk'] - >>> __insert_imports(tmpl, modules) - ['import os', - 'import cvtk', - '', - 'print("Hello, World!")'] - >>> - >>> - >>> tmpl = __insert_imports(tmpl, __get_imports(__file__) - >>> + """Insert imports into the top of the template code. + + This function deletes the original imports in the template (`tmpl`) + and then inserts modules listed in `modules` argument. """ extmpl = [] imported = False for codeline in tmpl: if codeline[0:6] == 'import': - # pass the imports in original file - pass + pass # delete the original imports else: if not imported: - # insert the imports + # insert the modules listed in `modules` argument for mod in modules: extmpl.append(mod) imported = True + # append the original code after the imports extmpl.append(codeline) return extmpl def __extend_cvtk_imports(tmpl, module_dicts): extmpl = [] - extended = False for codeline in tmpl: if codeline[0:9] == 'from cvtk': + # find the first cvtk import statement and replace it with the source code of the modules if not extended: for mod_dict in module_dicts: for mod_name, mod_funcs in mod_dict.items(): @@ -74,19 +52,12 @@ def __extend_cvtk_imports(tmpl, module_dicts): extmpl.append('\n\n\n' + inspect.getsource(mod_func)) extended = True else: + # append the original code after the extending of cvtk module source code extmpl.append(codeline) - return extmpl def __del_docstring(func_source: str) -> str: - """Delete docstring from source code. - - Delete docstring (strings between \"\"\" or ''' ) from source code. - - Args: - func_source (str): Source code of a function. - """ func_source_ = '' is_docstring = False omit = False @@ -97,6 +68,7 @@ def __del_docstring(func_source: str) -> str: is_docstring = not is_docstring else: if not is_docstring: + line = line.replace('\\\\', '\\') func_source_ += line + '\n' return func_source_ @@ -125,47 +97,22 @@ def __generate_app_html_tmpl(tmpl_fpath, task): return tmpl -def __estimate_task_from_source(source): - task_ = { - 'CLSCORE': {'classdef': 0, 'import': 0, 'call': 0}, - 'MMDETCORE': {'classdef': 0, 'import': 0, 'call': 0} - } +def __estimate_source_task(source): + spec = importlib.util.spec_from_file_location('ModuleCore', source) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + ModuleCoreClass = getattr(module, 'ModuleCore', None) + if ModuleCoreClass is None: + raise AttributeError("ModuleCore class not found in the specified source file.") + module_instance = ModuleCoreClass(None, None) + return module_instance.task_type + + +def __estimate_source_vanilla(source): + is_vanilla = True with open(source, 'r') as infh: for codeline in infh: - if 'class CLSCORE' in codeline: - task_['CLSCORE']['classdef'] += 1 - elif 'import cvtk.ml.torchuitls' in codeline: - task_['CLSCORE']['import'] += 1 - elif 'from cvtk.ml.torchutils import' in codeline and 'CLSCORE' in codeline: - task_['CLSCORE']['import'] += 1 - elif 'CLSCORE(' in codeline: - task_['CLSCORE']['call'] += 1 - elif 'class MMDETCORE' in codeline: - task_['MMDETCORE']['classdef'] += 1 - elif 'import cvtk.ml.mmdetutils' in codeline: - task_['MMDETCORE']['import'] += 1 - elif 'from cvtk.ml.mmdetutils import' in codeline and 'MMDETCORE' in codeline: - task_['MMDETCORE']['import'] += 1 - elif 'MMDETCORE(' in codeline: - task_['MMDETCORE']['call'] += 1 - is_cls_cvtk = (task_['CLSCORE']['import'] > 0) and (task_['CLSCORE']['call'] > 0) - is_cls_vanilla = (task_['CLSCORE']['classdef'] > 0) and (task_['CLSCORE']['call'] > 0) - is_cls = is_cls_cvtk or is_cls_vanilla - is_det_cvtk = (task_['MMDETCORE']['import'] > 0) and (task_['MMDETCORE']['call'] > 0) - is_det_vanilla = (task_['MMDETCORE']['classdef'] > 0) and (task_['MMDETCORE']['call'] > 0) - is_det = is_det_cvtk or is_det_vanilla - - if is_cls and (not is_det): - task = 'cls' - elif (not is_cls) and is_det: - task = 'det' - else: - raise ValueError('The task type cannot be determined from the source code. Make sure your source code contains CLSCORE or MMDETCORE class definition or importation, and call.') - if is_cls_cvtk or is_det_cvtk: - is_vanilla = False - elif is_cls_vanilla or is_det_vanilla: - is_vanilla = True - else: - raise ValueError('The source code cannot be determined from the source code. Make sure your source code contains importation of cvtk.ml.torchutils or cvtk.ml.mmdetutils.') - - return task, is_vanilla + if ('import cvtk' in codeline) or ('from cvtk' in codeline): + is_vanilla = False + break + return is_vanilla diff --git a/src/cvtk/ml/data.py b/src/cvtk/ml/data.py index ac3d8a8..b56b2ce 100644 --- a/src/cvtk/ml/data.py +++ b/src/cvtk/ml/data.py @@ -15,7 +15,7 @@ class DataLabel(): Methods implemented in the class provide a way to get the class index from the class name and vice versa. Args: - labels (tuple|list|str): A tuple or list, + labels: A tuple or list, or a path to a text file or coco format containing class labels. Text file should contain one class name per line. @@ -41,7 +41,7 @@ class DataLabel(): >>> print(DataLabel['flower']) 1 """ - def __init__(self, labels): + def __init__(self, labels: list|tuple|str): if isinstance(labels, list) or isinstance(labels, tuple): self.__labels = labels elif isinstance(labels, str): @@ -134,7 +134,7 @@ class SquareResize(): [0.229, 0.224, 0.225]) ]) """ - def __init__(self, shape=600, bg_color = None, resample=PIL.Image.BILINEAR): + def __init__(self, shape: int=600, bg_color: tuple[int, int, int]|None=None, resample: object=PIL.Image.BILINEAR): self.shape = shape self.bg_color = bg_color self.resample = resample diff --git a/src/cvtk/ml/mmdetutils.py b/src/cvtk/ml/mmdetutils.py index 8e46507..638d27b 100644 --- a/src/cvtk/ml/mmdetutils.py +++ b/src/cvtk/ml/mmdetutils.py @@ -276,25 +276,25 @@ def cfg(self): return self.__cfg -class MMDETCORE(): +class ModuleCore(): """A class for object detection and instance segmentation This class provides user-friendly APIs for object detection and instance segmentation using MMDetection. There are four main methods are implemented in this class: - :func:`train `, - :func:`test `, - :func:`save `, - :func:`inference `. - The :func:`train ` method is used for training the model + :func:`train `, + :func:`test `, + :func:`save `, + :func:`inference `. + The :func:`train ` method is used for training the model and perform validation and test if validation and test data are provided. - The :func:`test ` method is used for testing the model with test data. + The :func:`test ` method is used for testing the model with test data. In general, the performance test is performed automatically after the training, but user can also run the test independently from the training process with - the :func:`test ` method. - The :func:`save ` method is used for saving the model weights, + the :func:`test ` method. + The :func:`save ` method is used for saving the model weights, configuration (design of model architecture), training log (e.g., mAP and loss per epoch), and test results. - The :func:`inference ` method is used for running inference + The :func:`inference ` method is used for running inference with the trained model. The detailed usage of each method is described in the method documentation. @@ -321,14 +321,14 @@ class MMDETCORE(): Examples: >>> from cvtk.ml.data import DataLabel - >>> from cvtk.ml.mmdetutils import DataPipeline, Dataset, DataLoader, MMDETCORE + >>> from cvtk.ml.mmdetutils import DataPipeline, Dataset, DataLoader, ModuleCore >>> >>> datalabel = DataLabel(['leaf', 'flower', 'stem']) >>> cfg = 'faster_rcnn_r50_fpn_1x_coco' >>> weights = None # download from MMDetection repository >>> workspace = '/path/to/workspace' >>> - >>> model = MMDETCORE(datalabel, cfg, weights, workspace) + >>> model = ModuleCore(datalabel, cfg, weights, workspace) >>> >>> train = DataLoader('/path/to/train/coco.json', 'train') >>> model.train(train, epoch=10) @@ -340,23 +340,25 @@ def __init__(self, weights: str|None=None, workspace=None, seed=None): - self.device = 'cuda' if torch.cuda.is_available() else 'cpu' - self.datalabel = self.__init_datalabel(datalabel) - self.cfg = self.__init_cfg(cfg, weights, seed) - self.model = None - self.tempd, self.workspace = self.__init_tempdir(workspace) - self.mmdet_log_dpath = None - self.test_stats = None + self.task_type = 'det' + if not(datalabel is None and cfg is None): + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.datalabel = self.__init_datalabel(datalabel) + self.cfg = self.__init_cfg(cfg, weights, seed) + self.model = None + self.__tempdir_obj, self.workspace = self.__init_tempdir(workspace) + self.mmdet_log_dpath = None + self.test_stats = None def __del__(self): try: - if self.model is not None: + if hasattr(self, '__tempdir_obj') and (self.model is not None): del self.model torch.cuda.empty_cache() gc.collect() - if self.tempd is not None: - self.tempd.cleanup() + if hasattr(self, '__tempdir_obj') and (self.__tempdir_obj is not None): + self.__tempdir_obj.cleanup() except: logger.info(f'The temporary directory (`{self.workspace}`) created by cvtk ' f'cannot be removed automatically. Please remove it manually.') @@ -470,7 +472,7 @@ def train(self, the model will undergo a final evaluation at the end of training, and the test results will also be saved in the workspace. The test can also be performed independently from the training process, - seed the :func:`test ` method for more details. + seed the :func:`test ` method for more details. Args: train: A DataLoader class object. @@ -482,14 +484,14 @@ def train(self, Examples: >>> from cvtk.ml.data import DataLabel - >>> from cvtk.ml.mmdetutils import DataPipeline, Dataset, DataLoader, MMDETCORE + >>> from cvtk.ml.mmdetutils import DataPipeline, Dataset, DataLoader, ModuleCore >>> >>> datalabel = DataLabel(['leaf', 'flower', 'stem']) >>> cfg = 'faster_rcnn_r50_fpn_1x_coco' >>> weights = None # download from MMDetection repository >>> workspace = '/path/to/workspace' >>> - >>> model = MMDETCORE(datalabel, cfg, weights, workspace) + >>> model = ModuleCore(datalabel, cfg, weights, workspace) >>> >>> train = DataLoader(Dataset(datalabel, '/path/to/train/coco.json'), 'train') >>> model.train(train, epoch=10) @@ -578,13 +580,13 @@ def test(self, test:dict) -> dict: Examples: >>> from cvtk.ml.data import DataLabel - >>> from cvtk.ml.mmdetutils import DataPipeline, Dataset, DataLoader, MMDETCORE + >>> from cvtk.ml.mmdetutils import DataPipeline, Dataset, DataLoader, ModuleCore >>> >>> datalabel = DataLabel(['leaf', 'flower', 'stem']) >>> cfg = 'faster_rcnn_r50_fpn_1x_coco' >>> weights = '/path/to/model.pth' >>> - >>> model = MMDETCORE(datalabel, cfg, weights, workspace) + >>> model = ModuleCore(datalabel, cfg, weights, workspace) >>> >>> test = DataLoader('/path/to/test/coco.json', 'test') >>> metrics = model.test(test) @@ -937,7 +939,7 @@ def __generate_source(script_fpath, task, vanilla=False): {'cvtk': [JsonComplexEncoder, Annotation, Image, ImageDeck, imread]}, {'cvtk.format.coco': [calc_stats]}, {'cvtk.ml.data': [DataLabel]}, - {'cvtk.ml.mmdetutils': [DataPipeline, Dataset, DataLoader, MMDETCORE, plot_trainlog]} + {'cvtk.ml.mmdetutils': [DataPipeline, Dataset, DataLoader, ModuleCore, plot_trainlog]} ] tmpl = __insert_imports(tmpl, __get_imports(__file__)) tmpl = __extend_cvtk_imports(tmpl, cvtk_modules) diff --git a/src/cvtk/ml/torchutils.py b/src/cvtk/ml/torchutils.py index 173f8bb..2e5f72a 100644 --- a/src/cvtk/ml/torchutils.py +++ b/src/cvtk/ml/torchutils.py @@ -22,55 +22,58 @@ -def DataTransform(shape, is_train=False): - """Generate image preprocessing pipeline - - DataTransforms is a function to generate image preprocessing pipeline used in PyTorch. - By default (`is_train=False`), a pipeline for inference is generated, - which resizes images to a square shape and converts them to a tensor. - A pipeline for training is generated when `is_train=True`, - in which images are resized to a square shape with a specified resolution, - then processed with several fundamental augmentation processes, and finally converted to a tensor. - +class DataTransform(): + """Pipeline for preprocessing images + + The class composes several fundamental transforms for image preprocessing + and converts them to a `torchvision.transforms.Compose` instance. + It is intended for use by beginners. + If user wants to customize their own image preprocessing pipeline, + it is recommended to use `torchvision.transforms.Compose` directly. + Args: - shape (int): The resolution of the square image. - is_train (bool): If True, a pipeline for training is generated. Default is False. + shape: The resolution of preprocessed images. + is_train: Generate pipeline for trianing if True, otherwise for inference. + Pipeline for training includes random cropping, flipping, and rotation; + pipeline for inference only includes resizing and normalization. - Returns: - A torchvision.transforms.Compose instance containing the transform pipeline. - Examples: >>> from cvtk.ml.torchutils import DataTransform >>> >>> transform_train = DataTransform(224, is_train=True) - >>> print(transform_train) + >>> print(transform_train.pipeline) >>> >>> transform_inference = DataTransform(224) - >>> print(transforms_inference) + >>> print(transforms_inference.pipeline) """ - if is_train: - return torchvision.transforms.Compose([ - torchvision.transforms.v2.ToImage(), - torchvision.transforms.v2.Resize(size=(shape + 50, shape + 50), antialias=True), - torchvision.transforms.v2.RandomResizedCrop(size=(shape, shape), antialias=True), - torchvision.transforms.v2.RandomHorizontalFlip(0.5), - torchvision.transforms.v2.RandomAffine(45), - torchvision.transforms.v2.ToDtype(torch.float32, scale=True), - torchvision.transforms.v2.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225])]) - else: - return torchvision.transforms.Compose([ - torchvision.transforms.v2.ToImage(), - torchvision.transforms.v2.Resize(size=(shape, shape), antialias=True), - torchvision.transforms.v2.ToDtype(torch.float32, scale=True), - torchvision.transforms.v2.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225])]) - + def __init__(self, shape: int|tuple[int, int], is_train=False): + if isinstance(shape, int): + shape = (shape, shape) + elif isinstance(shape, list): + shape = tuple(shape) + + if is_train: + self.pipeline = torchvision.transforms.Compose([ + torchvision.transforms.v2.ToImage(), + torchvision.transforms.v2.Resize(size=(shape[0] + 50, shape[1] + 50), antialias=True), + torchvision.transforms.v2.RandomResizedCrop(size=shape, antialias=True), + torchvision.transforms.v2.RandomHorizontalFlip(0.5), + torchvision.transforms.v2.RandomAffine(45), + torchvision.transforms.v2.ToDtype(torch.float32, scale=True), + torchvision.transforms.v2.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225])]) + else: + self.pipeline = torchvision.transforms.Compose([ + torchvision.transforms.v2.ToImage(), + torchvision.transforms.v2.Resize(size=shape, antialias=True), + torchvision.transforms.v2.ToDtype(torch.float32, scale=True), + torchvision.transforms.v2.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225])]) class Dataset(torch.utils.data.Dataset): - """Generate dataset for training or testing with PyTorch - + """A class to manupulate image data for training and inference + Dataset is a class that generates a dataset for training or testing with PyTorch. It loads images from a directory (the subdirectories are recursively loaded), a list, a tuple, or a tab-separated (TSV) file. @@ -88,10 +91,10 @@ class Dataset(torch.utils.data.Dataset): In this class, upsampling is performed by specifying `upsampling=TRUE`. Args: - datalabel (DataLabel): A DataLabel instance. This datalabel is used to convert class labels to integers. - dataset (str|list|tuple): A path to a directory, a list, a tuple, or a TSV file. - transform (None|torchvision.transforms.Compose): A transform pipeline of image processing. - balance_train (bool): If True, the number of images in each class is balanced + datalabel: A DataLabel instance. This datalabel is used to convert class labels to integers. + dataset: A path to a directory, a list, a tuple, or a TSV file. + transform: A transform pipeline of image processing. + balance_train: If True, the number of images in each class is balanced Examples: >>> from cvtk.ml import DataLabel @@ -110,9 +113,11 @@ class Dataset(torch.utils.data.Dataset): """ def __init__(self, datalabel, - dataset, - transform=None, - upsampling=False): + dataset: str|list|tuple, + transform: torchvision.transforms.Compose|DataTransform|None=None, + upsampling: bool=False): + if transform is not None and isinstance(transform, DataTransform): + transform = transform.pipeline self.transform = transform self.upsampling = upsampling self.x , self.y = self.__load_images(dataset, datalabel) @@ -204,10 +209,10 @@ def __unbiased_classes(self, x, y): -def DataLoader(dataset, batch_size=32, num_workers=4, shuffle=False): +class DataLoader(torch.utils.data.DataLoader): """Create dataloader to manage data for training and inference - This function simply creates a torch.utils.data.DataLoader instance to manage data for training and inference. + This class simply creates a torch.utils.data.DataLoader instance to manage data for training and inference. Args: dataset (cvtk.ml.torchutils.DataSet): A dataset for training and inference. @@ -229,15 +234,15 @@ def DataLoader(dataset, batch_size=32, num_workers=4, shuffle=False): >>> dataloader = DataLoader(dataset, batch_size=32, num_workers=4) >>> """ - return torch.utils.data.DataLoader(dataset, - batch_size=batch_size, num_workers=num_workers, shuffle=shuffle) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) -class CLSCORE(): +class ModuleCore(): """A class provides training and inference functions for a classification model using PyTorch - CLSCORE is a class that provides training and inference functions for a classification model. + ModuleCore is a class that provides training and inference functions for a classification model. Args: datalabel (str|list|tuple|DataLabel): A DataLabel instance containing class labels. @@ -258,29 +263,29 @@ class CLSCORE(): Examples: >>> import torch >>> import torchvision - >>> from cvtk.ml.torchutils import CLSCORE + >>> from cvtk.ml.torchutils import ModuleCore >>> >>> datalabel = ['leaf', 'flower', 'root'] - >>> m = CLSCORE(datalabel, 'efficientnet_b7', 'EfficientNet_B7_Weights.DEFAULT') + >>> m = ModuleCore(datalabel, 'efficientnet_b7', 'EfficientNet_B7_Weights.DEFAULT') >>> >>> datalabel = 'class_label.txt' - >>> m = CLSCORE(datalabel, 'efficientnet_b7', 'EfficientNet_B7_Weights.DEFAULT') + >>> m = ModuleCore(datalabel, 'efficientnet_b7', 'EfficientNet_B7_Weights.DEFAULT') """ def __init__(self, datalabel, model, weights=None, workspace=None): - self.device = 'cuda' if torch.cuda.is_available() else 'cpu' - self.datalabel = self.__init_datalabel(datalabel) - self.model = self.__init_model(model, weights, len(self.datalabel.labels)) - self.workspace = self.__init_tempdir(workspace) - - self.model = self.model.to(self.device) - - self.train_stats = None - self.test_stats = None + self.task_type = 'cls' + if not(datalabel is None and model is None): + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.datalabel = self.__init_datalabel(datalabel) + self.model = self.__init_model(model, weights, len(self.datalabel.labels)) + self.workspace = self.__init_tempdir(workspace) + self.model = self.model.to(self.device) + self.train_stats = None + self.test_stats = None def __del__(self): try: - if self.model is not None: + if hasattr(self, '__tempdir_obj') and (self.model is not None): del self.model torch.cuda.empty_cache() gc.collect() @@ -305,6 +310,8 @@ def __init_model(self, model, weights, n_classes): else: if os.path.exists(weights): module = eval(f'torchvision.models.{model}(weights=None)') + elif weights == 'DEFAULT' or weights == 'IMAGENET1K_V1': + module = eval(f'torchvision.models.{model}(weights="{weights}")') else: module = eval(f'torchvision.models.{model}(weights=torchvision.models.{weights})') @@ -388,11 +395,11 @@ def train(self, train, valid=None, test=None, epoch=20, optimizer=None, criterio Examples: >>> import torch >>> from cvtk.ml import DataLabel - >>> from cvtk.ml.torchutils import DataTransform, Dataset, DataLoader, CLSCORE + >>> from cvtk.ml.torchutils import DataTransform, Dataset, DataLoader, ModuleCore >>> >>> datalabel = DataLabel(['leaf', 'flower', 'root']) >>> - >>> model = CLSCORE(datalabel, 'efficientnet_b7', 'EfficientNet_B7_Weights.DEFAULT') + >>> model = ModuleCore(datalabel, 'efficientnet_b7', 'EfficientNet_B7_Weights.DEFAULT') >>> >>> # train dataset >>> transforms_train = DataTransform(600, is_train=True) @@ -440,13 +447,7 @@ def train(self, train, valid=None, test=None, epoch=20, optimizer=None, criterio # test the model if dataset is provided at the last epoch if epoch_i == epoch and dataloaders['test'] is not None: - loss, acc, probs = self.__train(dataloaders['test'], phase, criterion, optimizer) - self.test_stats = { - 'dataset': dataloaders['test'].dataset, - 'loss': loss, - 'acc': acc, - 'probs': probs - } + self.test(dataloaders['test'], criterion) if self.workspace is not None: self.save(os.path.join(self.workspace, f'checkpoint_latest.pth')) @@ -500,7 +501,8 @@ def __train(self, dataloader, phase, criterion, optimizer): inputs = inputs.to(self.device) labels = labels.to(self.device) - optimizer.zero_grad() + if phase == 'train': + optimizer.zero_grad() with torch.set_grad_enabled(phase == 'train'): outputs = self.model(inputs) _, preds = torch.max(outputs, 1) @@ -536,10 +538,10 @@ def save(self, output): Examples: >>> import torch >>> from cvtk.ml import DataLabel - >>> from cvtk.ml.torchutils import DataTransform, Dataset, DataLoader, CLSCORE + >>> from cvtk.ml.torchutils import DataTransform, Dataset, DataLoader, ModuleCore >>> >>> datalabel = DataLabel(['leaf', 'flower', 'root']) - >>> model = CLSCORE(datalabel, 'efficientnet_b7', 'EfficientNet_B7_Weights.DEFAULT') + >>> model = ModuleCore(datalabel, 'efficientnet_b7', 'EfficientNet_B7_Weights.DEFAULT') >>> >>> # training >>> # ... @@ -589,6 +591,28 @@ def __write_test_outputs(self, output_log_fpath): '\t'.join([str(_) for _ in p_]))) + def test(self, dataloader, criterion=None): + """Test the model with the provided dataloader + + Test the model with the provided dataloader. + + Args: + data (torch.utils.data.DataLoader): A dataloader for testing. + criterion (torch.nn.Module|None): A loss function for training. + Default is `None` and `torch.nn.CrossEntropyLoss` is used. + + """ + self.model.eval() + criterion = torch.nn.CrossEntropyLoss() if criterion is None else criterion + loss, acc, probs = self.__train(dataloader, 'test', criterion, None) + self.test_stats = { + 'dataset': dataloader.dataset, + 'loss': loss, + 'acc': acc, + 'probs': probs + } + return self.test_stats + def inference(self, data, value='prob+label', format='pandas', batch_size=32, num_workers=8): """Perform inference with the input images @@ -606,11 +630,11 @@ def inference(self, data, value='prob+label', format='pandas', batch_size=32, nu Examples: >>> import torch >>> from cvtk.ml import DataLabel - >>> from cvtk.ml.torchutils import DataTransform, Dataset, DataLoader, CLSCORE + >>> from cvtk.ml.torchutils import DataTransform, Dataset, DataLoader, ModuleCore >>> >>> datalabel = DataLabel(['leaf', 'flower', 'root']) >>> - >>> model = CLSCORE(datalabel, 'efficientnet_b7', 'plant_organs.pth') + >>> model = ModuleCore(datalabel, 'efficientnet_b7', 'plant_organs.pth') >>> >>> transform = DataTransform(600) >>> dataset = Dataset(datalabel, 'sample.jpg', transform) @@ -805,7 +829,7 @@ def __generate_source(script_fpath, vanilla=False): if vanilla is True: cvtk_modules = [ {'cvtk.ml.data': [DataLabel]}, - {'cvtk.ml.torchutils': [DataTransform, Dataset, DataLoader, CLSCORE, plot_trainlog, plot_cm]} + {'cvtk.ml.torchutils': [DataTransform, Dataset, DataLoader, ModuleCore, plot_trainlog, plot_cm]} ] tmpl = __insert_imports(tmpl, __get_imports(__file__)) tmpl = __extend_cvtk_imports(tmpl, cvtk_modules) diff --git a/src/cvtk/scripts/cvtk.py b/src/cvtk/scripts/cvtk.py index 269c64a..7db4ea6 100644 --- a/src/cvtk/scripts/cvtk.py +++ b/src/cvtk/scripts/cvtk.py @@ -7,17 +7,17 @@ -def create(args): +def generate_task_source(args): cvtk.ml.generate_source(args.script, task=args.task, vanilla=args.vanilla) -def app(args): - cvtk.ml.generate_app(args.project, - source=args.source, - label=args.label, - model=args.model, - weights=args.weights, - vanilla=args.vanilla) +def generate_demoapp(args): + cvtk.ml.generate_demoapp(args.project, + source=args.source, + label=args.label, + model=args.model, + weights=args.weights, + vanilla=args.vanilla) def split(args): @@ -83,7 +83,7 @@ def main(): parser_train.add_argument('--script', type=str, required=True) parser_train.add_argument('--task', type=str, choices=['cls', 'det', 'segm'], default='cls') parser_train.add_argument('--vanilla', action='store_true', default=False) - parser_train.set_defaults(func=create) + parser_train.set_defaults(func=generate_task_source) parser_train = subparsers.add_parser('app') parser_train.add_argument('--project', type=str, required=True) @@ -92,9 +92,9 @@ def main(): parser_train.add_argument('--model', type=str, default=True) parser_train.add_argument('--weights', type=str, required=True) parser_train.add_argument('--vanilla', action='store_true', default=False) - parser_train.set_defaults(func=app) + parser_train.set_defaults(func=generate_demoapp) - parser_split_text = subparsers.add_parser('split') + parser_split_text = subparsers.add_parser('text-split') parser_split_text.add_argument('--input', type=str, required=True) parser_split_text.add_argument('--output', type=str, required=True) parser_split_text.add_argument('--ratios', type=str, default='8:1:1') diff --git a/src/cvtk/tmpl/_flask.py b/src/cvtk/tmpl/_flask.py index fd8a46e..c532e30 100644 --- a/src/cvtk/tmpl/_flask.py +++ b/src/cvtk/tmpl/_flask.py @@ -7,17 +7,17 @@ import numpy as np import skimage.measure #%CVTK%# IF TASK=cls -from cvtk.ml.torchutils import CLSCORE as MODULECORE +from cvtk.ml.torchutils import ModuleCore #%CVTK%# ENDIF #%CVTK%# IF TASK=det,segm -from cvtk.ml.mmdetutils import MMDETCORE as MODULECORE +from cvtk.ml.mmdetutils import ModuleCore #%CVTK%# ENDIF # application variables APP_ROOT = pathlib.Path(__file__).resolve().parent APP_STORAGE = os.path.join(APP_ROOT, 'static', 'storage') APP_TEMP = os.path.join(APP_ROOT, 'tmp') -MODEL = MODULECORE('__DATALABEL__', '__MODELCFG__','__MODELWEIGHT__', workspace=APP_TEMP) +MODEL = ModuleCore('__DATALABEL__', '__MODELCFG__','__MODELWEIGHT__', workspace=APP_TEMP) if not os.path.exists(APP_STORAGE): os.makedirs(APP_STORAGE) if not os.path.exists(APP_TEMP): diff --git a/src/cvtk/tmpl/_ls_backend.py b/src/cvtk/tmpl/_ls_backend.py index 36224ec..1cfb91d 100644 --- a/src/cvtk/tmpl/_ls_backend.py +++ b/src/cvtk/tmpl/_ls_backend.py @@ -2,7 +2,7 @@ import tempfile import urllib from cvtk.ml.data import DataLabel -from cvtk.ml.mmdetutils import MMDETCORE +from cvtk.ml.mmdetutils import ModuleCore import label_studio_ml import label_studio_ml.model import label_studio_ml.api @@ -22,7 +22,7 @@ def __init__(self, *args, **kwargs): # model settings self.temp_dpath = tempfile.mkdtemp() self.datalabel = DataLabel("__DATALABEL__") - self.model = MMDETCORE(self.datalabel, "__MODELCFG__", "__MODELWEIGHT__", workspace=self.temp_dpath) + self.model = ModuleCore(self.datalabel, "__MODELCFG__", "__MODELWEIGHT__", workspace=self.temp_dpath) self.version = '0.0.0' diff --git a/src/cvtk/tmpl/_mmdet.py b/src/cvtk/tmpl/_mmdet.py index 07276c3..4a84d8b 100644 --- a/src/cvtk/tmpl/_mmdet.py +++ b/src/cvtk/tmpl/_mmdet.py @@ -2,14 +2,14 @@ import random from cvtk import ImageDeck from cvtk.ml.data import DataLabel -from cvtk.ml.mmdetutils import DataPipeline, Dataset, DataLoader, MMDETCORE, plot_trainlog +from cvtk.ml.mmdetutils import DataPipeline, Dataset, DataLoader, ModuleCore, plot_trainlog def train(label, train, valid, test, output_weights, batch_size=4, num_workers=8, epoch=10): temp_dpath = os.path.splitext(output_weights)[0] datalabel = DataLabel(label) - model = MMDETCORE(datalabel, "__TASKARCH__", None, workspace=temp_dpath) + model = ModuleCore(datalabel, "__TASKARCH__", None, workspace=temp_dpath) train = DataLoader( Dataset(datalabel, train, @@ -34,13 +34,14 @@ def train(label, train, valid, test, output_weights, batch_size=4, num_workers=8 output=os.path.splitext(output_weights)[0] + '.train_stats.train.png') if os.path.exists(os.path.splitext(output_weights)[0] + '.train_stats.valid.txt'): plot_trainlog(os.path.splitext(output_weights)[0] + '.train_stats.valid.txt', - output=os.path.splitext(output_weights)[0] + '.train_stats.valid.png') - + output=os.path.splitext(output_weights)[0] + '.train_stats.valid.png') + + def inference(label, data, model_weights, output, batch_size=4, num_workers=8): datalabel = DataLabel(label) - model = MMDETCORE(datalabel, os.path.splitext(model_weights)[0] + '.py', model_weights, workspace=output) + model = ModuleCore(datalabel, os.path.splitext(model_weights)[0] + '.py', model_weights, workspace=output) data = DataLoader( Dataset(datalabel, data, DataPipeline()), diff --git a/src/cvtk/tmpl/_torch.py b/src/cvtk/tmpl/_torch.py index 8160a76..af4a456 100644 --- a/src/cvtk/tmpl/_torch.py +++ b/src/cvtk/tmpl/_torch.py @@ -1,13 +1,13 @@ import os from cvtk.ml.data import DataLabel -from cvtk.ml.torchutils import DataTransform, Dataset, DataLoader, CLSCORE, plot_trainlog, plot_cm +from cvtk.ml.torchutils import DataTransform, Dataset, DataLoader, ModuleCore, plot_trainlog, plot_cm def train(label, train, valid, test, output_weights, batch_size=4, num_workers=8, epoch=10): temp_dpath = os.path.splitext(output_weights)[0] datalabel = DataLabel(label) - model = CLSCORE(datalabel, 'resnet18', 'ResNet18_Weights.DEFAULT', temp_dpath) + model = ModuleCore(datalabel, 'resnet18', 'ResNet18_Weights.DEFAULT', temp_dpath) train = DataLoader( Dataset(datalabel, train, transform=DataTransform(224, is_train=True)), @@ -30,11 +30,33 @@ def train(label, train, valid, test, output_weights, batch_size=4, num_workers=8 os.path.splitext(output_weights)[0] + '.test_outputs.cm.png') +def test(label, data, model_weights, output, batch_size=4, num_workers=8): + temp_dpath = os.path.splitext(model_weights)[0] + + datalabel = DataLabel(label) + model = ModuleCore(datalabel, 'resnet18', model_weights, temp_dpath) + + test = DataLoader( + Dataset(datalabel, data, transform=DataTransform(224, is_train=False)), + batch_size=batch_size, num_workers=num_workers) + + test_stats = model.test(test) + with open(output, 'w') as fh: + fh.write('# loss: {}\n'.format(test_stats['loss'])) + fh.write('# acc: {}\n'.format(test_stats['acc'])) + fh.write('\t'.join(['image', 'label'] + datalabel.labels) + '\n') + for x_, y_, p_ in zip(test_stats['dataset'].x, test_stats['dataset'].y, test_stats['probs']): + fh.write('{}\t{}\t{}\n'.format( + x_, + datalabel.labels[y_], + '\t'.join([str(_) for _ in p_]))) + + def inference(label, data, model_weights, output, batch_size=4, num_workers=8): temp_dpath = os.path.splitext(output)[0] datalabel = DataLabel(label) - model = CLSCORE(datalabel, 'resnet18', model_weights, temp_dpath) + model = ModuleCore(datalabel, 'resnet18', model_weights, temp_dpath) data = DataLoader( Dataset(datalabel, data, transform=DataTransform(224, is_train=False)), @@ -47,6 +69,8 @@ def inference(label, data, model_weights, output, batch_size=4, num_workers=8): def _train(args): train(args.label, args.train, args.valid, args.test, args.output_weights, args.batch_size, args.num_workers, args.epoch) +def _test(args): + test(args.label, args.data, args.model_weights, args.output, args.batch_size, args.num_workers) def _inference(args): inference(args.label, args.data, args.model_weights, args.output, args.batch_size, args.num_workers) @@ -69,6 +93,15 @@ def _inference(args): parser_train.add_argument('--epoch', type=int, default=10) parser_train.set_defaults(func=_train) + parser_test = subparsers.add_parser('test') + parser_test.add_argument('--label', type=str, required=True) + parser_test.add_argument('--data', type=str, required=True) + parser_test.add_argument('--model_weights', type=str, required=True) + parser_test.add_argument('--output', type=str, required=True) + parser_test.add_argument('--batch_size', type=int, default=2) + parser_test.add_argument('--num_workers', type=int, default=8) + parser_test.set_defaults(func=_test) + parser_inference = subparsers.add_parser('inference') parser_inference.add_argument('--label', type=str, required=True) parser_inference.add_argument('--data', type=str, required=True) @@ -93,7 +126,15 @@ def _inference(args): --test ./data/fruits/test.txt \\ --output_weights ./output/fruits.pth - + +python __SCRIPTNAME__ test \\ + --label ./data/fruits/label.txt \\ + --train ./data/fruits/train.txt \\ + --valid ./data/fruits/valid.txt \\ + --test ./data/fruits/test.txt \\ + --output ./output/test_results.txt + + python __SCRIPTNAME__ inference \\ --label ./data/fruits/label.txt \\ --data ./data/fruits/images \\ diff --git a/src/cvtk/tmpl/html/fastapi_.html b/src/cvtk/tmpl/html/_flask.html similarity index 100% rename from src/cvtk/tmpl/html/fastapi_.html rename to src/cvtk/tmpl/html/_flask.html diff --git a/src/cvtk/tmpl/html/flask_.html b/src/cvtk/tmpl/html/flask_.html deleted file mode 100644 index 0575a29..0000000 --- a/src/cvtk/tmpl/html/flask_.html +++ /dev/null @@ -1,172 +0,0 @@ - - - -Demo Application - - - - - - -
-
-
-

Demo Application

-
- -
-
-
- -
-
-
- -
-
-
-
- -
-

Inference Result

- - - -
-
- -
-
- - - \ No newline at end of file diff --git a/tests/test_demoapp.py b/tests/test_demoapp.py index c554ca6..be21e77 100644 --- a/tests/test_demoapp.py +++ b/tests/test_demoapp.py @@ -1,5 +1,5 @@ import os -from cvtk.ml import generate_source, generate_app +from cvtk.ml import generate_source, generate_demoapp import unittest import testutils @@ -37,7 +37,7 @@ def __run_proc(self, task, task_vanilla, api_vanilla, code_generator): if code_generator == 'source': - generate_app(app_project, + generate_demoapp(app_project, source=script, label=testutils.data[task]['label'], model=model_cfg, diff --git a/tests/test_ml.py b/tests/test_ml.py index 206a54c..766ebe6 100644 --- a/tests/test_ml.py +++ b/tests/test_ml.py @@ -69,19 +69,19 @@ def __init__(self, *args, **kwargs): def test_split_text(self): - testutils.run_cmd(['cvtk', 'split', + testutils.run_cmd(['cvtk', 'text-split', '--input', testutils.data['cls']['all'], '--output', os.path.join(self.ws, 'fruits_subset_1.txt'), '--ratios', '6:3:1', '--shuffle', '--stratify']) - testutils.run_cmd(['cvtk', 'split', + testutils.run_cmd(['cvtk', 'text-split', '--input', testutils.data['cls']['all'], '--output', os.path.join(self.ws, 'fruits_subset_2.txt'), '--ratios', '6:3:1', '--shuffle']) - testutils.run_cmd(['cvtk', 'split', + testutils.run_cmd(['cvtk', 'text-split', '--input', testutils.data['cls']['all'], '--output', os.path.join(self.ws, 'fruits_subset_3.txt'), '--ratios', '6:3:1']) diff --git a/tests/test_mmdet.py b/tests/test_mmdet.py index 1ffe495..5d82f39 100644 --- a/tests/test_mmdet.py +++ b/tests/test_mmdet.py @@ -2,12 +2,11 @@ from cvtk import imlist, ImageDeck from cvtk.ml import generate_source from cvtk.ml.data import DataLabel -from cvtk.ml.mmdetutils import MMDETCORE, DataLabel, DataLoader, Dataset, DataPipeline, plot_trainlog +from cvtk.ml.mmdetutils import ModuleCore, DataLabel, DataLoader, Dataset, DataPipeline, plot_trainlog import unittest import testutils - class TestScript(unittest.TestCase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -66,7 +65,6 @@ def test_segm_mmdet_cmd(self): self.__run_proc('segm', True, 'cmd') - class TestDataset(unittest.TestCase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -98,8 +96,6 @@ def test_dataset_dict(self): self.assertEqual(dataset.cfg, data_dict) - - class TestDataLoader(unittest.TestCase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -127,8 +123,6 @@ def test_dataloader_none(self): self.assertIsNotNone(dataloader.cfg) - - class TestMMDet(unittest.TestCase): def __init__(self, *args, **kwargs): @@ -147,28 +141,31 @@ def __inference(self, model, datalabel, data, output_dpath): output=output_dpath + os.path.basename(im.source)) - - def __test_mmdetutils(self, label, train, valid=None, test=None, output_dpath=None, batch_size=4, num_workers=8): + def __test_mmdetutils(self, label, train, valid=None, test=None, output_dpath=None, task='det', batch_size=4, num_workers=8): output_pfx = os.path.join(output_dpath, 'sb') datalabel = DataLabel(label) - model = MMDETCORE(datalabel, "faster-rcnn_r101_fpn_1x_coco", None, workspace=output_dpath) + if task == 'det': + model = ModuleCore(datalabel, "faster-rcnn_r101_fpn_1x_coco", None, workspace=output_dpath) + else: + model = ModuleCore(datalabel, "mask-rcnn_r101_fpn_1x_coco", None, workspace=output_dpath) + with_mask = False if task == 'det' else True train = DataLoader( Dataset(datalabel, train, - DataPipeline(is_train=True, with_bbox=True, with_mask=False)), + DataPipeline(is_train=True, with_bbox=True, with_mask=with_mask)), phase='train', batch_size=batch_size, num_workers=num_workers) if valid is not None: valid = DataLoader( Dataset(datalabel, valid, - DataPipeline(is_train=False, with_bbox=True, with_mask=False)), + DataPipeline(is_train=False, with_bbox=True, with_mask=with_mask)), phase='valid', batch_size=batch_size, num_workers=num_workers) if test is not None: test = DataLoader( Dataset(datalabel, test, - DataPipeline(is_train=False, with_bbox=True, with_mask=False)), + DataPipeline(is_train=False, with_bbox=True, with_mask=with_mask)), phase='test', batch_size=batch_size, num_workers=num_workers) - model.train(train, valid, test, epoch=5) + model.train(train, valid, test, epoch=10) model.save(f'{output_pfx}.pth') if os.path.exists(f'{output_pfx}.train_stats.train.txt'): @@ -178,9 +175,8 @@ def __test_mmdetutils(self, label, train, valid=None, test=None, output_dpath=No plot_trainlog(f'{output_pfx}.train_stats.valid.txt', output=f'{output_pfx}.train_stats.valid.png') - # inference - model = MMDETCORE(datalabel, f'{output_pfx}.py', f'{output_pfx}.pth', + model = ModuleCore(datalabel, f'{output_pfx}.py', f'{output_pfx}.pth', workspace=output_dpath) # images from a folder @@ -189,14 +185,14 @@ def __test_mmdetutils(self, label, train, valid=None, test=None, output_dpath=No self.__inference(model, datalabel, imlist(self.sample)[0], os.path.join(output_dpath, 'f_')) - def test_det_t_t_t(self): self.__test_mmdetutils( testutils.data['det']['label'], testutils.data['det']['train'], testutils.data['det']['valid'], testutils.data['det']['test'], - os.path.join(self.ws, 'det_trainvalidtest')) + os.path.join(self.ws, 'det_trainvalidtest'), + 'det') def test_det_t_t_f(self): @@ -205,7 +201,8 @@ def test_det_t_t_f(self): testutils.data['det']['train'], testutils.data['det']['valid'], None, - os.path.join(self.ws, 'det_trainvalid')) + os.path.join(self.ws, 'det_trainvalid'), + 'det') def test_det_t_f_t(self): @@ -214,7 +211,8 @@ def test_det_t_f_t(self): testutils.data['det']['train'], None, testutils.data['det']['test'], - os.path.join(self.ws, 'det_traintest')) + os.path.join(self.ws, 'det_traintest'), + 'det') def test_det_t_f_f(self): @@ -223,7 +221,8 @@ def test_det_t_f_f(self): testutils.data['det']['train'], None, None, - os.path.join(self.ws, 'det_train')) + os.path.join(self.ws, 'det_train'), + 'det') def test_segm_t_t_t(self): @@ -232,7 +231,8 @@ def test_segm_t_t_t(self): testutils.data['segm']['train'], testutils.data['segm']['valid'], testutils.data['segm']['test'], - os.path.join(self.ws, 'segm_trainvalidtest')) + os.path.join(self.ws, 'segm_trainvalidtest'), + 'segm') def test_segm_t_t_f(self): @@ -241,7 +241,8 @@ def test_segm_t_t_f(self): testutils.data['segm']['train'], testutils.data['segm']['valid'], None, - os.path.join(self.ws, 'segm_trainvalid')) + os.path.join(self.ws, 'segm_trainvalid'), + 'segm') def test_segm_t_f_t(self): @@ -250,7 +251,8 @@ def test_segm_t_f_t(self): testutils.data['segm']['train'], None, testutils.data['segm']['test'], - os.path.join(self.ws, 'segm_traintest')) + os.path.join(self.ws, 'segm_traintest'), + 'segm') def test_segm_t_f_f(self): @@ -259,10 +261,9 @@ def test_segm_t_f_f(self): testutils.data['segm']['train'], None, None, - os.path.join(self.ws, 'segm_train')) - + os.path.join(self.ws, 'segm_train'), + 'segm') if __name__ == '__main__': unittest.main() - diff --git a/tests/test_torch.py b/tests/test_torch.py index c25c78d..6cba8b6 100644 --- a/tests/test_torch.py +++ b/tests/test_torch.py @@ -3,12 +3,12 @@ from cvtk import imlist from cvtk.ml import generate_source from cvtk.ml.data import DataLabel -from cvtk.ml.torchutils import DataLabel, CLSCORE, DataLoader, Dataset, DataTransform, plot_trainlog, plot_cm +from cvtk.ml.torchutils import DataLabel, ModuleCore, DataLoader, Dataset, DataTransform, plot_trainlog, plot_cm import unittest import testutils -class TestTorch(unittest.TestCase): +class TestScript(unittest.TestCase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -33,9 +33,14 @@ def __run_proc(self, vanilla, code_generator): '--test', testutils.data['cls']['test'], '--output_weights', os.path.join(dpath, 'fruits.pth')]) + testutils.run_cmd(['python', script, 'test', + '--label', testutils.data['cls']['label'], + '--data', testutils.data['cls']['test'], + '--model_weights', os.path.join(dpath, 'fruits.pth'), + '--output', os.path.join(dpath, 'test_results.txt')]) + testutils.run_cmd(['python', script, 'inference', '--label', testutils.data['cls']['label'], - #'--data', TU.data['cls']['test'], '--data', testutils.data['cls']['samples'], '--model_weights', os.path.join(dpath, 'fruits.pth'), '--output', os.path.join(dpath, 'inference_results.txt')]) @@ -58,7 +63,7 @@ def test_torch_cmd(self): -class TestTorchUtils(unittest.TestCase): +class TestTorch(unittest.TestCase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.ws = testutils.set_ws('torch_torchutils') @@ -84,7 +89,7 @@ def __test_torchutils(self, train, valid=None, test=None, output=None, batch_siz temp_dpath = os.path.splitext(output)[0] datalabel = DataLabel(self.label) - model = CLSCORE(datalabel, 'resnet18', 'ResNet18_Weights.DEFAULT', temp_dpath) + model = ModuleCore(datalabel, 'resnet18', 'ResNet18_Weights.DEFAULT', temp_dpath) train = DataLoader( Dataset(datalabel, train, transform=DataTransform(224, is_train=True)), @@ -110,7 +115,7 @@ def __test_torchutils(self, train, valid=None, test=None, output=None, batch_siz os.path.splitext(output)[0] + '.test_outputs.cm.png') - model = CLSCORE(datalabel, 'resnet18', output, temp_dpath) + model = ModuleCore(datalabel, 'resnet18', output, temp_dpath) self.__inference(model, datalabel, self.sample, os.path.splitext(output)[0] + '.inference_results.txt') self.__inference(model, datalabel, imlist(self.sample), os.path.splitext(output)[0] + '.inference_results.txt') self.__inference(model, datalabel, imlist(self.sample)[0], os.path.splitext(output)[0] + '.inference_results.txt')