From ec49756350502dc270a99c4ad288014eb5cfaf64 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 30 Oct 2024 09:41:14 -0400 Subject: [PATCH 01/16] Update trainer for easier handling of accumulate + proper reporting --- src/transformers/trainer.py | 44 +++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 24 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index e2ae622e2b6bf3..40c3b1bfa1489a 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2445,7 +2445,7 @@ def _inner_training_loop( update_step += 1 num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches) - for inputs in batch_samples: + for i, inputs in enumerate(batch_samples): step += 1 total_batched_samples += 1 is_last_step_and_steps_less_than_grad_acc = ( @@ -2491,7 +2491,9 @@ def _inner_training_loop( if step % args.gradient_accumulation_steps == 0: self.control = self.callback_handler.on_step_begin(args, self.state, self.control) - with self.accelerator.accumulate(model): + # We explicitly want to avoid relying on `accelerator.accumulate` for generation training + context = partial(self.accelerator.no_sync, model=model) if i == len(batch_samples) - 1 else contextlib.nullcontext + with context(): tr_loss_step = self.training_step(model, inputs, num_items_in_batch) if ( @@ -3643,15 +3645,13 @@ def training_step( with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: - if num_items_in_batch is not None: - if self.compute_loss_func or self.model_accepts_loss_kwargs: - loss *= self.args.gradient_accumulation_steps - # Average tokens across devices is orthogonal to gradient accumulation - if self.args.average_tokens_across_devices: - loss *= self.args.world_size + # Average tokens across devices is orthogonal to gradient accumulation + if num_items_in_batch is not None and self.args.average_tokens_across_devices: + loss *= self.args.world_size self.accelerator.backward(loss, **kwargs) - - return loss.detach() / self.args.gradient_accumulation_steps + if num_items_in_batch is None: + return loss.detach() / self.args.gradient_accumulation_steps + return loss.detach() def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): """ @@ -4953,24 +4953,21 @@ def _add_sm_patterns_to_gitignore(self) -> None: self.repo.git_push() def create_accelerator_and_postprocess(self): + # We explicitly don't rely on the `Accelerator` to do gradient accumulation grad_acc_kwargs = {} if is_accelerate_available("0.28.0") and self.args.accelerator_config.gradient_accumulation_kwargs is not None: grad_acc_kwargs = self.args.accelerator_config.gradient_accumulation_kwargs # check if num_steps is attempted to be passed in gradient_accumulation_kwargs - if "num_steps" in grad_acc_kwargs and self.args.gradient_accumulation_steps > 1: - # raise because we do not know which setting is intended. - raise ValueError( - "The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`" - "If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`." - ) - elif "num_steps" not in grad_acc_kwargs: - # take the gradient_accumulation_steps setting from TrainingArguments. - grad_acc_kwargs["num_steps"] = self.args.gradient_accumulation_steps - - grad_acc_kwargs["sync_with_dataloader"] = False - - gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs) + if "num_steps" in grad_acc_kwargs: + if self.args.gradient_accumulation_steps > 1: + # raise because we do not know which setting is intended. + raise ValueError( + "The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`" + "If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`." + ) + else: + self.args.gradient_accumulation_steps = grad_acc_kwargs["num_steps"] accelerator_config = self.args.accelerator_config.to_dict() @@ -5001,7 +4998,6 @@ def create_accelerator_and_postprocess(self): args = { "deepspeed_plugin": self.args.deepspeed_plugin, - "gradient_accumulation_plugin": gradient_accumulation_plugin, } if is_accelerate_available("0.28.0"): args["dataloader_config"] = dataloader_config From 4cdee533ea001d8f0c0d8754c2ba08f41c264af6 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 30 Oct 2024 10:13:11 -0400 Subject: [PATCH 02/16] test --- tests/trainer/test_trainer.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index b6fe807fa4961a..7b7427a5ad702d 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1144,6 +1144,38 @@ def test_number_of_steps_in_training_with_ipex(self): train_output = trainer.train() self.assertEqual(train_output.global_step, 10) + def test_torch_compile_loss_func_compatibility(self): + from datasets import load_dataset + tiny_model = AutoModelForCausalLM.from_pretrained( + "/mnt/models/TinyLlama_v1.1", num_labels=5, + ) + + tokenizer = AutoTokenizer.from_pretrained("/mnt/models/TinyLlama_v1.1") + tokenizer.add_special_tokens({"pad_token": tokenizer.eos_token}) + tiny_model.resize_token_embeddings(len(tokenizer)) + tiny_model.config.pad_token_id = tokenizer.pad_token_id + + dataset = load_dataset("yelp_review_full")["train"].select(range(100)) + def tokenize_function(examples): + return tokenizer( + examples["text"], + max_length=20, + padding="max_length", + truncation=True + ) + tokenized_datasets = dataset.map(tokenize_function, batched=True) + + with tempfile.TemporaryDirectory() as tmp_dir: + args = TrainingArguments( + tmp_dir, + learning_rate=1e-9, + torch_compile=True, + num_train_epochs=1, + logging_steps=1, + ) + # with self.assertRaises(ValueError): + _ = Trainer(model=tiny_model, args=args, train_dataset=tokenized_datasets, tokenizer=tokenizer) # noqa + @require_peft @require_bitsandbytes def test_bnb_compile(self): From fb8070f7a5205ff9c11c56e002f3f5ab824525bd Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 30 Oct 2024 10:26:44 -0400 Subject: [PATCH 03/16] Fixup tests --- src/transformers/modeling_utils.py | 3 +-- src/transformers/trainer.py | 7 +++++-- tests/trainer/test_trainer.py | 14 ++++++-------- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 8481fa7df9cd96..2ef4c3615c9fa2 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -28,7 +28,7 @@ import warnings from contextlib import contextmanager from dataclasses import dataclass -from functools import lru_cache, partial, wraps +from functools import partial, wraps from threading import Thread from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from zipfile import is_zipfile @@ -5014,7 +5014,6 @@ def _is_quantized_training_enabled(self): return self.hf_quantizer.is_trainable @property - @lru_cache def loss_function(self): if getattr(self.config, "loss_type", None) is not None: loss_type = self.config.loss_type diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 40c3b1bfa1489a..fadab199cc04d5 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -233,7 +233,6 @@ from accelerate.utils import ( DistributedDataParallelKwargs, DistributedType, - GradientAccumulationPlugin, load_fsdp_model, load_fsdp_optimizer, save_fsdp_model, @@ -2492,7 +2491,11 @@ def _inner_training_loop( self.control = self.callback_handler.on_step_begin(args, self.state, self.control) # We explicitly want to avoid relying on `accelerator.accumulate` for generation training - context = partial(self.accelerator.no_sync, model=model) if i == len(batch_samples) - 1 else contextlib.nullcontext + context = ( + functools.partial(self.accelerator.no_sync, model=model) + if i == len(batch_samples) - 1 + else contextlib.nullcontext + ) with context(): tr_loss_step = self.training_step(model, inputs, num_items_in_batch) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 7b7427a5ad702d..6eb99a8eba18c0 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1146,8 +1146,10 @@ def test_number_of_steps_in_training_with_ipex(self): def test_torch_compile_loss_func_compatibility(self): from datasets import load_dataset + tiny_model = AutoModelForCausalLM.from_pretrained( - "/mnt/models/TinyLlama_v1.1", num_labels=5, + "/mnt/models/TinyLlama_v1.1", + num_labels=5, ) tokenizer = AutoTokenizer.from_pretrained("/mnt/models/TinyLlama_v1.1") @@ -1156,13 +1158,10 @@ def test_torch_compile_loss_func_compatibility(self): tiny_model.config.pad_token_id = tokenizer.pad_token_id dataset = load_dataset("yelp_review_full")["train"].select(range(100)) + def tokenize_function(examples): - return tokenizer( - examples["text"], - max_length=20, - padding="max_length", - truncation=True - ) + return tokenizer(examples["text"], max_length=20, padding="max_length", truncation=True) + tokenized_datasets = dataset.map(tokenize_function, batched=True) with tempfile.TemporaryDirectory() as tmp_dir: @@ -1171,7 +1170,6 @@ def tokenize_function(examples): learning_rate=1e-9, torch_compile=True, num_train_epochs=1, - logging_steps=1, ) # with self.assertRaises(ValueError): _ = Trainer(model=tiny_model, args=args, train_dataset=tokenized_datasets, tokenizer=tokenizer) # noqa From 9c6ed74b5b88c97a1c149b98fac9aca447768053 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 30 Oct 2024 10:32:38 -0400 Subject: [PATCH 04/16] Full fix --- src/transformers/testing_utils.py | 11 +++++++++++ tests/trainer/test_trainer.py | 2 ++ 2 files changed, 13 insertions(+) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 0eef286732d81c..ae44f8d5527edd 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -29,6 +29,7 @@ import tempfile import time import unittest +from packaging import version from collections import defaultdict from collections.abc import Mapping from dataclasses import MISSING, fields @@ -1291,6 +1292,16 @@ def require_jumanpp(test_case): """ return unittest.skipUnless(is_jumanpp_available(), "test requires jumanpp")(test_case) +def require_python_311(test_case): + """ + Decorator marking a test that requires python 3.11 + """ + + current_version = version.parse(sys.version.split()[0]) + min_version = version.parse("3.11") + return unittest.skipUnless(current_version >= min_version, "test requires python v3.11+")(test_case) + + def require_cython(test_case): """ diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 6eb99a8eba18c0..55678088790927 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -76,6 +76,7 @@ require_non_xpu, require_optuna, require_peft, + require_python_311, require_ray, require_safetensors, require_schedulefree, @@ -1144,6 +1145,7 @@ def test_number_of_steps_in_training_with_ipex(self): train_output = trainer.train() self.assertEqual(train_output.global_step, 10) + @require_python_311 def test_torch_compile_loss_func_compatibility(self): from datasets import load_dataset From aab546701822e31b8e005f09fc5a106c54a2219b Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 30 Oct 2024 10:37:53 -0400 Subject: [PATCH 05/16] Fix style --- src/transformers/testing_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index ae44f8d5527edd..fb564a3eee9678 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -29,7 +29,6 @@ import tempfile import time import unittest -from packaging import version from collections import defaultdict from collections.abc import Mapping from dataclasses import MISSING, fields @@ -41,6 +40,7 @@ from unittest.mock import patch import urllib3 +from packaging import version from transformers import logging as transformers_logging @@ -1292,6 +1292,7 @@ def require_jumanpp(test_case): """ return unittest.skipUnless(is_jumanpp_available(), "test requires jumanpp")(test_case) + def require_python_311(test_case): """ Decorator marking a test that requires python 3.11 @@ -1302,7 +1303,6 @@ def require_python_311(test_case): return unittest.skipUnless(current_version >= min_version, "test requires python v3.11+")(test_case) - def require_cython(test_case): """ Decorator marking a test that requires jumanpp From c56ffe627e144037afaaba7145b95a3743f0cb12 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 30 Oct 2024 10:38:52 -0400 Subject: [PATCH 06/16] rm comment --- tests/trainer/test_trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 55678088790927..9819125fd1420b 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1173,7 +1173,6 @@ def tokenize_function(examples): torch_compile=True, num_train_epochs=1, ) - # with self.assertRaises(ValueError): _ = Trainer(model=tiny_model, args=args, train_dataset=tokenized_datasets, tokenizer=tokenizer) # noqa @require_peft From 4a8a2a3a65a85c763c472f0f2148e0a80fd6a611 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 30 Oct 2024 11:15:49 -0400 Subject: [PATCH 07/16] Fix tests --- tests/trainer/test_trainer.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 9819125fd1420b..390298876d417d 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -3707,9 +3707,6 @@ def test_accelerator_config_from_dict(self): self.assertEqual(trainer.accelerator.even_batches, False) self.assertEqual(trainer.accelerator.use_seedable_sampler, True) - if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE: - self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True) - def test_accelerator_config_from_yaml(self): # Checks that accelerator kwargs can be passed through # and the accelerator is initialized respectively @@ -3722,8 +3719,6 @@ def test_accelerator_config_from_yaml(self): "even_batches": False, "use_seedable_sampler": False, } - if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE: - accelerator_config["gradient_accumulation_kwargs"] = {"sync_each_batch": True} json.dump(accelerator_config, f) config = RegressionModelConfig(a=1.5, b=2.5) model = RegressionPreTrainedModel(config) @@ -3737,9 +3732,6 @@ def test_accelerator_config_from_yaml(self): self.assertEqual(trainer.accelerator.even_batches, False) self.assertEqual(trainer.accelerator.use_seedable_sampler, False) - if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE: - self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True) - def test_accelerator_config_from_dataclass(self): # Checks that accelerator kwargs can be passed through # and the accelerator is initialized respectively @@ -3785,10 +3777,7 @@ def test_accelerate_config_from_dataclass_grad_accum(self): with tempfile.TemporaryDirectory() as tmp_dir: args = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config=accelerator_config) trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset) - self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["num_steps"], 10) - self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["adjust_scheduler"], False) - self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_with_dataloader"], False) - self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True) + self.assertEqual(trainer.args.gradient_accumulation_steps, 10) def test_accelerator_config_from_partial(self): # Checks that accelerator kwargs can be passed through From 2f41eb7930043d60798ca2d1b581d53392eec216 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 30 Oct 2024 13:25:03 -0400 Subject: [PATCH 08/16] Minimize test + remove py 311 check --- src/transformers/testing_utils.py | 10 ------- tests/trainer/test_trainer.py | 43 +++++++++++++++---------------- 2 files changed, 21 insertions(+), 32 deletions(-) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index fb564a3eee9678..92d68ed5f85df6 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -1293,16 +1293,6 @@ def require_jumanpp(test_case): return unittest.skipUnless(is_jumanpp_available(), "test requires jumanpp")(test_case) -def require_python_311(test_case): - """ - Decorator marking a test that requires python 3.11 - """ - - current_version = version.parse(sys.version.split()[0]) - min_version = version.parse("3.11") - return unittest.skipUnless(current_version >= min_version, "test requires python v3.11+")(test_case) - - def require_cython(test_case): """ Decorator marking a test that requires jumanpp diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 390298876d417d..5658372fa71308 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -76,7 +76,6 @@ require_non_xpu, require_optuna, require_peft, - require_python_311, require_ray, require_safetensors, require_schedulefree, @@ -273,6 +272,19 @@ def __getitem__(self, i): return {"input_ids": self.x, "labels": self.x} +class SequenceClassificationDataset: + def __init__(self, length=64, vocab_size=100, num_labels=5): + self.length = length + self.sequences = [torch.randint(0, vocab_size, (64,)).tolist() for _ in range(length)] + self.labels = torch.randint(0, num_labels, (length,)).tolist() + + def __len__(self): + return self.length + + def __getitem__(self, i): + return {"input_ids": self.sequences[i], "label": self.labels[i]} + + class DynamicShapesDataset: def __init__(self, length=64, seed=42, batch_size=8): self.length = length @@ -1145,35 +1157,22 @@ def test_number_of_steps_in_training_with_ipex(self): train_output = trainer.train() self.assertEqual(train_output.global_step, 10) - @require_python_311 def test_torch_compile_loss_func_compatibility(self): - from datasets import load_dataset - - tiny_model = AutoModelForCausalLM.from_pretrained( - "/mnt/models/TinyLlama_v1.1", - num_labels=5, - ) - - tokenizer = AutoTokenizer.from_pretrained("/mnt/models/TinyLlama_v1.1") - tokenizer.add_special_tokens({"pad_token": tokenizer.eos_token}) - tiny_model.resize_token_embeddings(len(tokenizer)) - tiny_model.config.pad_token_id = tokenizer.pad_token_id - - dataset = load_dataset("yelp_review_full")["train"].select(range(100)) - - def tokenize_function(examples): - return tokenizer(examples["text"], max_length=20, padding="max_length", truncation=True) + config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) + tiny_llama = LlamaForCausalLM(config) - tokenized_datasets = dataset.map(tokenize_function, batched=True) + x = torch.randint(0, 100, (128,)) + train_dataset = RepeatDataset(x) with tempfile.TemporaryDirectory() as tmp_dir: args = TrainingArguments( tmp_dir, - learning_rate=1e-9, + per_device_train_batch_size=2, torch_compile=True, - num_train_epochs=1, + max_steps=1, # compile happens on the first step ) - _ = Trainer(model=tiny_model, args=args, train_dataset=tokenized_datasets, tokenizer=tokenizer) # noqa + trainer = Trainer(model=tiny_llama, args=args, train_dataset=train_dataset) # noqa + trainer.train() @require_peft @require_bitsandbytes From 27eaadc040d8c74d11619c352d908b6d8e0e2513 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 30 Oct 2024 13:26:39 -0400 Subject: [PATCH 09/16] Unused import --- src/transformers/testing_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 92d68ed5f85df6..0eef286732d81c 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -40,7 +40,6 @@ from unittest.mock import patch import urllib3 -from packaging import version from transformers import logging as transformers_logging From 6fcb0b597d3acaf399e33dcfa301be5720679533 Mon Sep 17 00:00:00 2001 From: Gyanateet Dutta Date: Wed, 30 Oct 2024 14:23:12 -0400 Subject: [PATCH 10/16] Forward contrib credits from discussions From 43f6a2f4e0380140f03d801413507528832241c4 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 30 Oct 2024 20:15:37 -0400 Subject: [PATCH 11/16] Fix reported metrics --- src/transformers/trainer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index fadab199cc04d5..2a93294879bdf5 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3652,9 +3652,11 @@ def training_step( if num_items_in_batch is not None and self.args.average_tokens_across_devices: loss *= self.args.world_size self.accelerator.backward(loss, **kwargs) - if num_items_in_batch is None: - return loss.detach() / self.args.gradient_accumulation_steps - return loss.detach() + # Finally we need to normalize the loss for reporting + loss = loss.detach() / self.args.gradient_accumulation_steps + if self.args.average_tokens_across_devices: + loss /= self.args.world_size + return loss def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): """ From 238c9859a37de5f20bffb6638b62f63c926a9cf3 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 30 Oct 2024 21:09:57 -0400 Subject: [PATCH 12/16] Refactor, good as it's going to get --- src/transformers/trainer.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 2a93294879bdf5..33e8eab64ee61b 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -600,8 +600,8 @@ def __init__( if not _is_peft_model(unwrapped_model) else unwrapped_model.get_base_model().forward ) - - self.model_accepts_loss_kwargs = "loss_kwargs" in inspect.signature(model_forward).parameters + params = inspect.signature(model_forward).parameters + self.model_accepts_loss_kwargs = "loss_kwargs" in params or "kwargs" in params self.neftune_noise_alpha = args.neftune_noise_alpha @@ -3648,15 +3648,11 @@ def training_step( with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: - # Average tokens across devices is orthogonal to gradient accumulation - if num_items_in_batch is not None and self.args.average_tokens_across_devices: - loss *= self.args.world_size self.accelerator.backward(loss, **kwargs) # Finally we need to normalize the loss for reporting - loss = loss.detach() / self.args.gradient_accumulation_steps - if self.args.average_tokens_across_devices: - loss /= self.args.world_size - return loss + if num_items_in_batch is None: + return loss.detach() / self.args.gradient_accumulation_steps + return loss.detach() def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): """ @@ -3668,13 +3664,12 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N labels = inputs.pop("labels") else: labels = None - if self.args.average_tokens_across_devices and num_items_in_batch is not None: - num_items_in_batch_tensor = torch.tensor(num_items_in_batch, device=self.args.device) - num_items_in_batch = int(self.accelerator.gather(num_items_in_batch_tensor).sum().cpu()) if self.model_accepts_loss_kwargs: loss_kwargs = {} if num_items_in_batch is not None: loss_kwargs["num_items_in_batch"] = num_items_in_batch + if self.processing_class is not None: + loss_kwargs["ignore_index"] = self.processing_class.pad_token_id inputs = {**inputs, **loss_kwargs} outputs = model(**inputs) # Save past state if it exists @@ -3704,6 +3699,9 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N # We don't use .loss here since the model may return tuples instead of ModelOutput. loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] + if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs: + loss *= self.accelerator.num_processes + return (loss, outputs) if return_outputs else loss def is_local_process_zero(self) -> bool: @@ -5101,9 +5099,10 @@ def get_batch_samples(self, epoch_iterator, num_batches): if len(batch_samples) > 0 and "labels" in batch_samples[0]: # For now we don't support object detection try: - num_items_in_batch = sum( - [data_batch["labels"][..., 1:].ne(-100).sum().item() for data_batch in batch_samples] - ) + num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples]) except TypeError: pass + + if self.args.average_tokens_across_devices: + num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item() return batch_samples, num_items_in_batch From 93f36e8e9160e239c015f3e9043d1661654199e1 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 30 Oct 2024 21:17:48 -0400 Subject: [PATCH 13/16] rm pad tok id check --- src/transformers/trainer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 33e8eab64ee61b..b0f752b62ea827 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3668,8 +3668,6 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N loss_kwargs = {} if num_items_in_batch is not None: loss_kwargs["num_items_in_batch"] = num_items_in_batch - if self.processing_class is not None: - loss_kwargs["ignore_index"] = self.processing_class.pad_token_id inputs = {**inputs, **loss_kwargs} outputs = model(**inputs) # Save past state if it exists From c6332197716f6df9118e9ebce5a3f30658eca9d9 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Thu, 31 Oct 2024 09:41:14 -0400 Subject: [PATCH 14/16] object detection and audio are being annoying --- src/transformers/trainer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index b0f752b62ea827..fc8fd87af10962 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -600,8 +600,7 @@ def __init__( if not _is_peft_model(unwrapped_model) else unwrapped_model.get_base_model().forward ) - params = inspect.signature(model_forward).parameters - self.model_accepts_loss_kwargs = "loss_kwargs" in params or "kwargs" in params + self.model_accepts_loss_kwargs = "loss_kwargs" in inspect.signature(model_forward).parameters self.neftune_noise_alpha = args.neftune_noise_alpha @@ -5098,7 +5097,7 @@ def get_batch_samples(self, epoch_iterator, num_batches): # For now we don't support object detection try: num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples]) - except TypeError: + except (TypeError, AttributeError): pass if self.args.average_tokens_across_devices: From c452194990d3af6bad7dbd599493b5df90d66cb5 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Thu, 31 Oct 2024 09:44:30 -0400 Subject: [PATCH 15/16] Fin --- src/transformers/trainer.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index fc8fd87af10962..3f092b89b8f61d 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -5093,6 +5093,11 @@ def get_batch_samples(self, epoch_iterator, num_batches): batch_samples += [next(epoch_iterator)] except StopIteration: break + + # Keep default behavior the same + if not self.model_accepts_loss_kwargs: + return batch_samples, None + if len(batch_samples) > 0 and "labels" in batch_samples[0]: # For now we don't support object detection try: From 7739461098f802ee1d03b754c295acaf494c8623 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Thu, 31 Oct 2024 09:50:37 -0400 Subject: [PATCH 16/16] Fin x2 --- src/transformers/trainer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 3f092b89b8f61d..1dab61d8fd91cf 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -600,7 +600,10 @@ def __init__( if not _is_peft_model(unwrapped_model) else unwrapped_model.get_base_model().forward ) - self.model_accepts_loss_kwargs = "loss_kwargs" in inspect.signature(model_forward).parameters + forward_params = inspect.signature(model_forward).parameters + self.model_accepts_loss_kwargs = ( + "loss_kwargs" in forward_params and forward_params["loss_kwargs"].kind == inspect.Parameter.VAR_KEYWORD + ) self.neftune_noise_alpha = args.neftune_noise_alpha