Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KD trainer w/ logprobs #2202

Draft
wants to merge 70 commits into
base: main
Choose a base branch
from
Draft

KD trainer w/ logprobs #2202

wants to merge 70 commits into from

Conversation

winglian
Copy link
Collaborator

Description

Motivation and Context

How has this been tested?

Screenshots (if appropriate)

Types of changes

Social Handles (Optional)

@winglian winglian changed the title KD trainer KD trainer w/ logprobs Dec 19, 2024
@winglian
Copy link
Collaborator Author

@SalmanMohammadi SalmanMohammadi self-requested a review January 2, 2025 16:58
@winglian winglian force-pushed the kd-trainer branch 2 times, most recently from 9cc1a77 to a952e84 Compare January 4, 2025 00:45
return super()._save_checkpoint(model, trial, **kwargs)


class AxolotlMambaTrainer(AxolotlTrainer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for this PR, but how would we feel about moving each of these trainers to their own file?

Copy link
Contributor

@SalmanMohammadi SalmanMohammadi Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

having a integrations/trl/ folder would be neat too.

@@ -13,6 +13,12 @@ class PreprocessCliArgs:
debug_num_examples: int = field(default=1)
prompter: Optional[str] = field(default=None)
download: Optional[bool] = field(default=True)
iterable: Optional[bool] = field(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth documenting somewhere (or raising an issue as a reminder) to show users that we offer support for iterable datasets?

@@ -39,6 +39,8 @@ def preprocess(config: str, **kwargs) -> None:
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't using @filter_none_kwargs address this?

@@ -0,0 +1,58 @@
### AXOLOTL COMMUNITY LICENSE AGREEMENT
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth surfacing this community license somewhere in our docs/README and how it would be used? Fine to leave as a follow up/issue.

Input args for knowledge distillation.
"""

kd_trainer: Optional[bool] = None # whether to use KD trainer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slightly confused - what happens when kd_trainer=False?

Comment on lines +103 to +112
if "input_ids" not in sample:
# If there's no "input_ids", just return sample unchanged
return sample

input_ids = sample["input_ids"]

# Detect if it's a single example or a batch
if not input_ids:
# Edge case: empty
return sample
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
if "input_ids" not in sample:
# If there's no "input_ids", just return sample unchanged
return sample
input_ids = sample["input_ids"]
# Detect if it's a single example or a batch
if not input_ids:
# Edge case: empty
return sample
# Return sample unchanged if "input_ids" is not present, or is empty
if "input_ids" not in sample or not sample["input_ids"]:
return sample
input_ids = sample["input_ids"]

input_ids = sample["input_ids"]

# Edge case: if input_ids is empty
if not input_ids:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should you do the same check above for not sample["input_ids"]?

@@ -172,10 +209,31 @@ def add_length(sample):


def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is more of a fiiltering function with the signature def filter_sequence_length(...) -> Union[bool, List[bool]] right?

max_input_len = np.max(get_dataset_lengths(train_dataset))
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
except AttributeError:
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there anything informative worth logging here?

try:
prior_len = len(train_dataset)
except TypeError:
# handle iterable datasets case
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How come an isinstance check like above wouldn't work here?

# If it's a list, we assume we're dealing with a batch
if isinstance(labels[0], int):
# Single example: return a single bool
return np.sum(np.array(labels) != -100) > 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
return np.sum(np.array(labels) != -100) > 0
return np.any(labels != -100)

Comment on lines +325 to +328
results = []
for row_labels in labels:
# Each row_labels is a list[int]
results.append(np.sum(np.array(row_labels) != -100) > 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
results = []
for row_labels in labels:
# Each row_labels is a list[int]
results.append(np.sum(np.array(row_labels) != -100) > 0)
results = [np.any(row_labels != -100) for row_labels in labels]

"dataloader_prefetch_factor": 8,
"dataloader_num_workers": 4,
"dataloader_pin_memory": True,
# "dataset_prepared_path": str(Path(temp_dir) / "last_run_prepared"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# "dataset_prepared_path": str(Path(temp_dir) / "last_run_prepared"),

might as well chop while we're here

@@ -29,7 +29,9 @@ def get_ds_type(config_dataset: DictDefault):
return ds_type


def load_dataset_w_config(config_dataset, auth_token):
def load_dataset_w_config(
config_dataset, auth_token, streaming=False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth folding streaming into config_dataset so you're doing config_dataset.streaming?

teacher_seq_len = target_token_ids.shape[1]

# Slice student logits to match teacher-provided sequence length
student_logits_for_kd = student_logits[
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When would the student be predicting a longer sequence length?

@SalmanMohammadi
Copy link
Contributor

I've taken a first pass and it looks pretty good overall. I think the KD logic checks out. I'll make another pass tomorrow.

kd_temperature: float = 1.0,
) -> torch.Tensor:
"""
A KD loss function that is TorchScript-friendly.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add docs for what the parameters are? I'd mainly like to clarify what target_mask is used for but might as well doc them all : )

@winglian winglian added the scheduled_release This PR is slated for the upcoming release label Jan 23, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
scheduled_release This PR is slated for the upcoming release
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants