Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TrainerState's property num_input_tokens_seen is not updating #34567

Closed
2 of 4 tasks
SwayamInSync opened this issue Nov 1, 2024 · 5 comments · Fixed by #34593
Closed
2 of 4 tasks

TrainerState's property num_input_tokens_seen is not updating #34567

SwayamInSync opened this issue Nov 1, 2024 · 5 comments · Fixed by #34593
Labels

Comments

@SwayamInSync
Copy link

System Info

- `transformers` version: 4.46.0
- Python version: 3.10.15
- Huggingface_hub version: 0.26.1
- Safetensors version: 0.4.5
- Accelerate version: 1.0.1
- Accelerate config:    not found
- PyTorch version (GPU?): 2.5.0+cu124 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: <fill in>
- Using GPU in script?: <fill in>
- GPU type: NVIDIA A100 80GB PCIe

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Here is the sample code to reproduce the error

from transformers import TrainerCallback, TrainingArguments, Trainer
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import Dataset
import torch

# Simple callback to monitor tokens
class TokenMonitorCallback(TrainerCallback):
    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step % 10 == 0:  # Print every 10 steps
            print(f"Step {state.global_step}, Tokens seen: {state.num_input_tokens_seen}")

    def on_epoch_end(self, args, state, control, **kwargs):
        print(f"Epoch end - Total tokens processed: {state.num_input_tokens_seen}")

# Create a tiny dataset
texts = ["Hello world", "This is a test", "Another example"] * 10
dataset = Dataset.from_dict({"text": texts})

# Initialize model and tokenizer
model_name = "distilgpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name)

# Tokenization function
def tokenize_function(examples):
    tokenized = tokenizer(
        examples["text"],
        padding="max_length",
        truncation=True,
        max_length=32,
        return_tensors="pt"
    )
    # Create labels by shifting input_ids
    tokenized["labels"] = tokenized["input_ids"].clone()
    return tokenized

# Tokenize dataset
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset.column_names)

# Training arguments
training_args = TrainingArguments(
    output_dir="./test-trainer",
    num_train_epochs=2,
    per_device_train_batch_size=4,
    logging_steps=10,
    save_steps=1000,
    learning_rate=2e-5,
    report_to="none"
)

# Initialize trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    callbacks=[TokenMonitorCallback()]
)

# Start training
trainer.train()

Following is the output

Epoch end - Total tokens processed: 0
Step 10, Tokens seen: 0
Epoch end - Total tokens processed: 0
TrainOutput(global_step=16, training_loss=5.371496677398682, metrics={'train_runtime': 56.2378, 'train_samples_per_second': 1.067, 'train_steps_per_second': 0.285, 'total_flos': 489931407360.0, 'train_loss': 5.371496677398682, 'epoch': 2.0})

Expected behavior

In the expected behaviour this property should be kept updating withing training loop with the number of input tokens seen on every step.

@techkang
Copy link
Contributor

techkang commented Nov 4, 2024

You need to set include_num_input_tokens_seen=True in training args.

if self.args.include_num_input_tokens_seen:

@SwayamInSync
Copy link
Author

Hey thanks @techkang
Yeah I checked this and I think it would be nice if add a line in the docs for TrainerCallback section for this, that need to mark include_num_input_tokens_seen as True

@SwayamInSync
Copy link
Author

SwayamInSync commented Nov 4, 2024

I think it expects the input_tokens to be of single batch

 0/3924 [00:00<?, ?it/s]Traceback (most recent call last):
 2   File "main.py", line 187, in <module>
 3     main()
 4   File "main.py", line 175, in main
 5     trainer.train(resume_from_checkpoint=last_checkpoint)
 6   File ".venv/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 434, in train
 7     output = super().train(*args, **kwargs)
 8   File ".venv/lib/python3.10/site-packages/transformers/trainer.py", line 2122, in train
 9     return inner_training_loop(
10   File ".venv/lib/python3.10/site-packages/transformers/trainer.py", line 2453, in _inner_training_loop
11     self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).cpu().item()
12 RuntimeError: a Tensor with 4 elements cannot be converted to Scalar
13 [rank0]: Traceback (most recent call last):
14 [rank0]:   File "main.py", line 187, in <module>
15 [rank0]:     main()
16 [rank0]:   File "main.py", line 175, in main
17 [rank0]:     trainer.train(resume_from_checkpoint=last_checkpoint)
18 [rank0]:   File ".venv/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 434, in train
19 [rank0]:     output = super().train(*args, **kwargs)
20 [rank0]:   File ".venv/lib/python3.10/site-packages/transformers/trainer.py", line 2122, in train
21 [rank0]:     return inner_training_loop(
22 [rank0]:   File ".venv/lib/python3.10/site-packages/transformers/trainer.py", line 2453, in _inner_training_loop
23 [rank0]:     self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).cpu().item()
24 [rank0]: RuntimeError: a Tensor with 4 elements cannot be converted to Scalar

The error occurs because accelerator.gather() is collecting tensors from all processes and concatenating them, resulting in a tensor with multiple elements ( 4 elements from 4 GPUs), but .item() can only convert a single-element tensor to a scalar.

calling sum before item should work as

self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).sum().cpu().item()

@SwayamInSync SwayamInSync reopened this Nov 4, 2024
@techkang
Copy link
Contributor

techkang commented Nov 4, 2024

You are correct and a PR is already proposed to fix this problem: #34554.

@SwayamInSync
Copy link
Author

Awesome, can close this now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants