diff --git a/.gitignore b/.gitignore index f2fd634b8..28c5727df 100644 --- a/.gitignore +++ b/.gitignore @@ -4,7 +4,7 @@ notebooks/ demo_data/ output*/ tmp/ -data*/ +data/ examples/data_oasst2 examples/output_oasst2 data_old/ diff --git a/README.md b/README.md index f3e427710..0f4b03b7c 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,7 @@ Using CLI for fine-tuning LLMs: ## What's New +- [PR 788](https://github.com/h2oai/h2o-llmstudio/pull/788) New problem type for Causal Regression Modeling allows to train single target regression data using LLMs. - [PR 747](https://github.com/h2oai/h2o-llmstudio/pull/747) Fully removed RLHF in favor of DPO/IPO/KTO optimization. - [PR 741](https://github.com/h2oai/h2o-llmstudio/pull/741) Removing separate max length settings for prompt and answer in favor of a single `max_length` settings better resembling `chat_template` functionality from `transformers`. - [PR 592](https://github.com/h2oai/h2o-llmstudio/pull/599) Added `KTOPairLoss` for DPO modeling allowing to train models with simple preference data. Data currently needs to be manually prepared by randomly matching positive and negative examples as pairs. diff --git a/documentation/docs/tooltips/experiments/_answer-column.mdx b/documentation/docs/tooltips/experiments/_answer-column.mdx index ed1372301..4892d7783 100644 --- a/documentation/docs/tooltips/experiments/_answer-column.mdx +++ b/documentation/docs/tooltips/experiments/_answer-column.mdx @@ -1,3 +1,3 @@ The column in the dataset containing the expected output. -For classification, this needs to be an integer column starting from zero containing the class label. \ No newline at end of file +For classification, this needs to be an integer column starting from zero containing the class label, while for regression, it needs to be a float column. \ No newline at end of file diff --git a/documentation/docs/tooltips/experiments/_metric.mdx b/documentation/docs/tooltips/experiments/_metric.mdx index 7a4d95348..590d16ac7 100644 --- a/documentation/docs/tooltips/experiments/_metric.mdx +++ b/documentation/docs/tooltips/experiments/_metric.mdx @@ -12,3 +12,7 @@ Causal Classification Modeling - AUC: Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC). - Accuracy: Compute the accuracy of the model. - LogLoss: Compute the log loss of the model. + +Causal Regression Modeling +- MSE: Compute Mean Squared Error of the model. +- MAE: Compute Mean Absolute Error of the model. diff --git a/documentation/docs/tooltips/experiments/_problem-type.mdx b/documentation/docs/tooltips/experiments/_problem-type.mdx index 2d1e74512..a34cd1e20 100644 --- a/documentation/docs/tooltips/experiments/_problem-type.mdx +++ b/documentation/docs/tooltips/experiments/_problem-type.mdx @@ -2,8 +2,11 @@ Defines the problem type of the experiment, which also defines the settings H2O - Causal Language Modeling: Used to fine-tune large language models -- DPO Modeling: Used to fine-tune large language models using Direct Preference Optimization +- Causal Classification Modeling: Used to fine-tune causal classification models + +- Causal Regression Modeling: Used to fine-tune causal regression models - Sequence To Sequence Modeling: Used to fine-tune large sequence to sequence models -- Causal Classification Modeling: Used to fine-tune causal classification models \ No newline at end of file +- DPO Modeling: Used to fine-tune large language models using Direct Preference Optimization + diff --git a/llm_studio/app_utils/config.py b/llm_studio/app_utils/config.py index 9aa6eb789..eeea1b218 100644 --- a/llm_studio/app_utils/config.py +++ b/llm_studio/app_utils/config.py @@ -53,9 +53,10 @@ def get_size(x): "start_page": "home", "problem_types": [ "text_causal_language_modeling_config", - "text_dpo_modeling_config", - "text_sequence_to_sequence_modeling_config", "text_causal_classification_modeling_config", + "text_causal_regression_modeling_config", + "text_sequence_to_sequence_modeling_config", + "text_dpo_modeling_config", ], "problem_categories": ["text"], "dataset_keys": [ diff --git a/llm_studio/app_utils/hugging_face_utils.py b/llm_studio/app_utils/hugging_face_utils.py index 5c7d00e94..5b5f6aa84 100644 --- a/llm_studio/app_utils/hugging_face_utils.py +++ b/llm_studio/app_utils/hugging_face_utils.py @@ -253,6 +253,15 @@ def publish_model_to_hugging_face( repo_type="model", commit_message="Upload classification_head.pth", ) + # push regression head to hub + if os.path.isfile(f"{path_to_experiment}/regression_head.pth"): + api.upload_file( + path_or_fileobj=f"{path_to_experiment}/regression_head.pth", + path_in_repo="regression_head.pth", + repo_id=repo_id, + repo_type="model", + commit_message="Upload regression_head.pth", + ) # push config to hub api.upload_file( diff --git a/llm_studio/app_utils/sections/chat.py b/llm_studio/app_utils/sections/chat.py index c395a83b7..d380699e5 100644 --- a/llm_studio/app_utils/sections/chat.py +++ b/llm_studio/app_utils/sections/chat.py @@ -145,7 +145,7 @@ async def should_start_chat(q: Q): box="first", items=[ ui.text( - "Chatbot is not available for text classification problems. " + "Chatbot is not available for this problem type. " "Please select a text generation problem." ) ], diff --git a/llm_studio/app_utils/sections/experiment.py b/llm_studio/app_utils/sections/experiment.py index e8148bd18..76db8da41 100644 --- a/llm_studio/app_utils/sections/experiment.py +++ b/llm_studio/app_utils/sections/experiment.py @@ -1814,6 +1814,7 @@ async def experiment_download_model(q: Q): "added_tokens.json", "model_card.md", "classification_head.pth", + "regression_head.pth", ] FILES_TO_PUSH = set( FILES_TO_PUSH diff --git a/llm_studio/python_configs/text_causal_regression_modeling_config.py b/llm_studio/python_configs/text_causal_regression_modeling_config.py new file mode 100644 index 000000000..cff596106 --- /dev/null +++ b/llm_studio/python_configs/text_causal_regression_modeling_config.py @@ -0,0 +1,175 @@ +import os +from dataclasses import dataclass, field +from typing import Any, Dict, List, Tuple + +import llm_studio.src.datasets.text_causal_regression_ds +import llm_studio.src.plots.text_causal_classification_modeling_plots +from llm_studio.python_configs.base import DefaultConfig, DefaultConfigProblemBase +from llm_studio.python_configs.text_causal_classification_modeling_config import ( + ConfigNLPCausalClassificationAugmentation as ConfigNLPCausalRegressionAugmentation, +) +from llm_studio.python_configs.text_causal_classification_modeling_config import ( + ConfigNLPCausalClassificationDataset, +) +from llm_studio.python_configs.text_causal_classification_modeling_config import ( + ConfigNLPCausalClassificationLogging as ConfigNLPCausalRegressionLogging, +) +from llm_studio.python_configs.text_causal_classification_modeling_config import ( + ConfigNLPCausalClassificationTokenizer as ConfigNLPCausalRegressionTokenizer, +) +from llm_studio.python_configs.text_causal_classification_modeling_config import ( + ConfigNLPCausalClassificationTraining, +) +from llm_studio.python_configs.text_causal_language_modeling_config import ( + ConfigNLPCausalLMArchitecture, + ConfigNLPCausalLMEnvironment, +) +from llm_studio.src import possible_values +from llm_studio.src.losses import text_causal_regression_modeling_losses +from llm_studio.src.metrics import text_causal_regression_modeling_metrics +from llm_studio.src.models import text_causal_regression_modeling_model +from llm_studio.src.utils.modeling_utils import generate_experiment_name + + +@dataclass +class ConfigNLPCausalRegressionDataset(ConfigNLPCausalClassificationDataset): + dataset_class: Any = llm_studio.src.datasets.text_causal_regression_ds.CustomDataset + num_classes: int = 1 + + def __post_init__(self): + self.prompt_column = ( + tuple( + self.prompt_column, + ) + if isinstance(self.prompt_column, str) + else tuple(self.prompt_column) + ) + super().__post_init__() + + self._visibility["num_classes"] = -1 + + +@dataclass +class ConfigNLPCausalRegressionTraining(ConfigNLPCausalClassificationTraining): + loss_class: Any = text_causal_regression_modeling_losses.Losses + loss_function: str = "MSELoss" + + learning_rate: float = 0.0001 + differential_learning_rate_layers: Tuple[str, ...] = ("regression_head",) + differential_learning_rate: float = 0.00001 + + def __post_init__(self): + super().__post_init__() + self._possible_values["loss_function"] = self.loss_class.names() + + self._possible_values["differential_learning_rate_layers"] = ( + possible_values.String( + values=("backbone", "embed", "regression_head"), + allow_custom=False, + placeholder="Select optional layers...", + ) + ) + + +@dataclass +class ConfigNLPCausalRegressionArchitecture(ConfigNLPCausalLMArchitecture): + model_class: Any = text_causal_regression_modeling_model.Model + + def __post_init__(self): + super().__post_init__() + + +@dataclass +class ConfigNLPCausalRegressionPrediction(DefaultConfig): + metric_class: Any = text_causal_regression_modeling_metrics.Metrics + metric: str = "MSE" + batch_size_inference: int = 0 + + def __post_init__(self): + super().__post_init__() + + self._possible_values["metric"] = self.metric_class.names() + self._possible_values["batch_size_inference"] = (0, 512, 1) + + self._visibility["metric_class"] = -1 + + +@dataclass +class ConfigNLPCausalRegressionEnvironment(ConfigNLPCausalLMEnvironment): + _model_card_template: str = "text_causal_regression_model_card_template.md" + _summary_card_template: str = ( + "text_causal_regression_experiment_summary_card_template.md" + ) + + def __post_init__(self): + super().__post_init__() + + +@dataclass +class ConfigProblemBase(DefaultConfigProblemBase): + output_directory: str = f"output/{os.path.basename(__file__).split('.')[0]}" + experiment_name: str = field(default_factory=generate_experiment_name) + llm_backbone: str = "h2oai/h2o-danube2-1.8b-chat" + + dataset: ConfigNLPCausalRegressionDataset = field( + default_factory=ConfigNLPCausalRegressionDataset + ) + tokenizer: ConfigNLPCausalRegressionTokenizer = field( + default_factory=ConfigNLPCausalRegressionTokenizer + ) + architecture: ConfigNLPCausalRegressionArchitecture = field( + default_factory=ConfigNLPCausalRegressionArchitecture + ) + training: ConfigNLPCausalRegressionTraining = field( + default_factory=ConfigNLPCausalRegressionTraining + ) + augmentation: ConfigNLPCausalRegressionAugmentation = field( + default_factory=ConfigNLPCausalRegressionAugmentation + ) + prediction: ConfigNLPCausalRegressionPrediction = field( + default_factory=ConfigNLPCausalRegressionPrediction + ) + environment: ConfigNLPCausalRegressionEnvironment = field( + default_factory=ConfigNLPCausalRegressionEnvironment + ) + logging: ConfigNLPCausalRegressionLogging = field( + default_factory=ConfigNLPCausalRegressionLogging + ) + + def __post_init__(self): + super().__post_init__() + + self._visibility["output_directory"] = -1 + + self._possible_values["llm_backbone"] = possible_values.String( + values=( + "h2oai/h2o-danube2-1.8b-base", + "h2oai/h2o-danube2-1.8b-chat", + "h2oai/h2ogpt-4096-llama2-7b", + "h2oai/h2ogpt-4096-llama2-7b-chat", + "h2oai/h2ogpt-4096-llama2-13b", + "h2oai/h2ogpt-4096-llama2-13b-chat", + "h2oai/h2ogpt-4096-llama2-70b", + "h2oai/h2ogpt-4096-llama2-70b-chat", + "tiiuae/falcon-7b", + "mistralai/Mistral-7B-v0.1", + "HuggingFaceH4/zephyr-7b-beta", + "google/gemma-2b", + "google/gemma-7b", + "stabilityai/stablelm-3b-4e1t", + "microsoft/phi-2", + "facebook/opt-125m", + ), + allow_custom=True, + ) + + def check(self) -> Dict[str, List]: + errors: Dict[str, List] = {"title": [], "message": []} + + if self.dataset.parent_id_column not in ["None", None]: + errors["title"] += ["Parent ID column is not supported for regression"] + errors["message"] += [ + "Parent ID column is not supported for regression datasets." + ] + + return errors diff --git a/llm_studio/src/datasets/text_causal_regression_ds.py b/llm_studio/src/datasets/text_causal_regression_ds.py new file mode 100644 index 000000000..615e1f3d4 --- /dev/null +++ b/llm_studio/src/datasets/text_causal_regression_ds.py @@ -0,0 +1,38 @@ +import logging +from typing import Any, Dict + +import numpy as np +import pandas as pd + +from llm_studio.src.datasets.text_causal_language_modeling_ds import ( + CustomDataset as TextCausalLanguageModelingCustomDataset, +) +from llm_studio.src.utils.exceptions import LLMDataException + +logger = logging.getLogger(__name__) + + +class CustomDataset(TextCausalLanguageModelingCustomDataset): + def __init__(self, df: pd.DataFrame, cfg: Any, mode: str = "train"): + super().__init__(df=df, cfg=cfg, mode=mode) + self.answers_float = df[cfg.dataset.answer_column].astype(float).values.tolist() + + if cfg.dataset.parent_id_column != "None": + raise LLMDataException( + "Parent ID column is not supported for regression datasets." + ) + + def __getitem__(self, idx: int) -> Dict: + sample = super().__getitem__(idx) + sample["class_label"] = self.answers_float[idx] + return sample + + def postprocess_output(self, cfg, df: pd.DataFrame, output: Dict) -> Dict: + output["logits"] = output["logits"].float() + preds = output["logits"] + preds = np.array(preds).astype(float).astype(str).reshape(-1) + output["predicted_text"] = preds + return super().postprocess_output(cfg, df, output) + + def clean_output(self, output, cfg): + return output diff --git a/llm_studio/src/losses/text_causal_regression_modeling_losses.py b/llm_studio/src/losses/text_causal_regression_modeling_losses.py new file mode 100644 index 000000000..eae934779 --- /dev/null +++ b/llm_studio/src/losses/text_causal_regression_modeling_losses.py @@ -0,0 +1,53 @@ +import logging +from typing import Any, KeysView + +from torch import Tensor, nn + +__all__ = ["Losses"] + + +logger = logging.getLogger(__name__) + + +class MSELoss(nn.Module): + def __init__(self, cfg: Any): + super().__init__() + self.cfg = cfg + self.loss_fn = nn.MSELoss() + + def forward(self, logits: Tensor, labels: Tensor) -> Tensor: + return self.loss_fn(logits, labels.reshape(-1)) + + +class MAELoss(nn.Module): + def __init__(self, cfg: Any): + super().__init__() + self.cfg = cfg + self.loss_fn = nn.L1Loss() + + def forward(self, logits: Tensor, labels: Tensor) -> Tensor: + return self.loss_fn(logits, labels.reshape(-1)) + + +class Losses: + """Losses factory.""" + + _losses = { + "MSELoss": MSELoss, + "MAELoss": MAELoss, + } + + @classmethod + def names(cls) -> KeysView: + return cls._losses.keys() + + @classmethod + def get(cls, name: str) -> Any: + """Access to Losses. + + Args: + name: losses name + Returns: + A class to build the Losses + """ + return cls._losses.get(name, MSELoss) diff --git a/llm_studio/src/metrics/text_causal_regression_modeling_metrics.py b/llm_studio/src/metrics/text_causal_regression_modeling_metrics.py new file mode 100644 index 000000000..0d91b277b --- /dev/null +++ b/llm_studio/src/metrics/text_causal_regression_modeling_metrics.py @@ -0,0 +1,62 @@ +import logging +from typing import Any, Dict, List, Tuple, Union + +import numpy as np +import pandas as pd +from numpy.typing import NDArray + +logger = logging.getLogger(__name__) + + +def mse_score( + cfg: Any, + results: Dict, + val_df: pd.DataFrame, + raw_results: bool = False, +) -> Union[NDArray, Tuple[NDArray, List[str]]]: + predicted_text = np.array([float(text) for text in results["predicted_text"]]) + target_text = np.array([float(text) for text in results["target_text"]]) + return ((target_text - predicted_text) ** 2).astype("float") + + +def mae_score( + cfg: Any, + results: Dict, + val_df: pd.DataFrame, + raw_results: bool = False, +) -> Union[NDArray, Tuple[NDArray, List[str]]]: + predicted_text = np.array([float(text) for text in results["predicted_text"]]) + target_text = np.array([float(text) for text in results["target_text"]]) + return np.abs(target_text - predicted_text).astype("float") + + +class Metrics: + """ + Metrics factory. Returns: + - metric value + - should it be maximized or minimized + - Reduce function + + Maximized or minimized is needed for early stopping (saving best checkpoint) + Reduce function to generate a single metric value, usually "mean" or "none" + """ + + _metrics = { + "MSE": (mse_score, "min", "mean"), + "MAE": (mae_score, "min", "mean"), + } + + @classmethod + def names(cls) -> List[str]: + return sorted(cls._metrics.keys()) + + @classmethod + def get(cls, name: str) -> Any: + """Access to Metrics. + + Args: + name: metrics name + Returns: + A class to build the Metrics + """ + return cls._metrics.get(name, cls._metrics["MSE"]) diff --git a/llm_studio/src/models/text_causal_regression_modeling_model.py b/llm_studio/src/models/text_causal_regression_modeling_model.py new file mode 100644 index 000000000..89b676c28 --- /dev/null +++ b/llm_studio/src/models/text_causal_regression_modeling_model.py @@ -0,0 +1,94 @@ +import logging +from typing import Any, Dict + +from torch import nn +from transformers import AutoModelForCausalLM + +from llm_studio.src.utils.data_utils import batch_padding +from llm_studio.src.utils.modeling_utils import ( + create_nlp_backbone, + forward, + prepare_lora, +) + +logger = logging.getLogger(__name__) + + +class Model(nn.Module): + """ + Model for causal language modeling problem type. + """ + + def __init__(self, cfg: Any): + """ + Args: + cfg: config with all the hyperparameters + """ + + super(Model, self).__init__() + + self.cfg = cfg + self.backbone, self.backbone_config = create_nlp_backbone( + cfg, model_class=AutoModelForCausalLM + ) + + if cfg.training.lora: + self.backbone = prepare_lora(cfg, self.backbone) + + self.regression_head = nn.Linear( + self.backbone_config.vocab_size, cfg.dataset.num_classes, bias=False + ) + + self.loss_fn = self.cfg.training.loss_class.get( + self.cfg.training.loss_function + )(self.cfg) + + def forward( + self, + batch: Dict, + padding: bool = True, + ) -> Dict: + # disable cache if gradient checkpointing is enabled + if self.cfg.architecture.gradient_checkpointing: + self.backbone.config.use_cache = False + + outputs: Dict = {} + mask_key = "prompt_attention_mask" + pad_keys = [ + "prompt_input_ids", + "prompt_attention_mask", + "special_tokens_mask", + "labels", + ] + + if padding: + batch = batch_padding( + self.cfg, + batch, + self.training, + mask_key=mask_key, + pad_keys=pad_keys, + padding_side=self.cfg.tokenizer._padding_side, + ) + + output = forward( + self.backbone, + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + ) + + output.logits = self.regression_head(output[0][:, -1].float()) + + if "labels" in batch: + loss = self.loss_fn( + output.logits, batch["class_label"].unsqueeze(1).float() + ) + outputs["loss"] = loss + + outputs["logits"] = output.logits + + # enable cache again if gradient checkpointing is enabled + if self.cfg.architecture.gradient_checkpointing: + self.backbone.config.use_cache = True + + return outputs diff --git a/llm_studio/src/utils/config_utils.py b/llm_studio/src/utils/config_utils.py index 516d314aa..10a18a2c1 100644 --- a/llm_studio/src/utils/config_utils.py +++ b/llm_studio/src/utils/config_utils.py @@ -209,4 +209,7 @@ def load_config_yaml(path: str): # Note that importing ConfigProblemBase from the python_configs # and using cfg.problem_type below will not work because of circular imports GENERATION_PROBLEM_TYPES = ["text_causal_language_modeling", "text_dpo_modeling"] -NON_GENERATION_PROBLEM_TYPES = ["text_causal_classification_modeling"] +NON_GENERATION_PROBLEM_TYPES = [ + "text_causal_classification_modeling", + "text_causal_regression_modeling", +] diff --git a/llm_studio/src/utils/modeling_utils.py b/llm_studio/src/utils/modeling_utils.py index 785a09d52..058df68ea 100644 --- a/llm_studio/src/utils/modeling_utils.py +++ b/llm_studio/src/utils/modeling_utils.py @@ -146,6 +146,14 @@ def save_checkpoint( checkpoint["model"]["classification_head.weight"], os.path.join(path, "classification_head.pth"), ) + if ( + cfg.environment._local_rank == 0 + and "regression_head.weight" in checkpoint["model"] + ): + torch.save( + checkpoint["model"]["regression_head.weight"], + os.path.join(path, "regression_head.pth"), + ) def load_model_weights( diff --git a/model_cards/text_causal_classification_experiment_summary_card_template.md b/model_cards/text_causal_classification_experiment_summary_card_template.md index e3204b652..a108a4113 100644 --- a/model_cards/text_causal_classification_experiment_summary_card_template.md +++ b/model_cards/text_causal_classification_experiment_summary_card_template.md @@ -28,6 +28,7 @@ You can make classification predictions by following the example below: ```python from transformers import AutoModelForCausalLM, AutoTokenizer +import torch model_name = "{{repo_id}}" # either local folder or Hugging Face model name # Important: The prompt needs to be in the same format the model was trained with. diff --git a/model_cards/text_causal_classification_model_card_template.md b/model_cards/text_causal_classification_model_card_template.md index e5e3c3fdd..72f26e3be 100644 --- a/model_cards/text_causal_classification_model_card_template.md +++ b/model_cards/text_causal_classification_model_card_template.md @@ -45,6 +45,7 @@ You can make classification predictions by following the example below: ```python from transformers import AutoModelForCausalLM, AutoTokenizer +import torch model_name = "{{repo_id}}" # either local folder or Hugging Face model name # Important: The prompt needs to be in the same format the model was trained with. diff --git a/model_cards/text_causal_regression_experiment_summary_card_template.md b/model_cards/text_causal_regression_experiment_summary_card_template.md new file mode 100644 index 000000000..56192fbff --- /dev/null +++ b/model_cards/text_causal_regression_experiment_summary_card_template.md @@ -0,0 +1,61 @@ +### Usage with HF transformers + +To use the model with the `transformers` library on a machine with GPUs: +- First, push the model to a huggingface repo by clicking the Push checkpoint to huggingface button below +- Make sure you have the `transformers` library installed in the machine's environment + +```bash +pip install transformers=={{transformers_version}} +``` + +Also make sure you are providing your huggingface token if the model is lying in a private repo. + - You can login to hugginface_hub by running + ```python + import huggingface_hub + huggingface_hub.login() + ``` + +You will also need to download the regression head, either manually, or by running the following code: + +```python +from huggingface_hub import hf_hub_download + +model_name = "{{repo_id}}" # either local folder or Hugging Face model name +hf_hub_download(repo_id=model_name, filename="regression_head.pth", local_dir="./") +``` + +You can make regression predictions by following the example below: + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer +import torch + +model_name = "{{repo_id}}" # either local folder or Hugging Face model name +# Important: The prompt needs to be in the same format the model was trained with. +# You can find an example prompt in the experiment logs. +prompt = "{{text_prompt_start}}How are you?{{end_of_sentence}}{{text_answer_separator}}" + +tokenizer = AutoTokenizer.from_pretrained( + model_name, + trust_remote_code={{trust_remote_code}}, +) +model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype="auto", + device_map={"": "cuda:0"}, + trust_remote_code={{trust_remote_code}}, +).cuda().eval() + +head_weights = torch.load("regression_head.pth", map_location="cuda") +# settings can be arbitrary here as we overwrite with saved weights +head = torch.nn.Linear(1, 1, bias=False).to("cuda") +head.weight.data = head_weights + +inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to("cuda") + +out = model(**inputs).logits + +logits = head(out[:,-1]) + +print(logits) +``` diff --git a/model_cards/text_causal_regression_model_card_template.md b/model_cards/text_causal_regression_model_card_template.md new file mode 100644 index 000000000..84eff5ec6 --- /dev/null +++ b/model_cards/text_causal_regression_model_card_template.md @@ -0,0 +1,106 @@ +--- +language: +- en +library_name: transformers +inference: false +thumbnail: https://h2o.ai/etc.clientlibs/h2o/clientlibs/clientlib-site/resources/images/favicon.ico +tags: +- gpt +- llm +- large language model +- h2o-llmstudio +--- +# Model Card +## Summary + +This model was trained using [H2O LLM Studio](https://github.com/h2oai/h2o-llmstudio). +- Base model: [{{base_model}}](https://huggingface.co/{{base_model}}) + + +## Usage + +To use the model with the `transformers` library on a machine with GPUs, first make sure you have the `transformers` library installed. + +```bash +pip install transformers=={{transformers_version}} +``` + +Also make sure you are providing your huggingface token if the model is lying in a private repo. + - You can login to hugginface_hub by running + ```python + import huggingface_hub + huggingface_hub.login() + ``` + +You will also need to download the regression head, either manually, or by running the following code: + +```python +from huggingface_hub import hf_hub_download + +model_name = "{{repo_id}}" # either local folder or Hugging Face model name +hf_hub_download(repo_id=model_name, filename="regression_head.pth", local_dir="./") +``` + +You can make regression predictions by following the example below: + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer +import torch + +model_name = "{{repo_id}}" # either local folder or Hugging Face model name +# Important: The prompt needs to be in the same format the model was trained with. +# You can find an example prompt in the experiment logs. +prompt = "{{text_prompt_start}}How are you?{{end_of_sentence}}{{text_answer_separator}}" + +tokenizer = AutoTokenizer.from_pretrained( + model_name, + trust_remote_code={{trust_remote_code}}, +) +model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype="auto", + device_map={"": "cuda:0"}, + trust_remote_code={{trust_remote_code}}, +).cuda().eval() + +head_weights = torch.load("regression_head.pth", map_location="cuda") +# settings can be arbitrary here as we overwrite with saved weights +head = torch.nn.Linear(1, 1, bias=False).to("cuda") +head.weight.data = head_weights + +inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to("cuda") + +out = model(**inputs).logits + +logits = head(out[:,-1]) + +print(logits) +``` + +## Quantization and sharding + +You can load the models using quantization by specifying ```load_in_8bit=True``` or ```load_in_4bit=True```. Also, sharding on multiple GPUs is possible by setting ```device_map=auto```. + +## Model Architecture + +``` +{{model_architecture}} +``` + +## Model Configuration + +This model was trained using H2O LLM Studio and with the configuration in [cfg.yaml](cfg.yaml). Visit [H2O LLM Studio](https://github.com/h2oai/h2o-llmstudio) to learn how to train your own large language models. + + +## Disclaimer + +Please read this disclaimer carefully before using the large language model provided in this repository. Your use of the model signifies your agreement to the following terms and conditions. + +- Biases and Offensiveness: The large language model is trained on a diverse range of internet text data, which may contain biased, racist, offensive, or otherwise inappropriate content. By using this model, you acknowledge and accept that the generated content may sometimes exhibit biases or produce content that is offensive or inappropriate. The developers of this repository do not endorse, support, or promote any such content or viewpoints. +- Limitations: The large language model is an AI-based tool and not a human. It may produce incorrect, nonsensical, or irrelevant responses. It is the user's responsibility to critically evaluate the generated content and use it at their discretion. +- Use at Your Own Risk: Users of this large language model must assume full responsibility for any consequences that may arise from their use of the tool. The developers and contributors of this repository shall not be held liable for any damages, losses, or harm resulting from the use or misuse of the provided model. +- Ethical Considerations: Users are encouraged to use the large language model responsibly and ethically. By using this model, you agree not to use it for purposes that promote hate speech, discrimination, harassment, or any form of illegal or harmful activities. +- Reporting Issues: If you encounter any biased, offensive, or otherwise inappropriate content generated by the large language model, please report it to the repository maintainers through the provided channels. Your feedback will help improve the model and mitigate potential issues. +- Changes to this Disclaimer: The developers of this repository reserve the right to modify or update this disclaimer at any time without prior notice. It is the user's responsibility to periodically review the disclaimer to stay informed about any changes. + +By using the large language model provided in this repository, you agree to accept and comply with the terms and conditions outlined in this disclaimer. If you do not agree with any part of this disclaimer, you should refrain from using the model and any content generated by it. diff --git a/tests/integration/test_causal_regression_modeling_cfg.yaml b/tests/integration/test_causal_regression_modeling_cfg.yaml new file mode 100644 index 000000000..707a4514f --- /dev/null +++ b/tests/integration/test_causal_regression_modeling_cfg.yaml @@ -0,0 +1,91 @@ +architecture: + backbone_dtype: int4 + gradient_checkpointing: true + intermediate_dropout: 0.0 + pretrained: true + pretrained_weights: '' +augmentation: + random_parent_probability: 0.0 + skip_parent_probability: 0.0 + token_mask_probability: 0.0 +dataset: + add_eos_token_to_answer: false + add_eos_token_to_prompt: false + add_eos_token_to_system: false + answer_column: regression_label + chatbot_author: H2O.ai + chatbot_name: h2oGPT + data_sample: 0.03 + data_sample_choice: + - Train + - Validation + limit_chained_samples: false + mask_prompt_labels: true + num_classes: 1 + parent_id_column: None + personalize: false + prompt_column: + - instruction + system_column: None + text_answer_separator: '' + text_prompt_start: '' + text_system_start: '' + train_dataframe: /tmp/train_full.pq + validation_dataframe: None + validation_size: 0.2 + validation_strategy: automatic +environment: + compile_model: false + deepspeed_reduce_bucket_size: 1000000 + deepspeed_stage3_param_persistence_threshold: 1000000 + deepspeed_stage3_prefetch_bucket_size: 1000000 + find_unused_parameters: false + gpus: + - '0' + huggingface_branch: main + mixed_precision: true + mixed_precision_dtype: float16 + number_of_workers: 8 + seed: -1 + trust_remote_code: true + use_deepspeed: false +experiment_name: test-regression-modeling +llm_backbone: facebook/opt-125m +logging: + logger: None + neptune_project: test_org/test_project +output_directory: /tmp/output +prediction: + batch_size_inference: 0 + metric: MSE +problem_type: text_causal_regression_modeling +tokenizer: + add_prompt_answer_tokens: false + max_length: 512 + padding_quantile: 1.0 + tokenizer_kwargs: '{"use_fast": true, "add_prefix_space": false}' +training: + batch_size: 2 + differential_learning_rate: 1.0e-05 + differential_learning_rate_layers: + - classification_head + drop_last_batch: true + epochs: 1 + evaluate_before_training: false + evaluation_epochs: 1.0 + grad_accumulation: 1 + gradient_clip: 0.0 + learning_rate: 0.0001 + lora: true + use_dora: false + lora_alpha: 16 + lora_dropout: 0.05 + lora_r: 4 + lora_target_modules: '' + loss_function: MSELoss + optimizer: AdamW + save_checkpoint: "last" + schedule: Cosine + train_validation_data: false + warmup_epochs: 0.0 + weight_decay: 0.0 \ No newline at end of file diff --git a/tests/integration/test_causal_regression_modeling_cpu_cfg.yaml b/tests/integration/test_causal_regression_modeling_cpu_cfg.yaml new file mode 100644 index 000000000..f40affee3 --- /dev/null +++ b/tests/integration/test_causal_regression_modeling_cpu_cfg.yaml @@ -0,0 +1,90 @@ +architecture: + backbone_dtype: float32 + gradient_checkpointing: true + intermediate_dropout: 0.0 + pretrained: true + pretrained_weights: '' +augmentation: + random_parent_probability: 0.0 + skip_parent_probability: 0.0 + token_mask_probability: 0.0 +dataset: + add_eos_token_to_answer: false + add_eos_token_to_prompt: false + add_eos_token_to_system: false + answer_column: regression_label + chatbot_author: H2O.ai + chatbot_name: h2oGPT + data_sample: 0.03 + data_sample_choice: + - Train + - Validation + limit_chained_samples: false + mask_prompt_labels: true + num_classes: 1 + parent_id_column: None + personalize: false + prompt_column: + - instruction + system_column: None + text_answer_separator: '' + text_prompt_start: '' + text_system_start: '' + train_dataframe: /tmp/train_full.pq + validation_dataframe: None + validation_size: 0.2 + validation_strategy: automatic +environment: + compile_model: false + deepspeed_reduce_bucket_size: 1000000 + deepspeed_stage3_param_persistence_threshold: 1000000 + deepspeed_stage3_prefetch_bucket_size: 1000000 + find_unused_parameters: false + gpus: + - '' + huggingface_branch: main + mixed_precision: false + number_of_workers: 8 + seed: -1 + trust_remote_code: true + use_deepspeed: false +experiment_name: solid-spaniel +llm_backbone: h2oai/llama2-0b-unit-test +logging: + logger: None + neptune_project: '' +output_directory: /tmp/output +prediction: + batch_size_inference: 0 + metric: MSE +problem_type: text_causal_regression_modeling +tokenizer: + add_prompt_answer_tokens: false + max_length: 32 + padding_quantile: 1.0 + tokenizer_kwargs: '{"use_fast": true, "add_prefix_space": false}' +training: + batch_size: 6 + differential_learning_rate: 1.0e-05 + differential_learning_rate_layers: + - classification_head + drop_last_batch: true + epochs: 1 + evaluate_before_training: false + evaluation_epochs: 1.0 + grad_accumulation: 1 + gradient_clip: 0.0 + learning_rate: 0.0001 + lora: true + use_dora: false + lora_alpha: 16 + lora_dropout: 0.05 + lora_r: 4 + lora_target_modules: '' + loss_function: MSELoss + optimizer: AdamW + save_checkpoint: "last" + schedule: Cosine + train_validation_data: false + warmup_epochs: 0.0 + weight_decay: 0.0 \ No newline at end of file diff --git a/tests/integration/test_integration.py b/tests/integration/test_integration.py index bd0245d4f..e992bae4c 100644 --- a/tests/integration/test_integration.py +++ b/tests/integration/test_integration.py @@ -68,6 +68,36 @@ def test_oasst_classification_training_gpu(tmp_path, settings): ) +@pytest.mark.parametrize( + "settings", + [ + ["MSE", "test_causal_regression_modeling_cfg"], + ], +) +def test_oasst_regression_training_gpu(tmp_path, settings): + metric, config_name = settings + run_oasst( + tmp_path, + config_name=config_name, + metric=metric, + ) + + +@pytest.mark.parametrize( + "settings", + [ + ["MSE", "test_causal_regression_modeling_cpu_cfg"], + ], +) +def test_oasst_regression_training_cpu(tmp_path, settings): + metric, config_name = settings + run_oasst( + tmp_path, + config_name=config_name, + metric=metric, + ) + + @pytest.mark.parametrize( "settings", [ @@ -115,6 +145,7 @@ def run_oasst(tmp_path, config_name, metric): df = pd.read_parquet(train_path) df["multiclass_label"] = np.random.choice(["0", "1", "2"], size=len(df)) df["binary_label"] = np.random.choice(["0", "1"], size=len(df)) + df["regression_label"] = np.random.uniform(0, 1, size=len(df)) df.to_parquet(train_path) with open( diff --git a/tests/python_configs/test_base.py b/tests/python_configs/test_base.py index 7da0675f3..2ce3db7e1 100644 --- a/tests/python_configs/test_base.py +++ b/tests/python_configs/test_base.py @@ -4,6 +4,9 @@ from llm_studio.python_configs.text_causal_language_modeling_config import ( ConfigProblemBase as CausalConfigProblemBase, ) +from llm_studio.python_configs.text_causal_regression_modeling_config import ( + ConfigProblemBase as CausalRegressionConfigProblemBase, +) from llm_studio.python_configs.text_sequence_to_sequence_modeling_config import ( ConfigProblemBase as Seq2SeqConfigProblemBase, ) @@ -18,6 +21,7 @@ def test_from_dict(): CausalConfigProblemBase, Seq2SeqConfigProblemBase, CausalClassificationConfigProblemBase, + CausalRegressionConfigProblemBase, ]: cfg = cfg_class() cfg_dict = convert_cfg_base_to_nested_dictionary(cfg) @@ -33,3 +37,8 @@ def test_from_dict(): def test_classification_config_is_in_non_generating_problem_types(): cfg = CausalClassificationConfigProblemBase() assert cfg.problem_type in NON_GENERATION_PROBLEM_TYPES + + +def test_regression_config_is_in_non_generating_problem_types(): + cfg = CausalRegressionConfigProblemBase() + assert cfg.problem_type in NON_GENERATION_PROBLEM_TYPES