Skip to content

Commit

Permalink
pushing to runpod
Browse files Browse the repository at this point in the history
  • Loading branch information
SalmanMohammadi committed Jan 9, 2025
1 parent 796fd14 commit a6ee075
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 18 deletions.
16 changes: 8 additions & 8 deletions examples/qwen2/prm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,27 +31,27 @@ wandb_name:
wandb_log_model:


gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
gradient_accumulation_steps: 1
micro_batch_size: 8
num_epochs: 1
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
float32: true
fp16: false
tf32: false
bf16: true
fp16:
tf32:
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention: false
flash_attention: false
xformers_attention:
flash_attention: true

warmup_ratio: 0.1
evals_per_epoch:
Expand Down
5 changes: 4 additions & 1 deletion src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1970,7 +1970,10 @@ def build(self, total_num_steps):
trainer_kwargs["processing_class"] = self.tokenizer
else:
trainer_kwargs["tokenizer"] = self.tokenizer
if not (trainer_cls in [AxolotlRewardTrainer, AxolotlPRMTrainer]) and self.cfg.datasets is not None:
if (
not (trainer_cls in [AxolotlRewardTrainer, AxolotlPRMTrainer])
and self.cfg.datasets is not None
):
trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]
Expand Down
17 changes: 9 additions & 8 deletions src/axolotl/prompt_strategies/stepwise_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
"""

from itertools import chain

from typing import Dict, Generator, List, Optional, Union
from typing import Dict, List, Optional, Union

from transformers import BatchEncoding, PreTrainedTokenizer

from axolotl.prompt_tokenizers import IGNORE_INDEX, PromptTokenizingStrategy
from axolotl.prompters import Prompter
from axolotl.prompt_tokenizers import IGNORE_INDEX
from axolotl.utils.dict import DictDefault


Expand Down Expand Up @@ -55,7 +53,9 @@ def tokenize_prompt(

# Handle labels
if self.train_on_last_step_only:
labels = [-100] * (len(prompt["labels"]) - 1) + [int(prompt["labels"][-1])]
labels = [IGNORE_INDEX] * (len(prompt["labels"]) - 1) + [
int(prompt["labels"][-1])
]
else:
labels = [int(label) for label in prompt["labels"]]

Expand All @@ -67,13 +67,13 @@ def tokenize_prompt(

# Create step-wise labels
labels = [
[-100] * (len(completion) - 1) + [label]
[IGNORE_INDEX] * (len(completion) - 1) + [label] # type: ignore
for completion, label in zip(completions_ids, labels)
]

# Join all steps
completion_ids = list(chain(*completions_ids))
labels = list(chain(*labels))
labels = list(chain(*labels)) # type: ignore

# Handle max lengths
if self.max_completion_length:
Expand All @@ -86,7 +86,8 @@ def tokenize_prompt(

# Combine prompt and completion
input_ids = prompt_ids + completion_ids
full_labels = [-100] * len(prompt_ids) + labels

full_labels = [IGNORE_INDEX] * len(prompt_ids) + labels
# Apply max sequence length
if self.sequence_len:
input_ids = input_ids[: self.sequence_len]
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/utils/data/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from typing import List, Tuple, Union

from datasets import (
concatenate_datasets,
Dataset,
DatasetDict,
concatenate_datasets,
load_dataset,
load_from_disk,
)
Expand Down

0 comments on commit a6ee075

Please sign in to comment.