Skip to content

Commit

Permalink
fix online dpo example (huggingface#1879)
Browse files Browse the repository at this point in the history
  • Loading branch information
edbeeching authored and kashif committed Jul 28, 2024
1 parent 4f043bc commit df69d1a
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions examples/scripts/online_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
--gradient_accumulation_steps 64 \
--total_episodes 30000 \
--model_name_or_path EleutherAI/pythia-14m \
--sft_model_path EleutherAI/pythia-14m \
--reward_model_path EleutherAI/pythia-14m \
--non_eos_penalty \
--stop_token eos \
Expand All @@ -41,7 +40,6 @@
--gradient_accumulation_steps 4 \
--total_episodes 1000000 \
--model_name_or_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \
--sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \
--reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \
--save_strategy no \
--non_eos_penalty \
Expand Down Expand Up @@ -96,7 +94,7 @@ def tokenize(element):
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE
reward_model = AutoModelForSequenceClassification.from_pretrained(config.reward_model_path, num_labels=1)
ref_model = AutoModelForCausalLM.from_pretrained(config.sft_model_path)
ref_model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path)

################
Expand Down

0 comments on commit df69d1a

Please sign in to comment.