Skip to content

Commit

Permalink
[add] add functions to treat labelstudio api.
Browse files Browse the repository at this point in the history
  • Loading branch information
jsun committed Aug 6, 2024
1 parent e664c23 commit 1810dcf
Show file tree
Hide file tree
Showing 7 changed files with 163 additions and 1,320 deletions.
1,313 changes: 0 additions & 1,313 deletions src/cvtk/format/__format.py

This file was deleted.

1 change: 1 addition & 0 deletions src/cvtk/ls/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ._base import export, generate_app
156 changes: 156 additions & 0 deletions src/cvtk/ls/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@

import os
import shutil
import zipfile
import tempfile
import urllib
import json
import importlib
from ..ml.torchutils import __generate_source as generate_source_cls
from ..ml.mmdetutils import __generate_source as generate_source_det
from ..ml._subutils import __estimate_task_from_source, __generate_app_html_tmpl
import label_studio_sdk



def __get_client(host, port, api_key=None):
url = f'{host}:{port}'
if api_key is None and api_key == '':
api_key = os.getenv('LABEL_STUDIO_API_KEY')
if api_key is None:
raise ValueError(f'API KEY is required to access Label Studio API. '
f'Set the API KEY with the argument `api_key` or '
f'export the API KEY as an environment variable '
f'`LABEL_STUDIO_API_KEY` (e.g., export LABEL_STUDIO_API_KEY=cdc903.....z3r9xkmr')
return label_studio_sdk.Client(url=url, api_key=api_key)


def export(project: int,
output: str,
format: str='COCO',
host: str='http://localhost',
port: int=8080,
api_key: str|None=None,
indent: int=4,
ensure_ascii: bool=False) -> dict:
"""
Export annotations from Label Studio project.
Args:
project: An ID of Label Studio project to export.
output: A path to save the exported data.
format: The format of the exported data. The supported formats are `COCO`,
`JSON` (Label Studio JSON), `JSON_MIN`, `CSV`, `TSV`, `VOC` (Pascal VOC),
`YOLO`, and others (see Label Studio Documentations for details).
Note that Only COCO has been implemented so far.
host: Label Studio host. Default is 'localhost'.
port: Label Studio port. Default is 8080.
api_key: Label Studio API key. Default is None.
indent: JSON indent. Default is 4.
ensure_ascii: Ensure ASCII. Default is False
Returns:
dict: A dictionary of the exported data.
Examples:
>>> import os
>>> from cvtk.ls import export
>>>
>>> data = export(project=0, output='instances.coco.json', format='COCO',
host='localhost', port=8080,
api_key='f6dea26f0a0f81883e04681b4e649c600fe50fc')
>>> print(data)
{'info': {'contributor': 'Label Studio', 'description': '', ...., 'images': [...], 'annotations': [...]}
>>>
>>> os.environ['LABEL_STUDIO_API_KEY'] = 'f6dea26f0a0f81883e04681b4e649c600fe50fc'
>>> data = export(project=0, output='instances.coco.json', format='COCO',
host='localhost', port=8080)
>>>
"""
client = __get_client(host, port, api_key)
prj = client.get_project(project)
ls_data_root = os.getenv('LABEL_STUDIO_LOCAL_FILES_DOCUMENT_ROOT')
format = format.upper()

with tempfile.TemporaryDirectory() as temp_dpath:
tempf_output_ = os.path.join(temp_dpath, 'output.zip')

prj.export_tasks(export_type=format,
download_all_tasks=False,
download_resources=False,
export_location=tempf_output_)

if format == 'COCO':
with zipfile.ZipFile(tempf_output_, 'r') as zf:
zf.extract('result.json', path=temp_dpath)
shutil.copy(os.path.join(temp_dpath, 'result.json'), output)
else:
raise NotImplementedError(f'Export format `{format}` is not implemented yet.')

# modify the image path in the exported json file
exported_data = None
with open(output, 'r') as fh:
exported_data = json.load(fh)
for img in exported_data['images']:
img['file_name'] = img['file_name'].replace('\/', '/')
if '/data/local-files/?d=' in img['file_name']:
img['file_name'] = img['file_name'].replace('/data/local-files/?d=', '')
if ls_data_root is not None:
img['file_name'] = os.path.join(ls_data_root, img['file_name'])
img['file_name'] = urllib.parse.unquote(img['file_name'])

