-
Notifications
You must be signed in to change notification settings - Fork 59
/
train_imdbclassifier-Copy1.py
58 lines (44 loc) · 2.2 KB
/
train_imdbclassifier-Copy1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# source: https://github.com/fastai/fastai2/blob/master/nbs/examples/train_imdbclassifier.py
# unique change: line 52 (off) to line 53 (on)
from fastai2.basics import *
from fastai2.callback.all import *
from fastai2.distributed import *
from fastprogress import fastprogress
from fastai2.callback.mixup import *
from fastscript import *
from fastai2.text.all import *
torch.backends.cudnn.benchmark = True
fastprogress.MAX_COLS = 80
@call_parse
def main(
gpu: Param("GPU to run on", int)=None,
lr: Param("base Learning rate", float)=1e-2,
bs: Param("Batch size", int)=64,
epochs:Param("Number of epochs", int)=5,
fp16: Param("Use mixed precision training", int)=0,
dump: Param("Print model; don't train", int)=0,
runs: Param("Number of times to repeat training", int)=1,
):
"Training of IMDB classifier."
if torch.cuda.is_available():
n_gpu = torch.cuda.device_count()
if gpu is None: gpu = list(range(n_gpu))[0]
torch.cuda.set_device(gpu)
else:
n_gpu = None
path = rank0_first(lambda:untar_data(URLs.IMDB))
dls = TextDataLoaders.from_folder(path, bs=bs, valid='test')
for run in range(runs):
print(f'Rank[{rank_distrib()}] Run: {run}; epochs: {epochs}; lr: {lr}; bs: {bs}')
learn = rank0_first(lambda: text_classifier_learner(dls, AWD_LSTM, drop_mult=0.5, metrics=accuracy))
if dump: print(learn.model); exit()
if fp16: learn = learn.to_fp16()
# TODO: DataParallel would hit floating point error, disabled for now.
# if gpu is None and n_gpu: ctx = partial(learn.parallel_ctx, device_ids=list(range(n_gpu)))
# Workaround: In PyTorch 1.4, need to set DistributedDataParallel() with find_unused_parameters=True,
# to avoid a crash that only happens in distributed mode of text_classifier_learner.fine_tune()
# if num_distrib() > 1 and torch.__version__.startswith("1.4"): DistributedTrainer.fup = True
DistributedTrainer.fup = True
with learn.distrib_ctx(cuda_id=gpu): # distributed traing requires "-m fastai2.launch"
print(f"Training in distributed data parallel context on GPU {gpu}", flush=True)
learn.fine_tune(epochs, lr)