-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathtrain.py
178 lines (153 loc) · 5.45 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import gc
import os
from dataclasses import asdict, dataclass
from os import PathLike
from pathlib import Path
import mlflow
import torch
import torch.onnx
from azure.storage.blob import BlobServiceClient
from lightning import Trainer
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import MLFlowLogger
from architectures.fine_tune_clsify_head import TransformerModule
from config import TrainConfig
from data import LexGlueDataModule
def training_loop(config: dataclass) -> TransformerModule:
"""Train and checkpoint the model with highest F1; log that model to MLflow and
return it."""
model = TransformerModule(
pretrained_model=config.pretrained_model,
num_classes=config.num_classes,
lr=config.lr,
)
datamodule = LexGlueDataModule(
pretrained_model=config.pretrained_model,
max_length=config.max_length,
batch_size=config.batch_size,
num_workers=config.num_workers,
debug_mode_sample=config.debug_mode_sample,
)
# Wire up MLflow context manager to Azure ML.
mlflow.set_experiment(config.mlflow_experiment_name)
with mlflow.start_run(
run_name=config.mlflow_run_name,
description=config.mlflow_description,
) as run:
# Connect Lightning's MLFlowLogger plugin to azureml-mlflow as defined in the
# context manager. TODO: MLflow metrics should show epochs rather than steps on
# the x-axis
mlf_logger = MLFlowLogger(
experiment_name=mlflow.get_experiment(run.info.experiment_id).name,
tracking_uri=mlflow.get_tracking_uri(),
log_model=True,
)
mlf_logger._run_id = run.info.run_id
mlflow.log_params(
{k: v for k, v in asdict(config).items() if not k.startswith("mlflow_")}
)
# Keep the model with the highest F1 score.
checkpoint_callback = ModelCheckpoint(
filename="{epoch}-{Val_F1_Score:.2f}",
monitor="Val_F1_Score",
mode="max",
verbose=True,
save_top_k=1,
)
# Run the training loop.
trainer = Trainer(
callbacks=[
EarlyStopping(
monitor="Val_F1_Score",
min_delta=config.min_delta,
patience=config.patience,
verbose=True,
mode="max",
),
checkpoint_callback,
],
default_root_dir=config.model_checkpoint_dir,
fast_dev_run=bool(config.debug_mode_sample),
max_epochs=config.max_epochs,
max_time=config.max_time,
precision="bf16-mixed" if torch.cuda.is_available() else "32-true",
logger=mlf_logger,
)
trainer.fit(model=model, datamodule=datamodule)
best_model_path = checkpoint_callback.best_model_path
# Evaluate the last and the best models on the test sample.
trainer.test(model=model, datamodule=datamodule)
trainer.test(
model=model,
datamodule=datamodule,
ckpt_path=best_model_path,
)
return model, datamodule
def convert_to_onnx(
model: torch.nn.Module,
save_path: PathLike | str,
sequence_length: int,
vocab_size: int,
) -> None:
model.eval()
dummy_input_ids = torch.randint(
0,
vocab_size,
(1, sequence_length),
dtype=torch.long,
)
dummy_attention_mask = torch.ones(
(1, sequence_length),
dtype=torch.long,
)
dummy_label = torch.zeros(
1,
dtype=torch.long,
)
torch.onnx.export(
model=model,
args=(dummy_input_ids, dummy_attention_mask, dummy_label),
f=save_path,
input_names=["input_ids", "attention_mask", "label"],
)
def copy_dir_to_abs(localdir: str, abs_container: str):
"""Copy the contents of a local directory to an Azure Blob Storage container.
Args:
localdir (str): The path to the local directory to copy.
abs_container (str): The name of the Azure Blob Storage container.
"""
blob_service_client = BlobServiceClient.from_connection_string(
os.getenv("CONNECTION_STRING")
)
container = blob_service_client.get_container_client(abs_container)
# Iterate through the local directory and upload each file while keeping the dir
# structure.
localdir = Path(localdir)
for filepath in localdir.rglob("*"):
if filepath.is_file():
blobpath = filepath.relative_to(localdir)
with filepath.open("rb") as file_data:
container.upload_blob(name=str(blobpath), data=file_data)
print("Successfully copied directory to Azure Blob Storage.")
if __name__ == "__main__":
# Free up gpu vRAM from memory leaks.
torch.cuda.empty_cache()
gc.collect()
train_config = TrainConfig()
# Train model.
trained_model, data_module = training_loop(train_config)
# Save model to ONNX format to the `onnx_path`.
onnx_path = os.path.join(
train_config.model_checkpoint_dir,
train_config.mlflow_run_name + "_model.onnx.pb",
)
convert_to_onnx(
model=trained_model,
save_path=onnx_path,
sequence_length=train_config.max_length,
vocab_size=data_module.tokenizer.vocab_size,
)
copy_dir_to_abs(
localdir=train_config.model_checkpoint_dir,
abs_container="model-artifacts",
)