Skip to content

Commit

Permalink
[clean] clean up code.
Browse files Browse the repository at this point in the history
  • Loading branch information
jsun committed Nov 1, 2024
1 parent 393ed0b commit 6109e13
Show file tree
Hide file tree
Showing 20 changed files with 316 additions and 475 deletions.
6 changes: 3 additions & 3 deletions docs/source/tutorials/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ For accurate model evaluation, it is necessary to split the data into subsets su
validation, and test sets for training and evaluating the model.
The **cvtk** package provides a convenient command for splitting a single dataset into multiple subsets.

If the dataset is saved in a text file, use the ``cvtk split`` command.
If the dataset is saved in a text file, use the ``cvtk text-split`` command.
For example, suppose you have a tab-delimited text file :file:`data.txt`
with the image file paths in the first column and the label names in the second column:

Expand All @@ -36,7 +36,7 @@ Note that, by adding the ``--shuffle`` argument, the data is shuffled before spl

.. code-block:: sh
cvtk split --input data.txt --output data_subset.txt --ratios 6:2:2 --shuffle
cvtk text-split --input data.txt --output data_subset.txt --ratios 6:2:2 --shuffle
The command generates the files
Expand Down Expand Up @@ -66,7 +66,7 @@ to ensure that each subset has a uniform class distribution.

.. code-block:: sh
cvtk split --input all.txt --output data_subset.txt --ratios 6:2:2 --shuffle --stratify
cvtk text-split --input all.txt --output data_subset.txt --ratios 6:2:2 --shuffle --stratify
Expand Down
2 changes: 1 addition & 1 deletion src/cvtk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '0.2.15'
__version__ = '0.2.16'

from ._base import imread, imconvert, imwrite, imshow, imlist, imresize
from ._base import Annotation, Image, ImageDeck, JsonComplexEncoder
27 changes: 11 additions & 16 deletions src/cvtk/ls/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import importlib
from ..ml.torchutils import __generate_source as generate_source_cls
from ..ml.mmdetutils import __generate_source as generate_source_det
from ..ml._subutils import __estimate_task_from_source, __generate_app_html_tmpl
from ..ml._subutils import __estimate_source_task, __estimate_source_vanilla, __generate_app_html_tmpl
import label_studio_sdk


