-
Notifications
You must be signed in to change notification settings - Fork 1.6k
/
Copy pathwrite_spconv2.py
118 lines (101 loc) · 5.2 KB
/
write_spconv2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
from mmcv.cnn.bricks.registry import CONV_LAYERS
from torch.nn.parameter import Parameter
def register_spconv2():
"""This func registers spconv2.0 spconv ops to overwrite the default mmcv
spconv ops."""
try:
from spconv.pytorch import (SparseConv2d, SparseConv3d, SparseConv4d,
SparseConvTranspose2d,
SparseConvTranspose3d, SparseInverseConv2d,
SparseInverseConv3d, SparseModule,
SubMConv2d, SubMConv3d, SubMConv4d)
except ImportError:
return False
else:
CONV_LAYERS._register_module(SparseConv2d, 'SparseConv2d', force=True)
CONV_LAYERS._register_module(SparseConv3d, 'SparseConv3d', force=True)
CONV_LAYERS._register_module(SparseConv4d, 'SparseConv4d', force=True)
CONV_LAYERS._register_module(
SparseConvTranspose2d, 'SparseConvTranspose2d', force=True)
CONV_LAYERS._register_module(
SparseConvTranspose3d, 'SparseConvTranspose3d', force=True)
CONV_LAYERS._register_module(
SparseInverseConv2d, 'SparseInverseConv2d', force=True)
CONV_LAYERS._register_module(
SparseInverseConv3d, 'SparseInverseConv3d', force=True)
CONV_LAYERS._register_module(SubMConv2d, 'SubMConv2d', force=True)
CONV_LAYERS._register_module(SubMConv3d, 'SubMConv3d', force=True)
CONV_LAYERS._register_module(SubMConv4d, 'SubMConv4d', force=True)
SparseModule._load_from_state_dict = _load_from_state_dict
SparseModule._save_to_state_dict = _save_to_state_dict
return True
def _save_to_state_dict(self, destination, prefix, keep_vars):
"""Rewrite this func to compat the convolutional kernel weights between
spconv 1.x in MMCV and 2.x in spconv2.x.
Kernel weights in MMCV spconv has shape in (D,H,W,in_channel,out_channel) ,
while those in spcon2.x is in (out_channel,D,H,W,in_channel).
"""
for name, param in self._parameters.items():
if param is not None:
param = param if keep_vars else param.detach()
if name == 'weight':
dims = list(range(1, len(param.shape))) + [0]
param = param.permute(*dims)
destination[prefix + name] = param
for name, buf in self._buffers.items():
if buf is not None and name not in self._non_persistent_buffers_set:
destination[prefix + name] = buf if keep_vars else buf.detach()
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
"""Rewrite this func to compat the convolutional kernel weights between
spconv 1.x in MMCV and 2.x in spconv2.x.
Kernel weights in MMCV spconv has shape in (D,H,W,in_channel,out_channel) ,
while those in spcon2.x is in (out_channel,D,H,W,in_channel).
"""
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys,
unexpected_keys, error_msgs)
local_name_params = itertools.chain(self._parameters.items(),
self._buffers.items())
local_state = {k: v.data for k, v in local_name_params if v is not None}
for name, param in local_state.items():
key = prefix + name
if key in state_dict:
input_param = state_dict[key]
# Backward compatibility: loading 1-dim tensor from
# 0.3.* to version 0.4+
if len(param.shape) == 0 and len(input_param.shape) == 1:
input_param = input_param[0]
dims = [len(input_param.shape) - 1] + list(
range(len(input_param.shape) - 1))
input_param = input_param.permute(*dims)
if input_param.shape != param.shape:
# local shape should match the one in checkpoint
error_msgs.append(
f'size mismatch for {key}: copying a param with '
f'shape {key, input_param.shape} from checkpoint,'
f'the shape in current model is {param.shape}.')
continue
if isinstance(input_param, Parameter):
# backwards compatibility for serialized parameters
input_param = input_param.data
try:
param.copy_(input_param)
except Exception:
error_msgs.append(
f'While copying the parameter named "{key}", whose '
f'dimensions in the model are {param.size()} and whose '
f'dimensions in the checkpoint are {input_param.size()}.')
elif strict:
missing_keys.append(key)
if strict:
for key, input_param in state_dict.items():
if key.startswith(prefix):
input_name = key[len(prefix):]
input_name = input_name.split(
'.', 1)[0] # get the name of param/buffer/child
if input_name not in self._modules \
and input_name not in local_state:
unexpected_keys.append(key)