print(img['file_name'])
with open(output, 'w') as f:
json.dump(exported_data, f, indent=indent, ensure_ascii=ensure_ascii)

return exported_data



def generate_app(project: str, source: str, label: str, model: str, weights: str, vanilla=False) -> None:
"""Generate a FastAPI application for inference of a classification or detection model
This function generates a FastAPI application for inference of a classification or detection model.
"""

if not os.path.exists(project):
os.makedirs(project)


coremodule = os.path.splitext(os.path.basename(source))[0]
data_label = os.path.basename(label)
model_cfg = os.path.basename(model)
model_weights = os.path.basename(weights)

shutil.copy2(source, os.path.join(project, coremodule + '.py'))
shutil.copy2(label, os.path.join(project, data_label))
if os.path.exists(model):
shutil.copy2(model, os.path.join(project, model_cfg))
shutil.copy2(weights, os.path.join(project, model_weights))

source_task_type, source_is_vanilla = __estimate_task_from_source(source)

# FastAPI script
tmpl = __generate_app_html_tmpl(importlib.resources.files('cvtk').joinpath(f'tmpl/_ls_backend.py'), source_task_type)
if vanilla:
if source_is_vanilla:
for i in range(len(tmpl)):
if tmpl[i][:9] == 'from cvtk':
if source_task_type == 'cls':
tmpl[i] = f'from {coremodule} import CLSCORE as MODULECORE'
elif source_task_type == 'det':
tmpl[i] = f'from {coremodule} import MMDETCORE as MODULECORE'
else:
raise ValueError('Unsupport Type.')
else:
print('The CLSCORE or MMDETCORE class definition is not found in the source code. The script will be generated with importation of cvtk.')
tmpl = ''.join(tmpl)
tmpl = tmpl.replace('__DATALABEL__', data_label)
tmpl = tmpl.replace('__MODELCFG__', model_cfg)
tmpl = tmpl.replace('__MODELWEIGHT__', model_weights)
with open(os.path.join(project, 'main.py'), 'w') as fh:
fh.write(tmpl)




Empty file removed src/cvtk/plot.py
Empty file.
6 changes: 3 additions & 3 deletions tests/run_test.sh
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
coverage run -p -m unittest test_base
coverage run -p -m unittest test_coco
coverage run -p -m unittest test_mlbase
coverage run -p -m unittest test_scripts
coverage run -p -m unittest test_ml
coverage run -p -m unittest test_mmdet
coverage run -p -m unittest test_torch
coverage run -p -m unittest test_fastapi
coverage run -p -m unittest test_demoapp
coverage run -p -m unittest test_ls
coverage combine
coverage report -m
coverage html
Expand Down
6 changes: 2 additions & 4 deletions tests/test_fastapi.py → tests/test_demoapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self, *args, **kwargs):
def __run_proc(self, task, task_vanilla, api_vanilla, code_generator):
task_module = 'vanilla' if task_vanilla else 'cvtk'
api_module = 'vanilla' if api_vanilla else 'cvtk'
dpath = testutils.set_ws(f'fastapi_demoapp__{task}_{task_module}_{api_module}_{code_generator}')
dpath = testutils.set_ws(f'demoapp__{task}_{task_module}_{api_module}_{code_generator}')

script = os.path.join(dpath, 'script.py')
model_weight = os.path.join(dpath, 'model.pth')
Expand Down Expand Up @@ -53,9 +53,7 @@ def __run_proc(self, task, task_vanilla, api_vanilla, code_generator):
if api_vanilla:
cmd_.append('--vanilla')
testutils.run_cmd(cmd_)

#testutils.run_cmd(['uvicorn', app_project, '--host', '0.0.0.0', '--port', '8080', '--reload'])



def test_cls(self):
self.__run_proc('cls', True, True, 'source')
Expand Down
1 change: 1 addition & 0 deletions tests/testutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def set_ws(dpath):


def run_cmd(cmd):
cmd = [str(_) for _ in cmd]
print('\nCOMMAND -----------------------------------------')
print(' '.join(cmd))
print('-------------------------------------------------\n')
Expand Down

0 comments on commit 1810dcf

Please sign in to comment.