diff --git a/documentation/docs/guide/experiments/experiment-settings.md b/documentation/docs/guide/experiments/experiment-settings.md index 695a89917..f46f75e6b 100644 --- a/documentation/docs/guide/experiments/experiment-settings.md +++ b/documentation/docs/guide/experiments/experiment-settings.md @@ -11,6 +11,7 @@ import DSvalidationStrategy from '../../tooltips/experiments/_validation-strateg import DSvalidationSize from '../../tooltips/experiments/_validation-size.mdx'; import DSdataSample from '../../tooltips/experiments/_data-sample.mdx'; import DSpromptColumn from '../../tooltips/experiments/_prompt-column.mdx'; +import DSPromptColumnSeparator from '../../tooltips/experiments/_prompt-column-separator.mdx'; import DSsystemColumn from '../../tooltips/experiments/_system-column.mdx'; import DSanswerColumn from '../../tooltips/experiments/_answer-column.mdx'; import DSparentIdColumn from '../../tooltips/experiments/_parent-id-column.mdx'; @@ -141,6 +142,10 @@ The settings under each category are listed and described below. +### Prompt column separator + + + ### Answer column diff --git a/documentation/docs/tooltips/experiments/_prompt-column-separator.mdx b/documentation/docs/tooltips/experiments/_prompt-column-separator.mdx new file mode 100644 index 000000000..d96f3d85d --- /dev/null +++ b/documentation/docs/tooltips/experiments/_prompt-column-separator.mdx @@ -0,0 +1 @@ +If multiple prompt columns are selected, the columns are concatenated with the separator defined here. If only a single prompt column is selected, this setting is ignored. \ No newline at end of file diff --git a/documentation/docs/tooltips/experiments/_prompt-column.mdx b/documentation/docs/tooltips/experiments/_prompt-column.mdx index d51c06e97..a7377020e 100644 --- a/documentation/docs/tooltips/experiments/_prompt-column.mdx +++ b/documentation/docs/tooltips/experiments/_prompt-column.mdx @@ -1 +1 @@ -The column in the dataset containing the user prompt. \ No newline at end of file +One column or multiple columns in the dataset containing the user prompt. If multiple columns are selected, the columns are concatenated with a separator defined in **Prompt Column Separator**. \ No newline at end of file diff --git a/llm_studio/app_utils/hugging_face_utils.py b/llm_studio/app_utils/hugging_face_utils.py index 4e8d82011..5c7d00e94 100644 --- a/llm_studio/app_utils/hugging_face_utils.py +++ b/llm_studio/app_utils/hugging_face_utils.py @@ -40,7 +40,9 @@ def get_model_card(cfg, model, repo_id) -> huggingface_hub.ModelCard: text_answer_separator=cfg.dataset.text_answer_separator, trust_remote_code=cfg.environment.trust_remote_code, end_of_sentence=( - cfg._tokenizer_eos_token if cfg.dataset.add_eos_token_to_prompt else "" + cfg.tokenizer._tokenizer_eos_token + if cfg.dataset.add_eos_token_to_prompt + else "" ), ) if cfg.problem_type not in NON_GENERATION_PROBLEM_TYPES: diff --git a/llm_studio/app_utils/sections/chat_update.py b/llm_studio/app_utils/sections/chat_update.py index 66cd4c1eb..63111adac 100644 --- a/llm_studio/app_utils/sections/chat_update.py +++ b/llm_studio/app_utils/sections/chat_update.py @@ -107,7 +107,7 @@ async def answer_chat(q: Q) -> str: else: prev_message = prev_message[0] if cfg.dataset.add_eos_token_to_answer: - prev_message += cfg._tokenizer_eos_token + prev_message += cfg.tokenizer._tokenizer_eos_token full_prompt += prev_message logger.info(f"Full prompt: {full_prompt}") diff --git a/llm_studio/app_utils/sections/experiment.py b/llm_studio/app_utils/sections/experiment.py index 28fa1494a..cec7515dc 100644 --- a/llm_studio/app_utils/sections/experiment.py +++ b/llm_studio/app_utils/sections/experiment.py @@ -2070,7 +2070,11 @@ def get_experiment_summary_code_card(cfg) -> str: ) text = text.replace( "{{end_of_sentence}}", - str(cfg._tokenizer_eos_token) if cfg.dataset.add_eos_token_to_prompt else "", + ( + str(cfg.tokenizer._tokenizer_eos_token) + if cfg.dataset.add_eos_token_to_prompt + else "" + ), ) text = text.replace("{{trust_remote_code}}", str(cfg.environment.trust_remote_code)) diff --git a/llm_studio/app_utils/utils.py b/llm_studio/app_utils/utils.py index fbb0056ab..1abfb77d4 100644 --- a/llm_studio/app_utils/utils.py +++ b/llm_studio/app_utils/utils.py @@ -34,6 +34,7 @@ from sqlitedict import SqliteDict from llm_studio.app_utils.db import Experiment +from llm_studio.python_configs.base import DefaultConfigProblemBase from llm_studio.src import possible_values from llm_studio.src.utils.config_utils import ( _get_type_annotation_error, @@ -98,12 +99,12 @@ def find_free_port(): def start_process( - cfg: Any, gpu_list: List, process_queue: List, env_vars: Dict + cfg: DefaultConfigProblemBase, gpu_list: List, process_queue: List, env_vars: Dict ) -> subprocess.Popen: """Starts train.py for a given configuration setting Args: - cfg: config + cfg: DefaultConfigProblemBase config gpu_list: list of GPUs to use for the training process_queue: list of processes to wait for before starting the training env_vars: dictionary of ENV variables to pass to the training process @@ -346,7 +347,7 @@ async def poll(self): await self.update_ui() -def s3_download_coroutine(q, filename): +def s3_download_coroutine(q: Q, filename: str): download_folder = f"{get_data_dir(q)}/tmp" download_folder = get_valid_temp_data_folder(q, download_folder) @@ -370,7 +371,7 @@ def extract_if_zip(file, actual_path): async def s3_download( - q, bucket, filename, aws_access_key, aws_secret_key + q: Q, bucket, filename, aws_access_key, aws_secret_key ) -> Tuple[str, str]: """Downloads a file from s3 @@ -447,7 +448,7 @@ def azure_file_options(conn_string: str, container: str) -> List[str]: return [] -async def download_progress(q, title, seen_so_far, total_len): +async def download_progress(q: Q, title, seen_so_far, total_len): if seen_so_far is not None and total_len is not None: percentage = seen_so_far / total_len value = percentage @@ -469,7 +470,7 @@ async def download_progress(q, title, seen_so_far, total_len): async def azure_download( - q: Any, conn_string: str, container: str, filename: str + q: Q, conn_string: str, container: str, filename: str ) -> Tuple[str, str]: """Downloads a file from azure @@ -531,7 +532,7 @@ async def azure_download( return azure_path, "".join(filename.split(".")[:-1]) -async def local_download(q: Any, filename: str) -> Tuple[str, str]: +async def local_download(q: Q, filename: str) -> Tuple[str, str]: """Downloads a file from local path Args: @@ -558,7 +559,7 @@ async def local_download(q: Any, filename: str) -> Tuple[str, str]: async def kaggle_download( - q: Any, command: str, kaggle_access_key: str, kaggle_secret_key: str + q: Q, command: str, kaggle_access_key: str, kaggle_secret_key: str ) -> Tuple[str, str]: """ "Downloads a file from kaggle @@ -769,6 +770,23 @@ def get_dataset( return dataset, v +def escape_python_string(s: str) -> str: + """Escapes a python string + + Args: + s: string to escape + + Returns: + Escaped string + """ + + s = s.replace("\\", "\\\\") + s = s.replace("\n", "\\n") + s = s.replace("\t", "\\t") + s = s.replace("\r", "\\r") + return s + + def get_ui_element( k: str, v: Any, @@ -883,7 +901,7 @@ def get_ui_element( ui.textbox( name=pre + k, label=title_label, - value=val, + value=escape_python_string(val), required=False, password=password, tooltip=tooltip, @@ -965,11 +983,11 @@ def get_ui_element( return t -def get_dataset_elements(cfg: Any, q: Q) -> List: +def get_dataset_elements(cfg: DefaultConfigProblemBase, q: Q) -> List: """For a given configuration setting return the according dataset ui components. Args: - cfg: configuration settings + cfg: DefaultConfigProblemBase configuration settings q: Q Returns: @@ -1061,11 +1079,13 @@ def get_dataset_elements(cfg: Any, q: Q) -> List: return items -def check_dependencies(cfg: Any, pre: str, k: str, q: Q, dataset_import: bool = False): +def check_dependencies( + cfg: DefaultConfigProblemBase, pre: str, k: str, q: Q, dataset_import: bool = False +): """Checks all dependencies for a given key Args: - cfg: configuration settings + cfg: DefaultConfigProblemBase configuration settings pre: prefix for client keys k: key to be checked q: Q @@ -1107,7 +1127,7 @@ def check_dependencies(cfg: Any, pre: str, k: str, q: Q, dataset_import: bool = return True -def is_visible(k: str, cfg: Any, q: Q) -> bool: +def is_visible(k: str, cfg: DefaultConfigProblemBase, q: Q) -> bool: """Returns a flag whether a given key should be visible on UI. Args: @@ -1145,7 +1165,7 @@ def get_grid_value(v: Any, type_annotation: Any) -> List[str]: def get_ui_elements( - cfg: Any, + cfg: DefaultConfigProblemBase, q: Q, limit: Optional[List[str]] = None, pre: str = "experiment/start", @@ -1349,7 +1369,7 @@ def get_ui_elements( def parse_ui_elements( - cfg: Any, q: Q, limit: Union[List, str] = "", pre: str = "" + cfg: DefaultConfigProblemBase, q: Q, limit: Union[List, str] = "", pre: str = "" ) -> Any: """Sets configuration settings with arguments from app @@ -1891,11 +1911,13 @@ def set_grid_to_cfg(cfg: Any, grid: Dict[str, List]) -> Any: return cfg -def start_experiment(cfg: Any, q: Q, pre: str, gpu_list: Optional[List] = None) -> None: +def start_experiment( + cfg: DefaultConfigProblemBase, q: Q, pre: str, gpu_list: Optional[List] = None +) -> None: """Starts an experiment Args: - cfg: configuration settings + cfg: DefaultConfigProblemBase configuration settings q: Q pre: prefix for client keys gpu_list: list of GPUs available @@ -2022,7 +2044,7 @@ def dir_file_table(current_path: str) -> pd.DataFrame: return pd.DataFrame({current_path: results}) -def get_download_link(q, artifact_path): +def get_download_link(q: Q, artifact_path): new_path = os.path.relpath(artifact_path, get_output_dir(q)) new_path = os.path.join(get_download_dir(q), new_path) url_path = os.path.relpath(new_path, get_output_dir(q)) @@ -2148,17 +2170,17 @@ def remove_temp_files(q: Q): os.remove(file) -def get_gpu_usage(): - usage = 0.0 - all_gpus = GPUtil.getGPUs() +def get_gpu_usage() -> float: + usage: float = 0.0 + all_gpus: List[GPUtil.GPU] = GPUtil.getGPUs() for gpu in all_gpus: - usage += gpu.load + usage += float(gpu.load) usage /= len(all_gpus) - return usage * 100 + return usage * 100.0 -def get_single_gpu_usage(sig_figs=1, highlight=None): +def get_single_gpu_usage(sig_figs: int = 1, highlight: Optional[str] = None): all_gpus = GPUtil.getGPUs() items = [] for i, gpu in enumerate(all_gpus): @@ -2184,11 +2206,11 @@ def get_single_gpu_usage(sig_figs=1, highlight=None): return items -def copy_config(cfg: Any, q: Q) -> Any: +def copy_config(cfg: DefaultConfigProblemBase, q: Q) -> Any: """Makes a copy of the config Args: - cfg: config object + cfg: DefaultConfigProblemBase config object Returns: copy of the config """ @@ -2217,7 +2239,7 @@ def make_label(title: str, appendix: str = "") -> str: return label -def get_cfg_list_items(cfg) -> List: +def get_cfg_list_items(cfg: DefaultConfigProblemBase) -> List: items = parse_cfg_dataclass(cfg) x = [] for item in items: diff --git a/llm_studio/python_configs/text_causal_language_modeling_config.py b/llm_studio/python_configs/text_causal_language_modeling_config.py index 7ee9371cb..e361411d3 100644 --- a/llm_studio/python_configs/text_causal_language_modeling_config.py +++ b/llm_studio/python_configs/text_causal_language_modeling_config.py @@ -40,6 +40,7 @@ class ConfigNLPCausalLMDataset(DefaultConfig): system_column: str = "system" prompt_column: Tuple[str, ...] = ("instruction", "input") + prompt_column_separator: str = "\n\n" answer_column: str = "output" parent_id_column: str = "parent_id" diff --git a/llm_studio/src/augmentations/nlp_aug.py b/llm_studio/src/augmentations/nlp_aug.py index f62199fff..4f1c6b5c6 100644 --- a/llm_studio/src/augmentations/nlp_aug.py +++ b/llm_studio/src/augmentations/nlp_aug.py @@ -45,7 +45,7 @@ def forward(self, batch: Dict) -> Dict: .bool() # & special_mask ).bool() - input_ids[mask] = self.cfg._tokenizer_mask_token_id + input_ids[mask] = self.cfg.tokenizer._tokenizer_mask_token_id batch["input_ids"] = input_ids.clone() batch["attention_mask"][mask] = 0 if batch["labels"].shape[1] == batch["input_ids"].shape[1]: diff --git a/llm_studio/src/datasets/conversation_chain_handler.py b/llm_studio/src/datasets/conversation_chain_handler.py index 93045669b..8bf879571 100644 --- a/llm_studio/src/datasets/conversation_chain_handler.py +++ b/llm_studio/src/datasets/conversation_chain_handler.py @@ -57,7 +57,7 @@ def __init__( # Do not set self.cfg = cfg, as ConversationChainHandler # will be used with PatchedAttribute context manager. self.conversation_chain_ids = self.get_conversation_chain_ids(cfg, df) - self.prompts = get_texts(df, cfg, separator="") + self.prompts = get_texts(df, cfg) self.answers = self.get_answers(df, cfg) self.systems = self.get_systems(cfg, df) diff --git a/llm_studio/src/datasets/text_causal_language_modeling_ds.py b/llm_studio/src/datasets/text_causal_language_modeling_ds.py index 6c2bc1a85..291a3e52b 100644 --- a/llm_studio/src/datasets/text_causal_language_modeling_ds.py +++ b/llm_studio/src/datasets/text_causal_language_modeling_ds.py @@ -110,7 +110,7 @@ def parse_prompt(cfg: Any, prompt: str): f"{codecs.decode(cfg.dataset.text_prompt_start, 'unicode_escape')}{prompt}" ) if cfg.dataset.add_eos_token_to_prompt: - prompt += cfg._tokenizer_eos_token + prompt += cfg.tokenizer._tokenizer_eos_token prompt = ( f"{prompt}" f"{codecs.decode(cfg.dataset.text_answer_separator, 'unicode_escape')}" @@ -120,7 +120,7 @@ def parse_prompt(cfg: Any, prompt: str): @staticmethod def parse_answer(cfg: Any, answer: str): if cfg.dataset.add_eos_token_to_answer: - answer += cfg._tokenizer_eos_token + answer += cfg.tokenizer._tokenizer_eos_token return answer @staticmethod @@ -132,7 +132,7 @@ def parse_system(cfg: Any, system: str): f"{codecs.decode(cfg.dataset.text_system_start, 'unicode_escape')}{system}" ) if cfg.dataset.add_eos_token_to_system: - system += cfg._tokenizer_eos_token + system += cfg.tokenizer._tokenizer_eos_token return system @staticmethod diff --git a/llm_studio/src/datasets/text_utils.py b/llm_studio/src/datasets/text_utils.py index e5504d890..cf86e0554 100644 --- a/llm_studio/src/datasets/text_utils.py +++ b/llm_studio/src/datasets/text_utils.py @@ -2,17 +2,16 @@ import json import logging import os -from typing import Any +from pandas import DataFrame from transformers import AutoTokenizer -logger = logging.getLogger(__name__) - +from llm_studio.python_configs.base import DefaultConfigProblemBase -TEXT_SEPARATOR = "" +logger = logging.getLogger(__name__) -def get_texts(df, cfg, separator=None): +def get_texts(df: DataFrame, cfg: DefaultConfigProblemBase): if isinstance(cfg.dataset.prompt_column, str): # single column dataset texts = df[cfg.dataset.prompt_column].astype(str) @@ -24,17 +23,15 @@ def get_texts(df, cfg, separator=None): for column in columns: df[column] = df[column].astype(str) - if separator is None: - separator = getattr(cfg, "_tokenizer_sep_token", TEXT_SEPARATOR) + join_str = codecs.decode(cfg.dataset.prompt_column_separator, "unicode_escape") - join_str = f" {separator} " texts = df[columns].astype(str) texts = texts.apply(lambda x: join_str.join(x), axis=1).values return texts -def get_tokenizer(cfg: Any): +def get_tokenizer(cfg: DefaultConfigProblemBase): kwargs = dict( revision=cfg.environment.huggingface_branch, @@ -87,23 +84,19 @@ def get_tokenizer(cfg: Any): tokenizer.bos_token = tokenizer.eos_token if tokenizer.cls_token is None: tokenizer.cls_token = tokenizer.eos_token - if tokenizer.sep_token is None: - tokenizer.sep_token = tokenizer.eos_token - - cfg._tokenizer_sep_token = tokenizer.sep_token if tokenizer.unk_token_id is not None: - cfg._tokenizer_mask_token_id = tokenizer.unk_token_id + cfg.tokenizer._tokenizer_mask_token_id = tokenizer.unk_token_id elif tokenizer.mask_token_id is not None: - cfg._tokenizer_mask_token_id = tokenizer.mask_token_id + cfg.tokenizer._tokenizer_mask_token_id = tokenizer.mask_token_id elif tokenizer.pad_token_id is not None: - cfg._tokenizer_mask_token_id = tokenizer.pad_token_id + cfg.tokenizer._tokenizer_mask_token_id = tokenizer.pad_token_id else: # setting the mask token id to the last token in the vocabulary # this usually is a safe choice and mostly refers to eos token - cfg._tokenizer_mask_token_id = len(tokenizer) - 1 + cfg.tokenizer._tokenizer_mask_token_id = len(tokenizer) - 1 - cfg._tokenizer_eos_token = tokenizer.eos_token + cfg.tokenizer._tokenizer_eos_token = tokenizer.eos_token if hasattr(cfg.prediction, "stop_tokens"): set_stop_token_ids(cfg, tokenizer) diff --git a/llm_studio/src/metrics/text_causal_language_modeling_metrics.py b/llm_studio/src/metrics/text_causal_language_modeling_metrics.py index 2fc0e5922..202b86f84 100644 --- a/llm_studio/src/metrics/text_causal_language_modeling_metrics.py +++ b/llm_studio/src/metrics/text_causal_language_modeling_metrics.py @@ -14,6 +14,7 @@ from torch import nn from tqdm import tqdm +from llm_studio.python_configs.base import DefaultConfigProblemBase from llm_studio.src.datasets.text_utils import get_texts from llm_studio.src.utils.logging_utils import TqdmToLogger @@ -25,7 +26,7 @@ def sacrebleu_score( - cfg: Any, results: Dict, val_df: pd.DataFrame, metric: Metric + cfg: DefaultConfigProblemBase, results: Dict, val_df: pd.DataFrame, metric: Metric ) -> NDArray: scores = [] for predicted_text, target_text in zip( @@ -39,7 +40,7 @@ def sacrebleu_score( return np.array(scores) -def call_openai_api(template, model, deployment_id=None): +def call_openai_api(template: str, model: str): if os.getenv("OPENAI_API_TYPE", "open_ai") == "azure": endpoint = os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1") client: AzureOpenAI | OpenAI = AzureOpenAI( @@ -85,7 +86,7 @@ def call_openai_api(template, model, deployment_id=None): return score, ret -def rate_reply(filled_eval_template, model): +def rate_reply(filled_eval_template: str, model: str): try: return call_openai_api(filled_eval_template, model) except Exception as e: @@ -94,13 +95,13 @@ def rate_reply(filled_eval_template, model): def gpt_score( - cfg: Any, + cfg: DefaultConfigProblemBase, results: Dict, val_df: pd.DataFrame, raw_results: bool = False, ) -> Union[NDArray, Tuple[NDArray, List[str]]]: vdf = val_df.copy() - vdf["_PROMPT"] = get_texts(val_df, cfg, separator="") + vdf["_PROMPT"] = get_texts(val_df, cfg) vdf["_PREDICTED_TEXT"] = results["predicted_text"] vdf["_TARGET_TEXT"] = results["target_text"] @@ -150,7 +151,7 @@ def gpt_score( class Perplexity(nn.Module): - def __init__(self, cfg: Any, reduce: bool = True): + def __init__(self, cfg: DefaultConfigProblemBase, reduce: bool = True): super().__init__() self.cfg = cfg self.loss_fn = nn.CrossEntropyLoss() @@ -170,7 +171,7 @@ def forward(self, logits, labels): return perplexity -def perplexity(cfg: Any, results: Dict, val_df: pd.DataFrame): +def perplexity(cfg: DefaultConfigProblemBase, results: Dict, val_df: pd.DataFrame): return results["perplexity"].detach().float().cpu().numpy() diff --git a/llm_studio/src/utils/config_utils.py b/llm_studio/src/utils/config_utils.py index 5ab942791..516d314aa 100644 --- a/llm_studio/src/utils/config_utils.py +++ b/llm_studio/src/utils/config_utils.py @@ -25,7 +25,7 @@ def rreload(module): importlib.reload(attribute2) -def _load_cls(module_path: str, cls_name: str) -> Any: +def _load_cls(module_path: str, cls_name: str) -> DefaultConfigProblemBase: """Loads the python class. Args: @@ -50,12 +50,14 @@ def _load_cls(module_path: str, cls_name: str) -> Any: module_path, cls_name ) - cls = getattr(module, cls_name) + cls: DefaultConfigProblemBase = getattr(module, cls_name)() return cls -def load_config_py(config_path: str, config_name: str = "Config"): +def load_config_py( + config_path: str, config_name: str = "Config" +) -> DefaultConfigProblemBase: """Loads the config class. Args: @@ -66,7 +68,7 @@ def load_config_py(config_path: str, config_name: str = "Config"): Loaded config class """ - return _load_cls(config_path, config_name)() + return _load_cls(config_path, config_name) def _get_type_annotation_error(v: Any, type_annotation: Type) -> ValueError: diff --git a/tests/src/utils/test_data_utils.py b/tests/src/utils/test_data_utils.py index 262bd60fa..55efc8ba4 100644 --- a/tests/src/utils/test_data_utils.py +++ b/tests/src/utils/test_data_utils.py @@ -98,6 +98,7 @@ def test_oasst_data_automatic_split(tmp_path: pathlib.Path): cfg_mock.dataset.prompt_column = ("instruction",) cfg_mock.dataset.answer_column = "output" cfg_mock.dataset.parent_id_column = "parent_id" + cfg_mock.dataset.prompt_column_separator = "\n\n" cfg_mock.dataset.validation_strategy = "automatic" diff --git a/train.py b/train.py index 9072badf5..6b9b14332 100644 --- a/train.py +++ b/train.py @@ -24,6 +24,7 @@ from tqdm import tqdm from transformers.deepspeed import HfDeepSpeedConfig +from llm_studio.python_configs.base import DefaultConfigProblemBase from llm_studio.src.loggers import MainLogger from llm_studio.src.utils.config_utils import ( load_config_py, @@ -72,7 +73,7 @@ def run_eval( - cfg, + cfg: DefaultConfigProblemBase, model: torch.nn.Module, val_dataloader: DataLoader, val_df: pd.DataFrame, @@ -154,7 +155,7 @@ def run_eval( def run_train( - cfg: Any, + cfg: DefaultConfigProblemBase, model: torch.nn.Module, optimizer, scheduler, @@ -166,7 +167,7 @@ def run_train( """Runs the training loop. Args: - cfg: config object + cfg: DefaultConfigProblemBase config object model: model train_dataloader: custom training Dataloader train_df: train DataFrame @@ -427,11 +428,11 @@ def run_train( return val_loss, val_metric -def run(cfg: Any) -> float: +def run(cfg: DefaultConfigProblemBase) -> float: """Runs the routine. Args: - cfg: config object with all the hyperparameters + cfg: DefaultConfigProblemBase config object with all the hyperparameters """ os.makedirs(cfg.output_directory, exist_ok=True) @@ -686,7 +687,7 @@ def run(cfg: Any) -> float: parser_args, unknown = parser.parse_known_args(sys.argv) if "config" in parser_args: - cfg = load_config_py(parser_args.config) + cfg: DefaultConfigProblemBase = load_config_py(parser_args.config) elif "yaml" in parser_args: cfg = load_config_yaml(parser_args.yaml) else: