-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
78 lines (61 loc) · 2.48 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# Adapted from Tevatron code
import logging
import sys
from transformers import AutoTokenizer, AutoProcessor
from transformers import LlavaNextProcessor
from transformers import (
HfArgumentParser,
)
from src.dataset import TrainDataset
from src.collator import TrainCollator
from src.arguments import ModelArguments, DataArguments, TrainingArguments
from src.model import MMEBModel
from src.trainer import MMEBTrainer, GradCacheTrainer
import wandb
import torch
import torch.distributed as dist
logger = logging.getLogger(__name__)
def main():
# a hack for torch.distributed.launch: https://github.com/huggingface/transformers/issues/22171
for arg in sys.argv:
if arg.startswith("--local-rank="):
rank = arg.split("=")[1]
sys.argv.remove(arg)
sys.argv.append('--local_rank')
sys.argv.append(rank)
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model_args: ModelArguments
data_args: DataArguments
training_args: TrainingArguments
if (dist.is_initialized() and torch.distributed.get_rank() == 0) or (not dist.is_initialized()):
wandb.init(project=training_args.project_name, name=training_args.run_name)
if model_args.model_backbone == "llava":
processor = LlavaNextProcessor.from_pretrained(
model_args.processor_name if model_args.processor_name else model_args.model_name,
trust_remote_code=True)
processor.tokenizer.padding_side = "left"
else:
processor = AutoProcessor.from_pretrained(
model_args.processor_name if model_args.processor_name else model_args.model_name,
trust_remote_code=True,
num_crops=model_args.num_crops
)
processor.tokenizer.padding_side = "right"
train_dataset = TrainDataset(data_args, model_args)
collator = TrainCollator(data_args, model_args, processor)
model = MMEBModel.build(model_args)
trainer_cls = GradCacheTrainer if training_args.grad_cache else MMEBTrainer
trainer = trainer_cls(
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=collator,
)
train_dataset.trainer = trainer
trainer.train()
trainer.save_model(training_args.output_dir)
if trainer.is_world_process_zero():
processor.save_pretrained(training_args.output_dir)
if __name__ == "__main__":
main()