Skip to content

Commit

Permalink
Update trainer for easier handling of accumulate, compile fixes, and …
Browse files Browse the repository at this point in the history
…proper reporting (#34511)

* Update trainer for easier handling of accumulate + proper reporting

* test

* Fixup tests

* Full fix

* Fix style

* rm comment

* Fix tests

* Minimize test + remove py 311 check

* Unused import

* Forward contrib credits from discussions

* Fix reported metrics

* Refactor, good as it's going to get

* rm pad tok id check

* object detection and audio are being annoying

* Fin

* Fin x2

---------

Co-authored-by: Gyanateet Dutta <[email protected]>
  • Loading branch information
2 people authored and ArthurZucker committed Nov 5, 2024
1 parent 5b36cda commit 8c62a92
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 48 deletions.
3 changes: 1 addition & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from functools import lru_cache, partial, wraps
from functools import partial, wraps
from threading import Thread
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from zipfile import is_zipfile
Expand Down Expand Up @@ -5014,7 +5014,6 @@ def _is_quantized_training_enabled(self):
return self.hf_quantizer.is_trainable

@property
@lru_cache
def loss_function(self):
if getattr(self.config, "loss_type", None) is not None:
loss_type = self.config.loss_type
Expand Down
73 changes: 39 additions & 34 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,6 @@
from accelerate.utils import (
DistributedDataParallelKwargs,
DistributedType,
GradientAccumulationPlugin,
load_fsdp_model,
load_fsdp_optimizer,
save_fsdp_model,
Expand Down Expand Up @@ -589,8 +588,10 @@ def __init__(
if not _is_peft_model(unwrapped_model)
else unwrapped_model.get_base_model().forward
)

self.model_accepts_loss_kwargs = "loss_kwargs" in inspect.signature(model_forward).parameters
forward_params = inspect.signature(model_forward).parameters
self.model_accepts_loss_kwargs = (
"loss_kwargs" in forward_params and forward_params["loss_kwargs"].kind == inspect.Parameter.VAR_KEYWORD
)

self.neftune_noise_alpha = args.neftune_noise_alpha

Expand Down Expand Up @@ -2424,7 +2425,7 @@ def _inner_training_loop(
update_step += 1
num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches)
for inputs in batch_samples:
for i, inputs in enumerate(batch_samples):
step += 1
total_batched_samples += 1
is_last_step_and_steps_less_than_grad_acc = (
Expand Down Expand Up @@ -2470,7 +2471,13 @@ def _inner_training_loop(
if step % args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)

with self.accelerator.accumulate(model):
# We explicitly want to avoid relying on `accelerator.accumulate` for generation training
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i == len(batch_samples) - 1
else contextlib.nullcontext
)
with context():
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)

if (
Expand Down Expand Up @@ -3602,15 +3609,11 @@ def training_step(
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
if num_items_in_batch is not None:
if self.compute_loss_func or self.model_accepts_loss_kwargs:
loss *= self.args.gradient_accumulation_steps
# Average tokens across devices is orthogonal to gradient accumulation
if self.args.average_tokens_across_devices:
loss *= self.args.world_size
self.accelerator.backward(loss, **kwargs)

return loss.detach() / self.args.gradient_accumulation_steps
# Finally we need to normalize the loss for reporting
if num_items_in_batch is None:
return loss.detach() / self.args.gradient_accumulation_steps
return loss.detach()

def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
"""
Expand All @@ -3622,9 +3625,6 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
labels = inputs.pop("labels")
else:
labels = None
if self.args.average_tokens_across_devices and num_items_in_batch is not None:
num_items_in_batch_tensor = torch.tensor(num_items_in_batch, device=self.args.device)
num_items_in_batch = int(self.accelerator.gather(num_items_in_batch_tensor).sum().cpu())
if self.model_accepts_loss_kwargs:
loss_kwargs = {}
if num_items_in_batch is not None:
Expand Down Expand Up @@ -3658,6 +3658,9 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
# We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
loss *= self.accelerator.num_processes

return (loss, outputs) if return_outputs else loss

def is_local_process_zero(self) -> bool:
Expand Down Expand Up @@ -4902,24 +4905,21 @@ def _add_sm_patterns_to_gitignore(self) -> None:
self.repo.git_push()

def create_accelerator_and_postprocess(self):
# We explicitly don't rely on the `Accelerator` to do gradient accumulation
grad_acc_kwargs = {}
if is_accelerate_available("0.28.0") and self.args.accelerator_config.gradient_accumulation_kwargs is not None:
grad_acc_kwargs = self.args.accelerator_config.gradient_accumulation_kwargs

# check if num_steps is attempted to be passed in gradient_accumulation_kwargs
if "num_steps" in grad_acc_kwargs and self.args.gradient_accumulation_steps > 1:
# raise because we do not know which setting is intended.
raise ValueError(
"The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`"
"If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`."
)
elif "num_steps" not in grad_acc_kwargs:
# take the gradient_accumulation_steps setting from TrainingArguments.
grad_acc_kwargs["num_steps"] = self.args.gradient_accumulation_steps

grad_acc_kwargs["sync_with_dataloader"] = False

gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)
if "num_steps" in grad_acc_kwargs:
if self.args.gradient_accumulation_steps > 1:
# raise because we do not know which setting is intended.
raise ValueError(
"The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`"
"If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`."
)
else:
self.args.gradient_accumulation_steps = grad_acc_kwargs["num_steps"]

accelerator_config = self.args.accelerator_config.to_dict()

Expand Down Expand Up @@ -4950,7 +4950,6 @@ def create_accelerator_and_postprocess(self):

args = {
"deepspeed_plugin": self.args.deepspeed_plugin,
"gradient_accumulation_plugin": gradient_accumulation_plugin,
}
if is_accelerate_available("0.28.0"):
args["dataloader_config"] = dataloader_config
Expand Down Expand Up @@ -5046,12 +5045,18 @@ def get_batch_samples(self, epoch_iterator, num_batches):
batch_samples += [next(epoch_iterator)]
except StopIteration:
break

# Keep default behavior the same
if not self.model_accepts_loss_kwargs:
return batch_samples, None

if len(batch_samples) > 0 and "labels" in batch_samples[0]:
# For now we don't support object detection
try:
num_items_in_batch = sum(
[data_batch["labels"][..., 1:].ne(-100).sum().item() for data_batch in batch_samples]
)
except TypeError:
num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples])
except (TypeError, AttributeError):
pass

if self.args.average_tokens_across_devices:
num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item()
return batch_samples, num_items_in_batch
43 changes: 31 additions & 12 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,19 @@ def __getitem__(self, i):
return {"input_ids": self.x, "labels": self.x}


class SequenceClassificationDataset:
def __init__(self, length=64, vocab_size=100, num_labels=5):
self.length = length
self.sequences = [torch.randint(0, vocab_size, (64,)).tolist() for _ in range(length)]
self.labels = torch.randint(0, num_labels, (length,)).tolist()

def __len__(self):
return self.length

def __getitem__(self, i):
return {"input_ids": self.sequences[i], "label": self.labels[i]}


class DynamicShapesDataset:
def __init__(self, length=64, seed=42, batch_size=8):
self.length = length
Expand Down Expand Up @@ -1144,6 +1157,23 @@ def test_number_of_steps_in_training_with_ipex(self):
train_output = trainer.train()
self.assertEqual(train_output.global_step, 10)

def test_torch_compile_loss_func_compatibility(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)

x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)

with tempfile.TemporaryDirectory() as tmp_dir:
args = TrainingArguments(
tmp_dir,
per_device_train_batch_size=2,
torch_compile=True,
max_steps=1, # compile happens on the first step
)
trainer = Trainer(model=tiny_llama, args=args, train_dataset=train_dataset) # noqa
trainer.train()

@require_peft
@require_bitsandbytes
def test_bnb_compile(self):
Expand Down Expand Up @@ -3676,9 +3706,6 @@ def test_accelerator_config_from_dict(self):
self.assertEqual(trainer.accelerator.even_batches, False)
self.assertEqual(trainer.accelerator.use_seedable_sampler, True)

if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)

def test_accelerator_config_from_yaml(self):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively
Expand All @@ -3691,8 +3718,6 @@ def test_accelerator_config_from_yaml(self):
"even_batches": False,
"use_seedable_sampler": False,
}
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
accelerator_config["gradient_accumulation_kwargs"] = {"sync_each_batch": True}
json.dump(accelerator_config, f)
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
Expand All @@ -3706,9 +3731,6 @@ def test_accelerator_config_from_yaml(self):
self.assertEqual(trainer.accelerator.even_batches, False)
self.assertEqual(trainer.accelerator.use_seedable_sampler, False)

if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)

def test_accelerator_config_from_dataclass(self):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively
Expand Down Expand Up @@ -3754,10 +3776,7 @@ def test_accelerate_config_from_dataclass_grad_accum(self):
with tempfile.TemporaryDirectory() as tmp_dir:
args = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config=accelerator_config)
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["num_steps"], 10)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["adjust_scheduler"], False)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_with_dataloader"], False)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)
self.assertEqual(trainer.args.gradient_accumulation_steps, 10)

def test_accelerator_config_from_partial(self):
# Checks that accelerator kwargs can be passed through
Expand Down

0 comments on commit 8c62a92

Please sign in to comment.