Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add instruction and conversational data support #211

Merged
merged 21 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api"

[tool.poetry]
name = "together"
version = "1.3.3"
version = "1.3.4"
authors = [
"Together AI <[email protected]>"
]
Expand Down
39 changes: 21 additions & 18 deletions src/together/cli/api/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,13 @@
from tabulate import tabulate

from together import Together
from together.cli.api.utils import INT_WITH_MAX
from together.utils import finetune_price_to_dollars, log_warn, parse_timestamp
from together.cli.api.utils import BOOL_WITH_AUTO, INT_WITH_MAX
from together.utils import (
finetune_price_to_dollars,
log_warn,
log_warn_once,
parse_timestamp,
)
from together.types.finetune import DownloadCheckpointType, FinetuneTrainingLimits


Expand Down Expand Up @@ -93,6 +98,13 @@ def fine_tuning(ctx: click.Context) -> None:
default=False,
help="Whether to skip the launch confirmation message",
)
@click.option(
"--train-on-inputs",
type=BOOL_WITH_AUTO,
default="auto",
help="Whether to mask the user messages in conversational data or prompts in instruction data. "
"`auto` will automatically determine whether to mask the inputs based on the data format.",
)
def create(
ctx: click.Context,
training_file: str,
Expand All @@ -112,6 +124,7 @@ def create(
suffix: str,
wandb_api_key: str,
confirm: bool,
train_on_inputs: bool | Literal["auto"],
) -> None:
"""Start fine-tuning"""
client: Together = ctx.obj
Expand All @@ -133,6 +146,7 @@ def create(
lora_trainable_modules=lora_trainable_modules,
suffix=suffix,
wandb_api_key=wandb_api_key,
train_on_inputs=train_on_inputs,
)

model_limits: FinetuneTrainingLimits = client.fine_tuning.get_model_limits(
Expand All @@ -150,6 +164,10 @@ def create(
"batch_size": model_limits.lora_training.max_batch_size,
"learning_rate": 1e-3,
}
log_warn_once(
f"LoRA rank default has been changed to {default_values['lora_r']} as the max available for the model.\n"
f"Learning rate default for LoRA FT has been changed to {default_values['learning_rate']}."
artek0chumak marked this conversation as resolved.
Show resolved Hide resolved
)
for arg in default_values:
arg_source = ctx.get_parameter_source("arg") # type: ignore[attr-defined]
if arg_source == ParameterSource.DEFAULT:
Expand Down Expand Up @@ -186,22 +204,7 @@ def create(

if confirm or click.confirm(_CONFIRMATION_MESSAGE, default=True, show_default=True):
response = client.fine_tuning.create(
training_file=training_file,
model=model,
n_epochs=n_epochs,
validation_file=validation_file,
n_evals=n_evals,
n_checkpoints=n_checkpoints,
batch_size=batch_size,
learning_rate=learning_rate,
warmup_ratio=warmup_ratio,
lora=lora,
lora_r=lora_r,
lora_dropout=lora_dropout,
lora_alpha=lora_alpha,
lora_trainable_modules=lora_trainable_modules,
suffix=suffix,
wandb_api_key=wandb_api_key,
**training_args,
mryab marked this conversation as resolved.
Show resolved Hide resolved
verbose=True,
)

Expand Down
21 changes: 21 additions & 0 deletions src/together/cli/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,25 @@ def convert(
)


class BooleanWithAutoParamType(click.ParamType):
name = "boolean_or_auto"

def convert(
self, value: str, param: click.Parameter | None, ctx: click.Context | None
) -> bool | Literal["auto"] | None:
if value == "auto":
return "auto"
try:
return bool(value)
except ValueError:
self.fail(
_("{value!r} is not a valid {type}.").format(
value=value, type=self.name
),
param,
ctx,
)


INT_WITH_MAX = AutoIntParamType()
BOOL_WITH_AUTO = BooleanWithAutoParamType()
19 changes: 19 additions & 0 deletions src/together/constants.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import enum

# Session constants
TIMEOUT_SECS = 600
MAX_SESSION_LIFETIME_SECS = 180
Expand Down Expand Up @@ -29,3 +31,20 @@

# expected columns for Parquet files
PARQUET_EXPECTED_COLUMNS = ["input_ids", "attention_mask", "labels"]


class DatasetFormat(enum.Enum):
"""Dataset format enum."""

GENERAL = "general"
CONVERSATION = "conversation"
INSTRUCTION = "instruction"


JSONL_REQUIRED_COLUMNS_MAP = {
DatasetFormat.GENERAL: ["text"],
DatasetFormat.CONVERSATION: ["messages"],
DatasetFormat.INSTRUCTION: ["prompt", "completion"],
}
REQUIRED_COLUMNS_MESSAGE = ["role", "content"]
POSSIBLE_ROLES_CONVERSATION = ["system", "user", "assistant"]
20 changes: 19 additions & 1 deletion src/together/resources/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def createFinetuneRequest(
lora_trainable_modules: str | None = "all-linear",
suffix: str | None = None,
wandb_api_key: str | None = None,
train_on_inputs: bool | Literal["auto"] = "auto",
) -> FinetuneRequest:
if batch_size == "max":
log_warn_once(
Expand Down Expand Up @@ -95,6 +96,7 @@ def createFinetuneRequest(
training_type=training_type,
suffix=suffix,
wandb_key=wandb_api_key,
train_on_inputs=train_on_inputs,
)

return finetune_request
Expand Down Expand Up @@ -125,6 +127,7 @@ def create(
wandb_api_key: str | None = None,
verbose: bool = False,
model_limits: FinetuneTrainingLimits | None = None,
train_on_inputs: bool | Literal["auto"] = "auto",
) -> FinetuneResponse:
"""
Method to initiate a fine-tuning job
Expand All @@ -137,7 +140,7 @@ def create(
n_evals (int, optional): Number of evaluation loops to run. Defaults to 0.
n_checkpoints (int, optional): Number of checkpoints to save during fine-tuning.
Defaults to 1.
batch_size (int, optional): Batch size for fine-tuning. Defaults to max.
batch_size (int or "max"): Batch size for fine-tuning. Defaults to max.
learning_rate (float, optional): Learning rate multiplier to use for training
Defaults to 0.00001.
warmup_ratio (float, optional): Warmup ratio for learning rate scheduler.
Expand All @@ -154,6 +157,12 @@ def create(
Defaults to False.
model_limits (FinetuneTrainingLimits, optional): Limits for the hyperparameters the model in Fine-tuning.
Defaults to None.
train_on_inputs (bool or "auto"): Whether to mask the user messages in conversational data or prompts in instruction data.
"auto" will automatically determine whether to mask the inputs based on the data format.
For datasets with the "text" field (general format), inputs will not be masked.
For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
(Instruction format), inputs will be masked.
Defaults to "auto".
mryab marked this conversation as resolved.
Show resolved Hide resolved

Returns:
FinetuneResponse: Object containing information about fine-tuning job.
Expand Down Expand Up @@ -184,6 +193,7 @@ def create(
lora_trainable_modules=lora_trainable_modules,
suffix=suffix,
wandb_api_key=wandb_api_key,
train_on_inputs=train_on_inputs,
)

if verbose:
Expand Down Expand Up @@ -436,6 +446,7 @@ async def create(
wandb_api_key: str | None = None,
verbose: bool = False,
model_limits: FinetuneTrainingLimits | None = None,
train_on_inputs: bool | Literal["auto"] = "auto",
) -> FinetuneResponse:
"""
Async method to initiate a fine-tuning job
Expand Down Expand Up @@ -465,6 +476,12 @@ async def create(
Defaults to False.
model_limits (FinetuneTrainingLimits, optional): Limits for the hyperparameters the model in Fine-tuning.
Defaults to None.
train_on_inputs (bool or "auto"): Whether to mask the user messages in conversational data or prompts in instruction data.
mryab marked this conversation as resolved.
Show resolved Hide resolved
"auto" will automatically determine whether to mask the inputs based on the data format.
For datasets with the "text" field (general format), inputs will not be masked.
For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
(Instruction format), inputs will be masked.
Defaults to "auto".

Returns:
FinetuneResponse: Object containing information about fine-tuning job.
Expand Down Expand Up @@ -495,6 +512,7 @@ async def create(
lora_trainable_modules=lora_trainable_modules,
suffix=suffix,
wandb_api_key=wandb_api_key,
train_on_inputs=train_on_inputs,
)

if verbose:
Expand Down
4 changes: 3 additions & 1 deletion src/together/types/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from enum import Enum
from typing import List, Literal

from pydantic import Field, validator, field_validator
from pydantic import StrictBool, Field, validator, field_validator

from together.types.abstract import BaseModel
from together.types.common import (
Expand Down Expand Up @@ -163,6 +163,7 @@ class FinetuneRequest(BaseModel):
# weights & biases api key
wandb_key: str | None = None
training_type: FullTrainingType | LoRATrainingType | None = None
train_on_inputs: StrictBool | Literal["auto"] = "auto"


class FinetuneResponse(BaseModel):
Expand Down Expand Up @@ -230,6 +231,7 @@ class FinetuneResponse(BaseModel):
# training file metadata
training_file_num_lines: int | None = Field(None, alias="TrainingFileNumLines")
training_file_size: int | None = Field(None, alias="TrainingFileSize")
train_on_inputs: StrictBool | Literal["auto"] | None = "auto"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is None still possible as a response, because older jobs don't have that attribute?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, it's for retrieve or any other command that can see an old data


@field_validator("training_type")
@classmethod
Expand Down
Loading
Loading