Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update trainer for easier handling of accumulate, compile fixes, and proper reporting #34511

Merged
merged 17 commits into from
Nov 4, 2024
Merged
3 changes: 1 addition & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from functools import lru_cache, partial, wraps
from functools import partial, wraps
from threading import Thread
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from zipfile import is_zipfile
Expand Down Expand Up @@ -5014,7 +5014,6 @@ def _is_quantized_training_enabled(self):
return self.hf_quantizer.is_trainable

@property
@lru_cache
def loss_function(self):
if getattr(self.config, "loss_type", None) is not None:
loss_type = self.config.loss_type
Expand Down
11 changes: 11 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import tempfile
import time
import unittest
from packaging import version
from collections import defaultdict
from collections.abc import Mapping
from dataclasses import MISSING, fields
Expand Down Expand Up @@ -1291,6 +1292,16 @@ def require_jumanpp(test_case):
"""
return unittest.skipUnless(is_jumanpp_available(), "test requires jumanpp")(test_case)

def require_python_311(test_case):
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
"""
Decorator marking a test that requires python 3.11
"""

current_version = version.parse(sys.version.split()[0])
min_version = version.parse("3.11")
return unittest.skipUnless(current_version >= min_version, "test requires python v3.11+")(test_case)



def require_cython(test_case):
"""
Expand Down
49 changes: 24 additions & 25 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,6 @@
from accelerate.utils import (
DistributedDataParallelKwargs,
DistributedType,
GradientAccumulationPlugin,
load_fsdp_model,
load_fsdp_optimizer,
save_fsdp_model,
Expand Down Expand Up @@ -2445,7 +2444,7 @@ def _inner_training_loop(
update_step += 1
num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches)
for inputs in batch_samples:
for i, inputs in enumerate(batch_samples):
step += 1
total_batched_samples += 1
is_last_step_and_steps_less_than_grad_acc = (
Expand Down Expand Up @@ -2491,7 +2490,13 @@ def _inner_training_loop(
if step % args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)

with self.accelerator.accumulate(model):
# We explicitly want to avoid relying on `accelerator.accumulate` for generation training
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i == len(batch_samples) - 1
else contextlib.nullcontext
)
Comment on lines +2496 to +2500
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For an explanation on what we have going on here @Rocketknight1 , during DDP we use model.no_sync() to only communicate across all GPUs during the next step outside it (so we speed up training when not needed when doing gradient accumulation). accelerator.no_sync() is the lower-level accumulate() API which makes that op backed-independent (so on a single GPU it just does nullcontext)

with context():
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)

if (
Expand Down Expand Up @@ -3643,15 +3648,13 @@ def training_step(
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
if num_items_in_batch is not None:
if self.compute_loss_func or self.model_accepts_loss_kwargs:
loss *= self.args.gradient_accumulation_steps
# Average tokens across devices is orthogonal to gradient accumulation
if self.args.average_tokens_across_devices:
loss *= self.args.world_size
# Average tokens across devices is orthogonal to gradient accumulation
if num_items_in_batch is not None and self.args.average_tokens_across_devices:
loss *= self.args.world_size
self.accelerator.backward(loss, **kwargs)

return loss.detach() / self.args.gradient_accumulation_steps
if num_items_in_batch is None:
return loss.detach() / self.args.gradient_accumulation_steps
return loss.detach()

def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
"""
Expand Down Expand Up @@ -4953,24 +4956,21 @@ def _add_sm_patterns_to_gitignore(self) -> None:
self.repo.git_push()

def create_accelerator_and_postprocess(self):
# We explicitly don't rely on the `Accelerator` to do gradient accumulation
grad_acc_kwargs = {}
if is_accelerate_available("0.28.0") and self.args.accelerator_config.gradient_accumulation_kwargs is not None:
grad_acc_kwargs = self.args.accelerator_config.gradient_accumulation_kwargs

# check if num_steps is attempted to be passed in gradient_accumulation_kwargs
if "num_steps" in grad_acc_kwargs and self.args.gradient_accumulation_steps > 1:
# raise because we do not know which setting is intended.
raise ValueError(
"The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`"
"If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`."
)
elif "num_steps" not in grad_acc_kwargs:
# take the gradient_accumulation_steps setting from TrainingArguments.
grad_acc_kwargs["num_steps"] = self.args.gradient_accumulation_steps

grad_acc_kwargs["sync_with_dataloader"] = False

gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)
if "num_steps" in grad_acc_kwargs:
if self.args.gradient_accumulation_steps > 1:
# raise because we do not know which setting is intended.
raise ValueError(
"The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`"
"If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`."
)
else:
self.args.gradient_accumulation_steps = grad_acc_kwargs["num_steps"]

accelerator_config = self.args.accelerator_config.to_dict()

Expand Down Expand Up @@ -5001,7 +5001,6 @@ def create_accelerator_and_postprocess(self):

args = {
"deepspeed_plugin": self.args.deepspeed_plugin,
"gradient_accumulation_plugin": gradient_accumulation_plugin,
}
if is_accelerate_available("0.28.0"):
args["dataloader_config"] = dataloader_config
Expand Down
32 changes: 32 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
require_non_xpu,
require_optuna,
require_peft,
require_python_311,
require_ray,
require_safetensors,
require_schedulefree,
Expand Down Expand Up @@ -1144,6 +1145,37 @@ def test_number_of_steps_in_training_with_ipex(self):
train_output = trainer.train()
self.assertEqual(train_output.global_step, 10)

@require_python_311
def test_torch_compile_loss_func_compatibility(self):
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
from datasets import load_dataset

tiny_model = AutoModelForCausalLM.from_pretrained(
"/mnt/models/TinyLlama_v1.1",
num_labels=5,
)

tokenizer = AutoTokenizer.from_pretrained("/mnt/models/TinyLlama_v1.1")
tokenizer.add_special_tokens({"pad_token": tokenizer.eos_token})
tiny_model.resize_token_embeddings(len(tokenizer))
tiny_model.config.pad_token_id = tokenizer.pad_token_id

dataset = load_dataset("yelp_review_full")["train"].select(range(100))

def tokenize_function(examples):
return tokenizer(examples["text"], max_length=20, padding="max_length", truncation=True)

tokenized_datasets = dataset.map(tokenize_function, batched=True)

with tempfile.TemporaryDirectory() as tmp_dir:
args = TrainingArguments(
tmp_dir,
learning_rate=1e-9,
torch_compile=True,
num_train_epochs=1,
)
# with self.assertRaises(ValueError):
_ = Trainer(model=tiny_model, args=args, train_dataset=tokenized_datasets, tokenizer=tokenizer) # noqa

@require_peft
@require_bitsandbytes
def test_bnb_compile(self):
Expand Down
Loading