-
-
Notifications
You must be signed in to change notification settings - Fork 926
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
KD trainer w/ logprobs #2202
base: main
Are you sure you want to change the base?
KD trainer w/ logprobs #2202
Conversation
9cc1a77
to
a952e84
Compare
return super()._save_checkpoint(model, trial, **kwargs) | ||
|
||
|
||
class AxolotlMambaTrainer(AxolotlTrainer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not for this PR, but how would we feel about moving each of these trainers to their own file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
having a integrations/trl/
folder would be neat too.
ab49180
to
4a0ab11
Compare
fix loader default
@@ -13,6 +13,12 @@ class PreprocessCliArgs: | |||
debug_num_examples: int = field(default=1) | |||
prompter: Optional[str] = field(default=None) | |||
download: Optional[bool] = field(default=True) | |||
iterable: Optional[bool] = field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it worth documenting somewhere (or raising an issue as a reminder) to show users that we offer support for iterable datasets?
@@ -39,6 +39,8 @@ def preprocess(config: str, **kwargs) -> None: | |||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` | |||
config options. | |||
""" | |||
kwargs = {k: v for k, v in kwargs.items() if v is not None} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't using @filter_none_kwargs
address this?
@@ -0,0 +1,58 @@ | |||
### AXOLOTL COMMUNITY LICENSE AGREEMENT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it worth surfacing this community license somewhere in our docs/README and how it would be used? Fine to leave as a follow up/issue.
Input args for knowledge distillation. | ||
""" | ||
|
||
kd_trainer: Optional[bool] = None # whether to use KD trainer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Slightly confused - what happens when kd_trainer=False
?
if "input_ids" not in sample: | ||
# If there's no "input_ids", just return sample unchanged | ||
return sample | ||
|
||
input_ids = sample["input_ids"] | ||
|
||
# Detect if it's a single example or a batch | ||
if not input_ids: | ||
# Edge case: empty | ||
return sample |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit
if "input_ids" not in sample: | |
# If there's no "input_ids", just return sample unchanged | |
return sample | |
input_ids = sample["input_ids"] | |
# Detect if it's a single example or a batch | |
if not input_ids: | |
# Edge case: empty | |
return sample | |
# Return sample unchanged if "input_ids" is not present, or is empty | |
if "input_ids" not in sample or not sample["input_ids"]: | |
return sample | |
input_ids = sample["input_ids"] |
input_ids = sample["input_ids"] | ||
|
||
# Edge case: if input_ids is empty | ||
if not input_ids: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should you do the same check above for not sample["input_ids"]
?
@@ -172,10 +209,31 @@ def add_length(sample): | |||
|
|||
|
|||
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this is more of a fiiltering function with the signature def filter_sequence_length(...) -> Union[bool, List[bool]]
right?
max_input_len = np.max(get_dataset_lengths(train_dataset)) | ||
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True) | ||
except AttributeError: | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there anything informative worth logging here?
try: | ||
prior_len = len(train_dataset) | ||
except TypeError: | ||
# handle iterable datasets case |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How come an isinstance
check like above wouldn't work here?
# If it's a list, we assume we're dealing with a batch | ||
if isinstance(labels[0], int): | ||
# Single example: return a single bool | ||
return np.sum(np.array(labels) != -100) > 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit
return np.sum(np.array(labels) != -100) > 0 | |
return np.any(labels != -100) |
results = [] | ||
for row_labels in labels: | ||
# Each row_labels is a list[int] | ||
results.append(np.sum(np.array(row_labels) != -100) > 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit
results = [] | |
for row_labels in labels: | |
# Each row_labels is a list[int] | |
results.append(np.sum(np.array(row_labels) != -100) > 0) | |
results = [np.any(row_labels != -100) for row_labels in labels] |
"dataloader_prefetch_factor": 8, | ||
"dataloader_num_workers": 4, | ||
"dataloader_pin_memory": True, | ||
# "dataset_prepared_path": str(Path(temp_dir) / "last_run_prepared"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# "dataset_prepared_path": str(Path(temp_dir) / "last_run_prepared"), |
might as well chop while we're here
@@ -29,7 +29,9 @@ def get_ds_type(config_dataset: DictDefault): | |||
return ds_type | |||
|
|||
|
|||
def load_dataset_w_config(config_dataset, auth_token): | |||
def load_dataset_w_config( | |||
config_dataset, auth_token, streaming=False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Worth folding streaming
into config_dataset
so you're doing config_dataset.streaming
?
teacher_seq_len = target_token_ids.shape[1] | ||
|
||
# Slice student logits to match teacher-provided sequence length | ||
student_logits_for_kd = student_logits[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When would the student be predicting a longer sequence length?
I've taken a first pass and it looks pretty good overall. I think the KD logic checks out. I'll make another pass tomorrow. |
kd_temperature: float = 1.0, | ||
) -> torch.Tensor: | ||
""" | ||
A KD loss function that is TorchScript-friendly. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add docs for what the parameters are? I'd mainly like to clarify what target_mask
is used for but might as well doc them all : )
Description
Motivation and Context
How has this been tested?
Screenshots (if appropriate)
Types of changes
Social Handles (Optional)