forked from open-mmlab/mmflow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gma_8x2_120k_mixed_368x768.py
45 lines (41 loc) · 1.31 KB
/
gma_8x2_120k_mixed_368x768.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
_base_ = [
'../_base_/models/gma/gma.py',
'../_base_/datasets/sintel_cleanx100_sintel_fianlx100_kitti2015x200_hd1kx5_flyingthings3d_raft_384x768.py', # noqa
'../_base_/default_runtime.py'
]
model = dict(
decoder=dict(
type='GMADecoder',
net_type='Basic',
num_levels=4,
radius=4,
iters=12,
corr_op_cfg=dict(type='CorrLookup', align_corners=True),
gru_type='SeqConv',
heads=1,
motion_channels=128,
position_only=False,
flow_loss=dict(type='SequenceLoss', gamma=0.85),
act_cfg=dict(type='ReLU')),
freeze_bn=False,
test_cfg=dict(iters=32))
optimizer = dict(
type='AdamW',
lr=0.000125,
betas=(0.9, 0.999),
eps=1e-08,
weight_decay=0.00001,
amsgrad=False)
optimizer_config = dict(grad_clip=dict(max_norm=1.))
lr_config = dict(
policy='OneCycle',
max_lr=0.000125,
total_steps=120100,
pct_start=0.05,
anneal_strategy='linear')
runner = dict(type='IterBasedRunner', max_iters=120000)
checkpoint_config = dict(by_epoch=False, interval=10000)
evaluation = dict(interval=10000, metric='EPE')
# Train on FlyingChairs and FlyingThings3D, and finetune on
# and Sintel, KITTI2015 and HD1K
load_from = 'https://download.openmmlab.com/mmflow/gma/gma_8x2_120k_flyingthings3d_400x720.pth' # noqa