-
Notifications
You must be signed in to change notification settings - Fork 4
/
downstream_train.py
94 lines (83 loc) · 4.99 KB
/
downstream_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
from trainer.downstream_trainer import SemanticKITTITrainer
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
import torch
from utils import *
import argparse
from numpy import inf
import MinkowskiEngine as ME
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='STSSL_fine_tune')
parser.add_argument('--dataset-name', type=str, default='SemanticKITTI',
help='Name of dataset (default: SemanticKITTI')
parser.add_argument('--data_dir', type=str, default='./Datasets/SemanticKITTI',
help='Path to dataset (default: ./Datasets/SemanticKITTI')
parser.add_argument('--batch-size', type=int, default=2, metavar='N',
help='input training batch-size')
parser.add_argument('--epochs', type=int, default=15, metavar='N',
help='number of training epochs (default: 15)')
parser.add_argument('--lr', type=float, default=1.0e-1,
help='learning rate (default: 2.4e-1')
parser.add_argument("--decay-lr", default=1e-4, action="store", type=float,
help='Learning rate decay (default: 1e-4')
parser.add_argument('--tau', default=0.1, type=float,
help='Tau temperature smoothing (default 0.1)')
parser.add_argument('--log_dir', type=str, default='checkpoint/downstream_task',
help='logging directory (default: checkpoint/downstream_task)')
parser.add_argument('--load_path', type=str, default='None',
help='load pretrain models')
parser.add_argument('--checkpoint', type=str, default='classifier_checkpoint',
help='model checkpoint (default: classifier_checkpoint)')
parser.add_argument('--use-cuda', action='store_true', default=False,
help='using cuda (default: True')
parser.add_argument('--device-id', type=int, default=0,
help='GPU device id (default: 0')
parser.add_argument('--feature-size', type=int, default=128,
help='Feature output size (default: 128')
parser.add_argument('--sparse_resolution', type=float, default=0.05,
help='Sparse tensor resolution (default: 0.05')
parser.add_argument('--percentage_labels', type=float, default=1.0,
help='Percentage of labels used for training (default: 1.0')
parser.add_argument('--num-points', type=int, default=80000,
help='Number of points sampled from point clouds (default: 80000')
parser.add_argument('--sparse-model', type=str, default='MinkUNet',
help='Sparse model to be used (default: MinkUNet')
parser.add_argument('--linear-eval', action='store_true', default=False,
help='Fine-tune or linear evaluation (default: False')
parser.add_argument('--use-intensity', action='store_true', default=True,
help='use points intensity (default: False')
parser.add_argument('--contrastive', action='store_true', default=False,
help='use contrastive pre-trained weights (default: False')
parser.add_argument('--accum-steps', type=int, default=1,
help='Number steps to accumulate gradient')
parser.add_argument('--num-workers', type=int, default=16, # 16
help='how many workers we use to load data usually 16')
args = parser.parse_args()
if args.use_cuda:
dtype = torch.cuda.FloatTensor
device = torch.device("cuda")
print('GPU')
else:
dtype = torch.FloatTensor
device = torch.device("cpu")
set_deterministic()
data_train, data_test = get_dataset(args, pre_training=False)
train_loader = get_data_loader(data_train, args, pre_training=False)
test_loader = get_data_loader(data_test, args, pre_training=False)
criterion = torch.nn.CrossEntropyLoss(ignore_index=0)
model = get_model(args, dtype)
model_head = get_classifier_head(args, dtype)
if torch.cuda.device_count() > 1:
model = ME.MinkowskiSyncBatchNorm.convert_sync_batchnorm(model)
model_head = ME.MinkowskiSyncBatchNorm.convert_sync_batchnorm(model_head)
model_sem_kitti = SemanticKITTITrainer(model, model_head, criterion, train_loader, test_loader, args)
trainer = Trainer(gpus=-1, accelerator='ddp', check_val_every_n_epoch=args.epochs+1, max_epochs=args.epochs, accumulate_grad_batches=args.accum_steps)
trainer.fit(model_sem_kitti)
else:
model_sem_kitti = SemanticKITTITrainer(model, model_head, criterion, train_loader, test_loader, args)
trainer = Trainer(gpus=[0], check_val_every_n_epoch=args.epochs+1, max_epochs=args.epochs, accumulate_grad_batches=args.accum_steps)
trainer.fit(model_sem_kitti)
print("Finish downstream train")