From 059cb8e7141bfb47edae29a99fe96a46f9e42bb9 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Fri, 10 Mar 2023 07:30:07 +0900 Subject: [PATCH] Add train-type parameter to otx train --- .../guide/get_started/quick_start_guide/cli_commands.rst | 2 ++ otx/cli/manager/config_manager.py | 2 +- otx/cli/tools/train.py | 7 +++++++ 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/docs/source/guide/get_started/quick_start_guide/cli_commands.rst b/docs/source/guide/get_started/quick_start_guide/cli_commands.rst index f2c8ead424b..e6ed0619c3d 100644 --- a/docs/source/guide/get_started/quick_start_guide/cli_commands.rst +++ b/docs/source/guide/get_started/quick_start_guide/cli_commands.rst @@ -180,6 +180,8 @@ However, if you created a workspace with ``otx build``, the training process can Comma-separated paths to unlabeled data folders --unlabeled-file-list UNLABELED_FILE_LIST Comma-separated paths to unlabeled file list + --train-type TRAIN_TYPE + The currently supported options: dict_keys(['INCREMENTAL', 'SEMISUPERVISED', 'SELFSUPERVISED']). --load-weights LOAD_WEIGHTS Load model weights from previously saved checkpoint. --resume-from RESUME_FROM diff --git a/otx/cli/manager/config_manager.py b/otx/cli/manager/config_manager.py index b85d65761ac..56dddb837f9 100644 --- a/otx/cli/manager/config_manager.py +++ b/otx/cli/manager/config_manager.py @@ -186,7 +186,7 @@ def _get_train_type(self, ignore_args: bool = False) -> str: if arg_algo_backend: train_type = arg_algo_backend.get("train_type", {"value": "INCREMENTAL"}) # type: ignore return train_type.get("value", "INCREMENTAL") - if self.mode in ("build") and self.args.train_type: + if self.mode in ("build", "train") and self.args.train_type: self.train_type = self.args.train_type.upper() if self.train_type not in TASK_TYPE_TO_SUB_DIR_NAME: raise ValueError(f"{self.train_type} is not currently supported by otx.") diff --git a/otx/cli/tools/train.py b/otx/cli/tools/train.py index bcc8e7ba878..935f65d1315 100644 --- a/otx/cli/tools/train.py +++ b/otx/cli/tools/train.py @@ -28,6 +28,7 @@ from otx.api.serialization.label_mapper import label_schema_to_bytes from otx.api.usecases.adapters.model_adapter import ModelAdapter from otx.cli.manager import ConfigManager +from otx.cli.manager.config_manager import TASK_TYPE_TO_SUB_DIR_NAME from otx.cli.utils.hpo import run_hpo from otx.cli.utils.importing import get_impl_class from otx.cli.utils.io import read_binary, read_label_schema, save_model_data @@ -60,6 +61,12 @@ def get_args(): "--unlabeled-file-list", help="Comma-separated paths to unlabeled file list", ) + parser.add_argument( + "--train-type", + help=f"The currently supported options: {TASK_TYPE_TO_SUB_DIR_NAME.keys()}.", + type=str, + default="incremental", + ) parser.add_argument( "--load-weights", help="Load model weights from previously saved checkpoint.",