Expand Down Expand Up @@ -99,7 +99,6 @@ def export(project: int,
img['file_name'] = os.path.join(ls_data_root, img['file_name'])
img['file_name'] = urllib.parse.unquote(img['file_name'])

print(img['file_name'])
with open(output, 'w') as f:
json.dump(exported_data, f, indent=indent, ensure_ascii=ensure_ascii)

Expand All @@ -116,7 +115,6 @@ def generate_app(project: str, source: str, label: str, model: str, weights: str
if not os.path.exists(project):
os.makedirs(project)


coremodule = os.path.splitext(os.path.basename(source))[0]
data_label = os.path.basename(label)
model_cfg = os.path.basename(model)
Expand All @@ -128,28 +126,25 @@ def generate_app(project: str, source: str, label: str, model: str, weights: str
shutil.copy2(model, os.path.join(project, model_cfg))
shutil.copy2(weights, os.path.join(project, model_weights))

source_task_type, source_is_vanilla = __estimate_task_from_source(source)
source_task_type = __estimate_source_task(source)
source_is_vanilla = __estimate_source_vanilla(source)

# FastAPI script
tmpl = __generate_app_html_tmpl(importlib.resources.files('cvtk').joinpath(f'tmpl/_ls_backend.py'), source_task_type)
tmpl = __generate_app_html_tmpl(importlib.resources.files('cvtk').joinpath(f'tmpl/_ls_backend.py'),
source_task_type)

if vanilla:
if source_is_vanilla:
for i in range(len(tmpl)):
if tmpl[i][:9] == 'from cvtk':
if source_task_type == 'cls':
tmpl[i] = f'from {coremodule} import CLSCORE as MODULECORE'
elif source_task_type == 'det':
tmpl[i] = f'from {coremodule} import MMDETCORE as MODULECORE'
else:
raise ValueError('Unsupport Type.')
if (tmpl[i][:9] == 'from cvtk') and ('import ModuleCore' in tmpl[i]):
tmpl[i] = f'from {coremodule} import ModuleCore'
else:
print('The CLSCORE or MMDETCORE class definition is not found in the source code. The script will be generated with importation of cvtk.')
# user specified vanilla, but the source code for CV task is not vanilla
print('The `ModuleCore` class definition is not found in the source code. `ModuleCore` will be generated with importation of cvtk regardless vanilla is specified.')

tmpl = ''.join(tmpl)
tmpl = tmpl.replace('__DATALABEL__', data_label)
tmpl = tmpl.replace('__MODELCFG__', model_cfg)
tmpl = tmpl.replace('__MODELWEIGHT__', model_weights)
with open(os.path.join(project, 'main.py'), 'w') as fh:
fh.write(tmpl)



2 changes: 1 addition & 1 deletion src/cvtk/ml/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from ._base import split_dataset, generate_source, generate_app
from ._base import split_dataset, generate_source, generate_demoapp
53 changes: 25 additions & 28 deletions src/cvtk/ml/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import re
from .torchutils import __generate_source as generate_source_cls
from .mmdetutils import __generate_source as generate_source_det
from ._subutils import __estimate_task_from_source, __generate_app_html_tmpl
from ._subutils import __estimate_source_task, __estimate_source_vanilla, __generate_app_html_tmpl


def split_dataset(data: str|list[str, str]|tuple[str, str],
Expand All @@ -24,9 +24,9 @@ def split_dataset(data: str|list[str, str]|tuple[str, str],
or a path to a text file. If list is given, each element of the list is treated as a sample.
output: The output file name will be appended with the index of the split subset.
ratios: The ratios to split the dataset. The sum of the ratios should be 1.
shuffle (bool): Shuffle the dataset before splitting.
stratify (bool): Split the dataset with a balanced class distribution if `label` is given.
random_seed (int|none): Random seed for shuffling the dataset.
shuffle: Shuffle the dataset before splitting.
stratify: Split the dataset with a balanced class distribution if `label` is given.
random_seed: Random seed for shuffling the dataset.
Returns:
A list of the split datasets. The length of the list is the same as the length of `ratios`.
Expand Down Expand Up @@ -113,8 +113,8 @@ def split_dataset(data: str|list[str, str]|tuple[str, str],



def generate_source(project: str, task: str='cls', vanilla=False) -> None:
"""Generate source code for training and inference of a classification model using PyTorch
def generate_source(project: str, task: str='cls', vanilla: bool=False) -> None:
"""Generate source code for classification or detection tasks
This function generates a Python script for training and inference of a model
using PyTorch (for classification task) or MMDetection (for object detection and instance segmentation tasks).
Expand All @@ -128,9 +128,9 @@ def generate_source(project: str, task: str='cls', vanilla=False) -> None:
since all functions is implemented directly in torch and torchvision.
Args:
project (str): A file path to save the script.
task (str): The task type of project. Three types of tasks can be specified ('cls', 'det', 'segm'). The default is 'cls'.
vanilla (bool): Generate a script without importation of cvtk. The default is False.
project: A file path to save the script.
task: The task type of project. Three types of tasks can be specified ('cls', 'det', 'segm'). The default is 'cls'.
vanilla: Generate a script without importation of cvtk. The default is False.
"""

if task.lower() in ['cls', 'classification']:
Expand All @@ -141,18 +141,18 @@ def generate_source(project: str, task: str='cls', vanilla=False) -> None:
raise ValueError('The current version only support classification (`cls`), detection (`det`), and segmentation (`segm`) tasks.')


def generate_app(project: str, source: str, label: str, model: str, weights: str, vanilla=False) -> None:
def generate_demoapp(project: str, source: str, label: str, model: str, weights: str, vanilla: bool=False) -> None:
"""Generate a FastAPI application for inference of a classification or detection model
This function generates a FastAPI application for inference of a classification or detection model.
Args:
project (str): A file path to save the FastAPI application.
source (str): The source code of the model.
label (str): The label file of the dataset.
model (str): The configuration file of the model.
weights (str): The weights file of the model.
module (str): Script with importation of cvtk ('cvtk') or not ('fastapi').
project: A file path to save the FastAPI application.
source: The source code of the model.
label: The label file of the dataset.
model: The configuration file of the model.
weights: The weights file of the model.
module: Script with importation of cvtk ('cvtk') or not ('fastapi').
Examples:
>>> from cvtk.ml import generate_app
Expand All @@ -173,22 +173,20 @@ def generate_app(project: str, source: str, label: str, model: str, weights: str
shutil.copy2(model, os.path.join(project, model_cfg))
shutil.copy2(weights, os.path.join(project, model_weights))

source_task_type, source_is_vanilla = __estimate_task_from_source(source)
source_task_type = __estimate_source_task(source)
source_is_vanilla = __estimate_source_vanilla(source)

# FastAPI script
tmpl = __generate_app_html_tmpl(importlib.resources.files('cvtk').joinpath(f'tmpl/_flask.py'), source_task_type)
tmpl = __generate_app_html_tmpl(importlib.resources.files('cvtk').joinpath(f'tmpl/_flask.py'),
source_task_type)
if vanilla:
if source_is_vanilla:
for i in range(len(tmpl)):
if tmpl[i][:9] == 'from cvtk':
if source_task_type == 'cls':
tmpl[i] = f'from {coremodule} import CLSCORE as MODULECORE'
elif source_task_type == 'det':
tmpl[i] = f'from {coremodule} import MMDETCORE as MODULECORE'
else:
raise ValueError('Unsupport Type.')
if (tmpl[i][:9] == 'from cvtk') and ('import ModuleCore' in tmpl[i]):
tmpl[i] = f'from {coremodule} import ModuleCore'
else:
print('The CLSCORE or MMDETCORE class definition is not found in the source code. The script will be generated with importation of cvtk.')
# user specified vanilla, but the source code for CV task is not vanilla
print('The `ModuleCore` class definition is not found in the source code. `ModuleCore` will be generated with importation of cvtk regardless vanilla is specified.')
tmpl = ''.join(tmpl)
tmpl = tmpl.replace('__DATALABEL__', data_label)
tmpl = tmpl.replace('__MODELCFG__', model_cfg)
Expand All @@ -199,7 +197,6 @@ def generate_app(project: str, source: str, label: str, model: str, weights: str
# HTML template
if not os.path.exists(os.path.join(project, 'templates')):
os.makedirs(os.path.join(project, 'templates'))
tmpl = __generate_app_html_tmpl(importlib.resources.files('cvtk').joinpath(f'tmpl/html/fastapi_.html'), source_task_type)
tmpl = __generate_app_html_tmpl(importlib.resources.files('cvtk').joinpath(f'tmpl/html/_flask.html'), source_task_type)
with open(os.path.join(project, 'templates', 'index.html'), 'w') as fh:
fh.write(''.join(tmpl))

111 changes: 29 additions & 82 deletions src/cvtk/ml/_subutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,86 +7,57 @@


def __get_imports(code_file: str) -> list[str]:
"""Find lines containing import statements from a file.
Args:
code_file (str): Path to a python file.
"""
imports = []

with open(code_file, 'r') as codefh:
for codeline in codefh:
if codeline[0:6] == 'import':
imports.append(codeline)
if codeline[0:6] == 'class ' or codeline[0:4] == 'def ':
break
return imports


def __insert_imports(tmpl: list[str], modules: list[str]) -> list[str]:
"""Insert import statements to a template.
Insert import statements to a template (`tmpl`) at the end of the import statements in the template.
Args:
tmpl (list[str]): A list of strings containing the template.
modules (list[str]): A list of strings containing import statements.
Examples:
>>> tmpl = ['import os',
... '',
... 'print("Hello, World!")'],
>>> modules = ['import cvtk']
>>> __insert_imports(tmpl, modules)
['import os',
'import cvtk',
'',
'print("Hello, World!")']
>>>
>>>
>>> tmpl = __insert_imports(tmpl, __get_imports(__file__)
>>>
"""Insert imports into the top of the template code.
This function deletes the original imports in the template (`tmpl`)
and then inserts modules listed in `modules` argument.
"""
extmpl = []
imported = False
for codeline in tmpl:
if codeline[0:6] == 'import':
# pass the imports in original file
pass
pass # delete the original imports
else:
if not imported:
# insert the imports
# insert the modules listed in `modules` argument
for mod in modules:
extmpl.append(mod)
imported = True
# append the original code after the imports
extmpl.append(codeline)
return extmpl


def __extend_cvtk_imports(tmpl, module_dicts):
extmpl = []

extended = False
for codeline in tmpl:
if codeline[0:9] == 'from cvtk':
# find the first cvtk import statement and replace it with the source code of the modules
if not extended:
for mod_dict in module_dicts:
for mod_name, mod_funcs in mod_dict.items():
for mod_func in mod_funcs:
extmpl.append('\n\n\n' + inspect.getsource(mod_func))
extended = True
else:
# append the original code after the extending of cvtk module source code
extmpl.append(codeline)

return extmpl


def __del_docstring(func_source: str) -> str:
"""Delete docstring from source code.
Delete docstring (strings between \"\"\" or ''' ) from source code.
Args:
func_source (str): Source code of a function.
"""
func_source_ = ''
is_docstring = False
omit = False
Expand All @@ -97,6 +68,7 @@ def __del_docstring(func_source: str) -> str:
is_docstring = not is_docstring
else:
if not is_docstring:
line = line.replace('\\\\', '\\')
func_source_ += line + '\n'
return func_source_

Expand Down Expand Up @@ -125,47 +97,22 @@ def __generate_app_html_tmpl(tmpl_fpath, task):
return tmpl


def __estimate_task_from_source(source):
task_ = {
'CLSCORE': {'classdef': 0, 'import': 0, 'call': 0},
'MMDETCORE': {'classdef': 0, 'import': 0, 'call': 0}
}
def __estimate_source_task(source):
spec = importlib.util.spec_from_file_location('ModuleCore', source)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
ModuleCoreClass = getattr(module, 'ModuleCore', None)
if ModuleCoreClass is None:
raise AttributeError("ModuleCore class not found in the specified source file.")
module_instance = ModuleCoreClass(None, None)
return module_instance.task_type


def __estimate_source_vanilla(source):
is_vanilla = True
with open(source, 'r') as infh:
for codeline in infh:
if 'class CLSCORE' in codeline:
task_['CLSCORE']['classdef'] += 1
elif 'import cvtk.ml.torchuitls' in codeline:
task_['CLSCORE']['import'] += 1
elif 'from cvtk.ml.torchutils import' in codeline and 'CLSCORE' in codeline:
task_['CLSCORE']['import'] += 1
elif 'CLSCORE(' in codeline:
task_['CLSCORE']['call'] += 1
elif 'class MMDETCORE' in codeline:
task_['MMDETCORE']['classdef'] += 1
elif 'import cvtk.ml.mmdetutils' in codeline:
task_['MMDETCORE']['import'] += 1
elif 'from cvtk.ml.mmdetutils import' in codeline and 'MMDETCORE' in codeline:
task_['MMDETCORE']['import'] += 1
elif 'MMDETCORE(' in codeline:
task_['MMDETCORE']['call'] += 1
is_cls_cvtk = (task_['CLSCORE']['import'] > 0) and (task_['CLSCORE']['call'] > 0)
is_cls_vanilla = (task_['CLSCORE']['classdef'] > 0) and (task_['CLSCORE']['call'] > 0)
is_cls = is_cls_cvtk or is_cls_vanilla
is_det_cvtk = (task_['MMDETCORE']['import'] > 0) and (task_['MMDETCORE']['call'] > 0)
is_det_vanilla = (task_['MMDETCORE']['classdef'] > 0) and (task_['MMDETCORE']['call'] > 0)
is_det = is_det_cvtk or is_det_vanilla

if is_cls and (not is_det):
task = 'cls'
elif (not is_cls) and is_det:
task = 'det'
else:
raise ValueError('The task type cannot be determined from the source code. Make sure your source code contains CLSCORE or MMDETCORE class definition or importation, and call.')
if is_cls_cvtk or is_det_cvtk:
is_vanilla = False
elif is_cls_vanilla or is_det_vanilla:
is_vanilla = True
else:
raise ValueError('The source code cannot be determined from the source code. Make sure your source code contains importation of cvtk.ml.torchutils or cvtk.ml.mmdetutils.')

return task, is_vanilla
if ('import cvtk' in codeline) or ('from cvtk' in codeline):
is_vanilla = False
break
return is_vanilla
Loading

0 comments on commit 6109e13

Please sign in to comment.