-
Notifications
You must be signed in to change notification settings - Fork 18
/
main.py
119 lines (92 loc) · 3.49 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import rich.logging
import torch
import hydra
import warnings
import logging
from tracklab.utils import monkeypatch_hydra, \
progress # needed to avoid complex hydra stacktraces when errors occur in "instantiate(...)"
from hydra.utils import instantiate
from omegaconf import OmegaConf
from tracklab.datastruct import TrackerState
from tracklab.pipeline import Pipeline
from tracklab.utils import wandb
os.environ["HYDRA_FULL_ERROR"] = "1"
log = logging.getLogger(__name__)
warnings.filterwarnings("ignore")
@hydra.main(version_base=None, config_path="pkg://tracklab.configs", config_name="config")
def main(cfg):
device = init_environment(cfg)
# Instantiate all modules
tracking_dataset = instantiate(cfg.dataset)
evaluator = instantiate(cfg.eval, tracking_dataset=tracking_dataset)
modules = []
if cfg.pipeline is not None:
for name in cfg.pipeline:
module = cfg.modules[name]
inst_module = instantiate(module, device=device, tracking_dataset=tracking_dataset)
modules.append(inst_module)
pipeline = Pipeline(models=modules)
# Train tracking modules
for module in modules:
if module.training_enabled:
module.train()
# Test tracking
if cfg.test_tracking:
log.info(f"Starting tracking operation on {cfg.dataset.eval_set} set.")
# Init tracker state and tracking engine
tracking_set = tracking_dataset.sets[cfg.dataset.eval_set]
tracker_state = TrackerState(tracking_set, pipeline=pipeline, **cfg.state)
tracking_engine = instantiate(
cfg.engine,
modules=pipeline,
tracker_state=tracker_state,
)
# Run tracking and visualization
tracking_engine.track_dataset()
# Evaluation
evaluate(cfg, evaluator, tracker_state)
# Save tracker state
if tracker_state.save_file is not None:
log.info(f"Saved state at : {tracker_state.save_file.resolve()}")
close_enviroment()
return 0
def set_sharing_strategy():
torch.multiprocessing.set_sharing_strategy(
"file_system"
)
def init_environment(cfg):
# For Hydra and Slurm compatibility
progress.use_rich = cfg.use_rich
set_sharing_strategy() # Do not touch
device = "cuda" if torch.cuda.is_available() else "cpu"
log.info(f"Using device: '{device}'.")
wandb.init(cfg)
if cfg.print_config:
log.info(OmegaConf.to_yaml(cfg))
if cfg.use_rich:
for handler in log.root.handlers:
if type(handler) is logging.StreamHandler:
handler.setLevel(logging.ERROR)
log.root.addHandler(rich.logging.RichHandler(level=logging.INFO))
else:
# TODO : Fix for mmcv fix. This should be done in a nicer way
for handler in log.root.handlers:
if type(handler) is logging.StreamHandler:
handler.setLevel(logging.INFO)
return device
def close_enviroment():
wandb.finish()
def evaluate(cfg, evaluator, tracker_state):
if cfg.get("eval_tracking", True) and cfg.dataset.nframes == -1:
log.info("Starting evaluation.")
evaluator.run(tracker_state)
elif cfg.get("eval_tracking", True) == False:
log.warning("Skipping evaluation because 'eval_tracking' was set to False.")
else:
log.warning(
"Skipping evaluation because only part of video was tracked (i.e. 'cfg.dataset.nframes' was not set "
"to -1)"
)
if __name__ == "__main__":
main()