diff --git a/documentation/docs/guide/experiments/experiment-settings.md b/documentation/docs/guide/experiments/experiment-settings.md
index 695a89917..f46f75e6b 100644
--- a/documentation/docs/guide/experiments/experiment-settings.md
+++ b/documentation/docs/guide/experiments/experiment-settings.md
@@ -11,6 +11,7 @@ import DSvalidationStrategy from '../../tooltips/experiments/_validation-strateg
import DSvalidationSize from '../../tooltips/experiments/_validation-size.mdx';
import DSdataSample from '../../tooltips/experiments/_data-sample.mdx';
import DSpromptColumn from '../../tooltips/experiments/_prompt-column.mdx';
+import DSPromptColumnSeparator from '../../tooltips/experiments/_prompt-column-separator.mdx';
import DSsystemColumn from '../../tooltips/experiments/_system-column.mdx';
import DSanswerColumn from '../../tooltips/experiments/_answer-column.mdx';
import DSparentIdColumn from '../../tooltips/experiments/_parent-id-column.mdx';
@@ -141,6 +142,10 @@ The settings under each category are listed and described below.
+### Prompt column separator
+
+
+
### Answer column
diff --git a/documentation/docs/tooltips/experiments/_prompt-column-separator.mdx b/documentation/docs/tooltips/experiments/_prompt-column-separator.mdx
new file mode 100644
index 000000000..d96f3d85d
--- /dev/null
+++ b/documentation/docs/tooltips/experiments/_prompt-column-separator.mdx
@@ -0,0 +1 @@
+If multiple prompt columns are selected, the columns are concatenated with the separator defined here. If only a single prompt column is selected, this setting is ignored.
\ No newline at end of file
diff --git a/documentation/docs/tooltips/experiments/_prompt-column.mdx b/documentation/docs/tooltips/experiments/_prompt-column.mdx
index d51c06e97..a7377020e 100644
--- a/documentation/docs/tooltips/experiments/_prompt-column.mdx
+++ b/documentation/docs/tooltips/experiments/_prompt-column.mdx
@@ -1 +1 @@
-The column in the dataset containing the user prompt.
\ No newline at end of file
+One column or multiple columns in the dataset containing the user prompt. If multiple columns are selected, the columns are concatenated with a separator defined in **Prompt Column Separator**.
\ No newline at end of file
diff --git a/llm_studio/app_utils/hugging_face_utils.py b/llm_studio/app_utils/hugging_face_utils.py
index 4e8d82011..5c7d00e94 100644
--- a/llm_studio/app_utils/hugging_face_utils.py
+++ b/llm_studio/app_utils/hugging_face_utils.py
@@ -40,7 +40,9 @@ def get_model_card(cfg, model, repo_id) -> huggingface_hub.ModelCard:
text_answer_separator=cfg.dataset.text_answer_separator,
trust_remote_code=cfg.environment.trust_remote_code,
end_of_sentence=(
- cfg._tokenizer_eos_token if cfg.dataset.add_eos_token_to_prompt else ""
+ cfg.tokenizer._tokenizer_eos_token
+ if cfg.dataset.add_eos_token_to_prompt
+ else ""
),
)
if cfg.problem_type not in NON_GENERATION_PROBLEM_TYPES:
diff --git a/llm_studio/app_utils/sections/chat_update.py b/llm_studio/app_utils/sections/chat_update.py
index 66cd4c1eb..63111adac 100644
--- a/llm_studio/app_utils/sections/chat_update.py
+++ b/llm_studio/app_utils/sections/chat_update.py
@@ -107,7 +107,7 @@ async def answer_chat(q: Q) -> str:
else:
prev_message = prev_message[0]
if cfg.dataset.add_eos_token_to_answer:
- prev_message += cfg._tokenizer_eos_token
+ prev_message += cfg.tokenizer._tokenizer_eos_token
full_prompt += prev_message
logger.info(f"Full prompt: {full_prompt}")
diff --git a/llm_studio/app_utils/sections/experiment.py b/llm_studio/app_utils/sections/experiment.py
index 28fa1494a..cec7515dc 100644
--- a/llm_studio/app_utils/sections/experiment.py
+++ b/llm_studio/app_utils/sections/experiment.py
@@ -2070,7 +2070,11 @@ def get_experiment_summary_code_card(cfg) -> str:
)
text = text.replace(
"{{end_of_sentence}}",
- str(cfg._tokenizer_eos_token) if cfg.dataset.add_eos_token_to_prompt else "",
+ (
+ str(cfg.tokenizer._tokenizer_eos_token)
+ if cfg.dataset.add_eos_token_to_prompt
+ else ""
+ ),
)
text = text.replace("{{trust_remote_code}}", str(cfg.environment.trust_remote_code))
diff --git a/llm_studio/app_utils/utils.py b/llm_studio/app_utils/utils.py
index fbb0056ab..1abfb77d4 100644
--- a/llm_studio/app_utils/utils.py
+++ b/llm_studio/app_utils/utils.py
@@ -34,6 +34,7 @@
from sqlitedict import SqliteDict
from llm_studio.app_utils.db import Experiment
+from llm_studio.python_configs.base import DefaultConfigProblemBase
from llm_studio.src import possible_values
from llm_studio.src.utils.config_utils import (
_get_type_annotation_error,
@@ -98,12 +99,12 @@ def find_free_port():
def start_process(
- cfg: Any, gpu_list: List, process_queue: List, env_vars: Dict
+ cfg: DefaultConfigProblemBase, gpu_list: List, process_queue: List, env_vars: Dict
) -> subprocess.Popen:
"""Starts train.py for a given configuration setting
Args:
- cfg: config
+ cfg: DefaultConfigProblemBase config
gpu_list: list of GPUs to use for the training
process_queue: list of processes to wait for before starting the training
env_vars: dictionary of ENV variables to pass to the training process
@@ -346,7 +347,7 @@ async def poll(self):
await self.update_ui()
-def s3_download_coroutine(q, filename):
+def s3_download_coroutine(q: Q, filename: str):
download_folder = f"{get_data_dir(q)}/tmp"
download_folder = get_valid_temp_data_folder(q, download_folder)
@@ -370,7 +371,7 @@ def extract_if_zip(file, actual_path):
async def s3_download(
- q, bucket, filename, aws_access_key, aws_secret_key
+ q: Q, bucket, filename, aws_access_key, aws_secret_key
) -> Tuple[str, str]:
"""Downloads a file from s3
@@ -447,7 +448,7 @@ def azure_file_options(conn_string: str, container: str) -> List[str]:
return []
-async def download_progress(q, title, seen_so_far, total_len):
+async def download_progress(q: Q, title, seen_so_far, total_len):
if seen_so_far is not None and total_len is not None:
percentage = seen_so_far / total_len
value = percentage
@@ -469,7 +470,7 @@ async def download_progress(q, title, seen_so_far, total_len):
async def azure_download(
- q: Any, conn_string: str, container: str, filename: str
+ q: Q, conn_string: str, container: str, filename: str
) -> Tuple[str, str]:
"""Downloads a file from azure
@@ -531,7 +532,7 @@ async def azure_download(
return azure_path, "".join(filename.split(".")[:-1])
-async def local_download(q: Any, filename: str) -> Tuple[str, str]:
+async def local_download(q: Q, filename: str) -> Tuple[str, str]:
"""Downloads a file from local path
Args:
@@ -558,7 +559,7 @@ async def local_download(q: Any, filename: str) -> Tuple[str, str]:
async def kaggle_download(
- q: Any, command: str, kaggle_access_key: str, kaggle_secret_key: str
+ q: Q, command: str, kaggle_access_key: str, kaggle_secret_key: str
) -> Tuple[str, str]:
""" "Downloads a file from kaggle
@@ -769,6 +770,23 @@ def get_dataset(
return dataset, v
+def escape_python_string(s: str) -> str:
+ """Escapes a python string
+
+ Args:
+ s: string to escape
+
+ Returns:
+ Escaped string
+ """
+
+ s = s.replace("\\", "\\\\")
+ s = s.replace("\n", "\\n")
+ s = s.replace("\t", "\\t")
+ s = s.replace("\r", "\\r")
+ return s
+
+
def get_ui_element(
k: str,
v: Any,
@@ -883,7 +901,7 @@ def get_ui_element(
ui.textbox(
name=pre + k,
label=title_label,
- value=val,
+ value=escape_python_string(val),
required=False,
password=password,
tooltip=tooltip,
@@ -965,11 +983,11 @@ def get_ui_element(
return t
-def get_dataset_elements(cfg: Any, q: Q) -> List:
+def get_dataset_elements(cfg: DefaultConfigProblemBase, q: Q) -> List:
"""For a given configuration setting return the according dataset ui components.
Args:
- cfg: configuration settings
+ cfg: DefaultConfigProblemBase configuration settings
q: Q
Returns:
@@ -1061,11 +1079,13 @@ def get_dataset_elements(cfg: Any, q: Q) -> List:
return items
-def check_dependencies(cfg: Any, pre: str, k: str, q: Q, dataset_import: bool = False):
+def check_dependencies(
+ cfg: DefaultConfigProblemBase, pre: str, k: str, q: Q, dataset_import: bool = False
+):
"""Checks all dependencies for a given key
Args:
- cfg: configuration settings
+ cfg: DefaultConfigProblemBase configuration settings
pre: prefix for client keys
k: key to be checked
q: Q
@@ -1107,7 +1127,7 @@ def check_dependencies(cfg: Any, pre: str, k: str, q: Q, dataset_import: bool =
return True
-def is_visible(k: str, cfg: Any, q: Q) -> bool:
+def is_visible(k: str, cfg: DefaultConfigProblemBase, q: Q) -> bool:
"""Returns a flag whether a given key should be visible on UI.
Args:
@@ -1145,7 +1165,7 @@ def get_grid_value(v: Any, type_annotation: Any) -> List[str]:
def get_ui_elements(
- cfg: Any,
+ cfg: DefaultConfigProblemBase,
q: Q,
limit: Optional[List[str]] = None,
pre: str = "experiment/start",
@@ -1349,7 +1369,7 @@ def get_ui_elements(
def parse_ui_elements(
- cfg: Any, q: Q, limit: Union[List, str] = "", pre: str = ""
+ cfg: DefaultConfigProblemBase, q: Q, limit: Union[List, str] = "", pre: str = ""
) -> Any:
"""Sets configuration settings with arguments from app
@@ -1891,11 +1911,13 @@ def set_grid_to_cfg(cfg: Any, grid: Dict[str, List]) -> Any:
return cfg
-def start_experiment(cfg: Any, q: Q, pre: str, gpu_list: Optional[List] = None) -> None:
+def start_experiment(
+ cfg: DefaultConfigProblemBase, q: Q, pre: str, gpu_list: Optional[List] = None
+) -> None:
"""Starts an experiment
Args:
- cfg: configuration settings
+ cfg: DefaultConfigProblemBase configuration settings
q: Q
pre: prefix for client keys
gpu_list: list of GPUs available
@@ -2022,7 +2044,7 @@ def dir_file_table(current_path: str) -> pd.DataFrame:
return pd.DataFrame({current_path: results})
-def get_download_link(q, artifact_path):
+def get_download_link(q: Q, artifact_path):
new_path = os.path.relpath(artifact_path, get_output_dir(q))
new_path = os.path.join(get_download_dir(q), new_path)
url_path = os.path.relpath(new_path, get_output_dir(q))
@@ -2148,17 +2170,17 @@ def remove_temp_files(q: Q):
os.remove(file)
-def get_gpu_usage():
- usage = 0.0
- all_gpus = GPUtil.getGPUs()
+def get_gpu_usage() -> float:
+ usage: float = 0.0
+ all_gpus: List[GPUtil.GPU] = GPUtil.getGPUs()
for gpu in all_gpus:
- usage += gpu.load
+ usage += float(gpu.load)
usage /= len(all_gpus)
- return usage * 100
+ return usage * 100.0
-def get_single_gpu_usage(sig_figs=1, highlight=None):
+def get_single_gpu_usage(sig_figs: int = 1, highlight: Optional[str] = None):
all_gpus = GPUtil.getGPUs()
items = []
for i, gpu in enumerate(all_gpus):
@@ -2184,11 +2206,11 @@ def get_single_gpu_usage(sig_figs=1, highlight=None):
return items
-def copy_config(cfg: Any, q: Q) -> Any:
+def copy_config(cfg: DefaultConfigProblemBase, q: Q) -> Any:
"""Makes a copy of the config
Args:
- cfg: config object
+ cfg: DefaultConfigProblemBase config object
Returns:
copy of the config
"""
@@ -2217,7 +2239,7 @@ def make_label(title: str, appendix: str = "") -> str:
return label
-def get_cfg_list_items(cfg) -> List:
+def get_cfg_list_items(cfg: DefaultConfigProblemBase) -> List:
items = parse_cfg_dataclass(cfg)
x = []
for item in items:
diff --git a/llm_studio/python_configs/text_causal_language_modeling_config.py b/llm_studio/python_configs/text_causal_language_modeling_config.py
index 7ee9371cb..e361411d3 100644
--- a/llm_studio/python_configs/text_causal_language_modeling_config.py
+++ b/llm_studio/python_configs/text_causal_language_modeling_config.py
@@ -40,6 +40,7 @@ class ConfigNLPCausalLMDataset(DefaultConfig):
system_column: str = "system"
prompt_column: Tuple[str, ...] = ("instruction", "input")
+ prompt_column_separator: str = "\n\n"
answer_column: str = "output"
parent_id_column: str = "parent_id"
diff --git a/llm_studio/src/augmentations/nlp_aug.py b/llm_studio/src/augmentations/nlp_aug.py
index f62199fff..4f1c6b5c6 100644
--- a/llm_studio/src/augmentations/nlp_aug.py
+++ b/llm_studio/src/augmentations/nlp_aug.py
@@ -45,7 +45,7 @@ def forward(self, batch: Dict) -> Dict:
.bool()
# & special_mask
).bool()
- input_ids[mask] = self.cfg._tokenizer_mask_token_id
+ input_ids[mask] = self.cfg.tokenizer._tokenizer_mask_token_id
batch["input_ids"] = input_ids.clone()
batch["attention_mask"][mask] = 0
if batch["labels"].shape[1] == batch["input_ids"].shape[1]:
diff --git a/llm_studio/src/datasets/conversation_chain_handler.py b/llm_studio/src/datasets/conversation_chain_handler.py
index 93045669b..8bf879571 100644
--- a/llm_studio/src/datasets/conversation_chain_handler.py
+++ b/llm_studio/src/datasets/conversation_chain_handler.py
@@ -57,7 +57,7 @@ def __init__(
# Do not set self.cfg = cfg, as ConversationChainHandler
# will be used with PatchedAttribute context manager.
self.conversation_chain_ids = self.get_conversation_chain_ids(cfg, df)
- self.prompts = get_texts(df, cfg, separator="")
+ self.prompts = get_texts(df, cfg)
self.answers = self.get_answers(df, cfg)
self.systems = self.get_systems(cfg, df)
diff --git a/llm_studio/src/datasets/text_causal_language_modeling_ds.py b/llm_studio/src/datasets/text_causal_language_modeling_ds.py
index 6c2bc1a85..291a3e52b 100644
--- a/llm_studio/src/datasets/text_causal_language_modeling_ds.py
+++ b/llm_studio/src/datasets/text_causal_language_modeling_ds.py
@@ -110,7 +110,7 @@ def parse_prompt(cfg: Any, prompt: str):
f"{codecs.decode(cfg.dataset.text_prompt_start, 'unicode_escape')}{prompt}"
)
if cfg.dataset.add_eos_token_to_prompt:
- prompt += cfg._tokenizer_eos_token
+ prompt += cfg.tokenizer._tokenizer_eos_token
prompt = (
f"{prompt}"
f"{codecs.decode(cfg.dataset.text_answer_separator, 'unicode_escape')}"
@@ -120,7 +120,7 @@ def parse_prompt(cfg: Any, prompt: str):
@staticmethod
def parse_answer(cfg: Any, answer: str):
if cfg.dataset.add_eos_token_to_answer:
- answer += cfg._tokenizer_eos_token
+ answer += cfg.tokenizer._tokenizer_eos_token
return answer
@staticmethod
@@ -132,7 +132,7 @@ def parse_system(cfg: Any, system: str):
f"{codecs.decode(cfg.dataset.text_system_start, 'unicode_escape')}{system}"
)
if cfg.dataset.add_eos_token_to_system:
- system += cfg._tokenizer_eos_token
+ system += cfg.tokenizer._tokenizer_eos_token
return system
@staticmethod
diff --git a/llm_studio/src/datasets/text_utils.py b/llm_studio/src/datasets/text_utils.py
index e5504d890..cf86e0554 100644
--- a/llm_studio/src/datasets/text_utils.py
+++ b/llm_studio/src/datasets/text_utils.py
@@ -2,17 +2,16 @@
import json
import logging
import os
-from typing import Any
+from pandas import DataFrame
from transformers import AutoTokenizer
-logger = logging.getLogger(__name__)
-
+from llm_studio.python_configs.base import DefaultConfigProblemBase
-TEXT_SEPARATOR = ""
+logger = logging.getLogger(__name__)
-def get_texts(df, cfg, separator=None):
+def get_texts(df: DataFrame, cfg: DefaultConfigProblemBase):
if isinstance(cfg.dataset.prompt_column, str):
# single column dataset
texts = df[cfg.dataset.prompt_column].astype(str)
@@ -24,17 +23,15 @@ def get_texts(df, cfg, separator=None):
for column in columns:
df[column] = df[column].astype(str)
- if separator is None:
- separator = getattr(cfg, "_tokenizer_sep_token", TEXT_SEPARATOR)
+ join_str = codecs.decode(cfg.dataset.prompt_column_separator, "unicode_escape")
- join_str = f" {separator} "
texts = df[columns].astype(str)
texts = texts.apply(lambda x: join_str.join(x), axis=1).values
return texts
-def get_tokenizer(cfg: Any):
+def get_tokenizer(cfg: DefaultConfigProblemBase):
kwargs = dict(
revision=cfg.environment.huggingface_branch,
@@ -87,23 +84,19 @@ def get_tokenizer(cfg: Any):
tokenizer.bos_token = tokenizer.eos_token
if tokenizer.cls_token is None:
tokenizer.cls_token = tokenizer.eos_token
- if tokenizer.sep_token is None:
- tokenizer.sep_token = tokenizer.eos_token
-
- cfg._tokenizer_sep_token = tokenizer.sep_token
if tokenizer.unk_token_id is not None:
- cfg._tokenizer_mask_token_id = tokenizer.unk_token_id
+ cfg.tokenizer._tokenizer_mask_token_id = tokenizer.unk_token_id
elif tokenizer.mask_token_id is not None:
- cfg._tokenizer_mask_token_id = tokenizer.mask_token_id
+ cfg.tokenizer._tokenizer_mask_token_id = tokenizer.mask_token_id
elif tokenizer.pad_token_id is not None:
- cfg._tokenizer_mask_token_id = tokenizer.pad_token_id
+ cfg.tokenizer._tokenizer_mask_token_id = tokenizer.pad_token_id
else:
# setting the mask token id to the last token in the vocabulary
# this usually is a safe choice and mostly refers to eos token
- cfg._tokenizer_mask_token_id = len(tokenizer) - 1
+ cfg.tokenizer._tokenizer_mask_token_id = len(tokenizer) - 1
- cfg._tokenizer_eos_token = tokenizer.eos_token
+ cfg.tokenizer._tokenizer_eos_token = tokenizer.eos_token
if hasattr(cfg.prediction, "stop_tokens"):
set_stop_token_ids(cfg, tokenizer)
diff --git a/llm_studio/src/metrics/text_causal_language_modeling_metrics.py b/llm_studio/src/metrics/text_causal_language_modeling_metrics.py
index 2fc0e5922..202b86f84 100644
--- a/llm_studio/src/metrics/text_causal_language_modeling_metrics.py
+++ b/llm_studio/src/metrics/text_causal_language_modeling_metrics.py
@@ -14,6 +14,7 @@
from torch import nn
from tqdm import tqdm
+from llm_studio.python_configs.base import DefaultConfigProblemBase
from llm_studio.src.datasets.text_utils import get_texts
from llm_studio.src.utils.logging_utils import TqdmToLogger
@@ -25,7 +26,7 @@
def sacrebleu_score(
- cfg: Any, results: Dict, val_df: pd.DataFrame, metric: Metric
+ cfg: DefaultConfigProblemBase, results: Dict, val_df: pd.DataFrame, metric: Metric
) -> NDArray:
scores = []
for predicted_text, target_text in zip(
@@ -39,7 +40,7 @@ def sacrebleu_score(
return np.array(scores)
-def call_openai_api(template, model, deployment_id=None):
+def call_openai_api(template: str, model: str):
if os.getenv("OPENAI_API_TYPE", "open_ai") == "azure":
endpoint = os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1")
client: AzureOpenAI | OpenAI = AzureOpenAI(
@@ -85,7 +86,7 @@ def call_openai_api(template, model, deployment_id=None):
return score, ret
-def rate_reply(filled_eval_template, model):
+def rate_reply(filled_eval_template: str, model: str):
try:
return call_openai_api(filled_eval_template, model)
except Exception as e:
@@ -94,13 +95,13 @@ def rate_reply(filled_eval_template, model):
def gpt_score(
- cfg: Any,
+ cfg: DefaultConfigProblemBase,
results: Dict,
val_df: pd.DataFrame,
raw_results: bool = False,
) -> Union[NDArray, Tuple[NDArray, List[str]]]:
vdf = val_df.copy()
- vdf["_PROMPT"] = get_texts(val_df, cfg, separator="")
+ vdf["_PROMPT"] = get_texts(val_df, cfg)
vdf["_PREDICTED_TEXT"] = results["predicted_text"]
vdf["_TARGET_TEXT"] = results["target_text"]
@@ -150,7 +151,7 @@ def gpt_score(
class Perplexity(nn.Module):
- def __init__(self, cfg: Any, reduce: bool = True):
+ def __init__(self, cfg: DefaultConfigProblemBase, reduce: bool = True):
super().__init__()
self.cfg = cfg
self.loss_fn = nn.CrossEntropyLoss()
@@ -170,7 +171,7 @@ def forward(self, logits, labels):
return perplexity
-def perplexity(cfg: Any, results: Dict, val_df: pd.DataFrame):
+def perplexity(cfg: DefaultConfigProblemBase, results: Dict, val_df: pd.DataFrame):
return results["perplexity"].detach().float().cpu().numpy()
diff --git a/llm_studio/src/utils/config_utils.py b/llm_studio/src/utils/config_utils.py
index 5ab942791..516d314aa 100644
--- a/llm_studio/src/utils/config_utils.py
+++ b/llm_studio/src/utils/config_utils.py
@@ -25,7 +25,7 @@ def rreload(module):
importlib.reload(attribute2)
-def _load_cls(module_path: str, cls_name: str) -> Any:
+def _load_cls(module_path: str, cls_name: str) -> DefaultConfigProblemBase:
"""Loads the python class.
Args:
@@ -50,12 +50,14 @@ def _load_cls(module_path: str, cls_name: str) -> Any:
module_path, cls_name
)
- cls = getattr(module, cls_name)
+ cls: DefaultConfigProblemBase = getattr(module, cls_name)()
return cls
-def load_config_py(config_path: str, config_name: str = "Config"):
+def load_config_py(
+ config_path: str, config_name: str = "Config"
+) -> DefaultConfigProblemBase:
"""Loads the config class.
Args:
@@ -66,7 +68,7 @@ def load_config_py(config_path: str, config_name: str = "Config"):
Loaded config class
"""
- return _load_cls(config_path, config_name)()
+ return _load_cls(config_path, config_name)
def _get_type_annotation_error(v: Any, type_annotation: Type) -> ValueError:
diff --git a/tests/src/utils/test_data_utils.py b/tests/src/utils/test_data_utils.py
index 262bd60fa..55efc8ba4 100644
--- a/tests/src/utils/test_data_utils.py
+++ b/tests/src/utils/test_data_utils.py
@@ -98,6 +98,7 @@ def test_oasst_data_automatic_split(tmp_path: pathlib.Path):
cfg_mock.dataset.prompt_column = ("instruction",)
cfg_mock.dataset.answer_column = "output"
cfg_mock.dataset.parent_id_column = "parent_id"
+ cfg_mock.dataset.prompt_column_separator = "\n\n"
cfg_mock.dataset.validation_strategy = "automatic"
diff --git a/train.py b/train.py
index 9072badf5..6b9b14332 100644
--- a/train.py
+++ b/train.py
@@ -24,6 +24,7 @@
from tqdm import tqdm
from transformers.deepspeed import HfDeepSpeedConfig
+from llm_studio.python_configs.base import DefaultConfigProblemBase
from llm_studio.src.loggers import MainLogger
from llm_studio.src.utils.config_utils import (
load_config_py,
@@ -72,7 +73,7 @@
def run_eval(
- cfg,
+ cfg: DefaultConfigProblemBase,
model: torch.nn.Module,
val_dataloader: DataLoader,
val_df: pd.DataFrame,
@@ -154,7 +155,7 @@ def run_eval(
def run_train(
- cfg: Any,
+ cfg: DefaultConfigProblemBase,
model: torch.nn.Module,
optimizer,
scheduler,
@@ -166,7 +167,7 @@ def run_train(
"""Runs the training loop.
Args:
- cfg: config object
+ cfg: DefaultConfigProblemBase config object
model: model
train_dataloader: custom training Dataloader
train_df: train DataFrame
@@ -427,11 +428,11 @@ def run_train(
return val_loss, val_metric
-def run(cfg: Any) -> float:
+def run(cfg: DefaultConfigProblemBase) -> float:
"""Runs the routine.
Args:
- cfg: config object with all the hyperparameters
+ cfg: DefaultConfigProblemBase config object with all the hyperparameters
"""
os.makedirs(cfg.output_directory, exist_ok=True)
@@ -686,7 +687,7 @@ def run(cfg: Any) -> float:
parser_args, unknown = parser.parse_known_args(sys.argv)
if "config" in parser_args:
- cfg = load_config_py(parser_args.config)
+ cfg: DefaultConfigProblemBase = load_config_py(parser_args.config)
elif "yaml" in parser_args:
cfg = load_config_yaml(parser_args.yaml)
else: