Skip to content

Commit

Permalink
add optype_wise control (intel#1032)
Browse files Browse the repository at this point in the history
  • Loading branch information
mengniwang95 authored Jul 11, 2022
1 parent b3a409f commit 5032f34
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 37 deletions.
5 changes: 3 additions & 2 deletions neural_compressor/adaptor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,11 +467,12 @@ def tuning_cfg_to_fw(self, tuning_cfg):
continue

is_perchannel = False
weight_bit = 7.0
bit = None
if 'weight' in tuning_cfg['op'][each_op_info]:
is_perchannel = tuning_cfg['op'][each_op_info]['weight'][
'granularity'] == 'per_channel'
weight_bit = tuning_cfg['op'][each_op_info]['weight']['bit']
bit = tuning_cfg['op'][each_op_info]['weight']['bit']
weight_bit = bit if bit else 7.0

algorithm = tuning_cfg['op'][each_op_info]['activation']['algorithm']

Expand Down
45 changes: 24 additions & 21 deletions neural_compressor/conf/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,40 +196,40 @@ def percent_to_float(data):

ops_schema = Schema({
Optional('weight', default=None): {
Optional('granularity', default=None): And(
Optional('granularity'): And(
list,
lambda s: all(i in ['per_channel', 'per_tensor'] for i in s)),
Optional('scheme', default=None): And(
Optional('scheme'): And(
list,
# asym_float are only for PyTorch framework
lambda s: all(i in ['asym', 'sym', 'asym_float'] for i in s)),
Optional('dtype', default=None): And(
Optional('dtype'): And(
list,
lambda s: all(i in ['int8', 'uint8', 'fp32', 'bf16', 'fp16'] for i in s)),
Optional('algorithm', default=None): And(
Optional('algorithm'): And(
list,
lambda s: all(i in ['minmax'] for i in s)),
Optional('bit', default=None): And(
Optional('bit'): And(
Or(float, list),
Use(input_to_list_float),
lambda s: all(0.0 < i <= 7.0 for i in s))
},
Optional('activation', default=None): {
Optional('granularity', default=None): And(
Optional('granularity'): And(
list,
lambda s: all(i in ['per_channel', 'per_tensor'] for i in s)),
Optional('scheme', default=None): And(
Optional('scheme'): And(
list,
lambda s: all(i in ['asym', 'sym'] for i in s)),
Optional('dtype', default=None): And(
Optional('dtype'): And(
list,
lambda s: all(i in ['int8', 'uint8', 'fp32', 'bf16', 'fp16'] for i in s)),
# compute_dtypeis only for PyTorch framework
Optional('compute_dtype', default=['uint8']): And(
list,
lambda s: all(i in ['int8', 'uint8', 'fp32', 'bf16', 'None'] for i in s)),
# placeholder are only for PyTorch framework
Optional('algorithm', default=None): And(
Optional('algorithm'): And(
list,
lambda s: all(i in ['minmax', 'kl', 'placeholder'] for i in s))
}
Expand Down Expand Up @@ -784,6 +784,9 @@ def percent_to_float(data):
lambda s: all(i in ['minmax', 'kl', 'placeholder'] for i in s)),
}
},
Optional('optype_wise', default=None): {
str: ops_schema
},
Optional('op_wise', default=None): {
str: ops_schema
},
Expand Down Expand Up @@ -1281,8 +1284,7 @@ def __init__(self, cfg=None):
def _merge_dicts(self, src, dst):
"""Helper function to merge src dict into dst dict.
If the key in src doesn't exist in dst, then add this key and value
pair to dst.
If the key in src doesn't exist in dst, then skip
If the key in src is in dst and the value intersects with the one in
dst, then override the value in dst with the intersect value.
Expand All @@ -1304,9 +1306,6 @@ def _merge_dicts(self, src, dst):
if value in dst[key] or isinstance(value, float)]
if value != []:
dst[key] = value
else:
if not isinstance(src[key], dict):
dst[key] = src[key]

return dst

Expand All @@ -1315,8 +1314,14 @@ def modelwise_tune_space(self, model_wise_quant):

self._model_wise_tune_space = OrderedDict()
for optype in model_wise_quant.keys():
self._model_wise_tune_space[optype] = self._merge_dicts(cfg.quantization.model_wise,
model_wise_quant[optype])
if cfg.quantization.optype_wise and optype in cfg.quantization.optype_wise:
self._model_wise_tune_space[optype] = self._merge_dicts(
cfg.quantization.optype_wise[optype],
model_wise_quant[optype])
else:
self._model_wise_tune_space[optype] = self._merge_dicts(
cfg.quantization.model_wise,
model_wise_quant[optype])

return self._model_wise_tune_space

Expand Down Expand Up @@ -1367,19 +1372,17 @@ def _is_regex(pattern):
return True

opwise = copy.deepcopy(opwise_quant)
for k, v in opwise.items():
opwise[k] = self._merge_dicts(self._model_wise_tune_space[k[1]], opwise[k])

cfg = self.usr_cfg
if cfg.quantization.op_wise:
for k, v in cfg.quantization.op_wise.items():
is_regex = _is_regex(k)
for k_op, _ in opwise.items():
if not is_regex and k == k_op[0]:
if (not is_regex and k == k_op[0]) or (is_regex and re.match(k, k_op[0])):
opwise[k_op] = self._merge_dicts(v, opwise[k_op])

if is_regex and re.match(k, k_op[0]):
opwise[k_op] = self._merge_dicts(v, opwise[k_op])
for k, v in opwise.items():
opwise[k] = self._merge_dicts(self._model_wise_tune_space[k[1]], opwise[k])

self._opwise_tune_space = opwise
return self._opwise_tune_space
Expand Down
56 changes: 42 additions & 14 deletions test/test_adaptor_onnxrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,22 +324,23 @@ def build_ir3_model():
return model

def build_matmul_model():
A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [1, 1, 5, 5])
B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [1, 1, 5, 1])
C = helper.make_tensor_value_info('C', TensorProto.FLOAT, [1, 1, 5, 1])
D = helper.make_tensor_value_info('D', TensorProto.FLOAT, [1, 1, 5, 1])
H = helper.make_tensor_value_info('H', TensorProto.FLOAT, [1, 1, 5, 1])

A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [1, 5, 5])
C = helper.make_tensor_value_info('C', TensorProto.FLOAT, [1, 5, 2])
D = helper.make_tensor_value_info('D', TensorProto.FLOAT, [1, 5, 2])
H = helper.make_tensor_value_info('H', TensorProto.FLOAT, [1, 5, 2])

e_value = np.random.randint(2, size=(10)).astype(np.float32)
B_init = helper.make_tensor('B', TensorProto.FLOAT, [5, 2], e_value.reshape(10).tolist())
E_init = helper.make_tensor('E', TensorProto.FLOAT, [1, 5, 2], e_value.reshape(10).tolist())

matmul_node = onnx.helper.make_node('MatMul', ['A', 'B'], ['C'], name='Matmul')
e_value = np.random.randint(2, size=(5)).astype(np.float32)
E_init = helper.make_tensor('E', TensorProto.FLOAT, [1, 1, 5, 1], e_value.reshape(5).tolist())
add = onnx.helper.make_node('Add', ['C', 'E'], ['D'], name='add')

f_value = np.random.randint(2, size=(5)).astype(np.float32)
F_init = helper.make_tensor('F', TensorProto.FLOAT, [1, 1, 5, 1], e_value.reshape(5).tolist())
f_value = np.random.randint(2, size=(10)).astype(np.float32)
F_init = helper.make_tensor('F', TensorProto.FLOAT, [1, 5, 2], e_value.reshape(10).tolist())
add2 = onnx.helper.make_node('Add', ['D', 'F'], ['H'], name='add2')

graph = helper.make_graph([matmul_node, add, add2], 'test_graph_1', [A, B], [H], [E_init, F_init])
graph = helper.make_graph([matmul_node, add, add2], 'test_graph_1', [A], [H], [B_init, E_init, F_init])
model = helper.make_model(graph)
model = helper.make_model(graph, **{'opset_imports': [helper.make_opsetid('', 13)]})
return model
Expand Down Expand Up @@ -465,9 +466,8 @@ def __init__(self):
self.data = []
self.label = []
for i in range(3):
self.data.append([np.random.randn(1,5,5).astype('float32'),
np.random.randn(1,5,1).astype('float32')])
self.label.append(np.random.randn(1,5,1).astype('float32'))
self.data.append(np.random.randn(5,5).astype('float32'))
self.label.append(np.random.randn(5,1).astype('float32'))

def __getitem__(self, idx):
return self.data[idx], self.label[idx]
Expand Down Expand Up @@ -612,6 +612,34 @@ def test_set_tensor(self):


def test_adaptor(self):
from neural_compressor.utils.constant import FP32, INT8_SYM_MINMAX_PERTENSOR, UINT8_ASYM_MINMAX_PERTENSOR
conf.model.framework = 'onnxrt_qlinearops'
conf.quantization.approach = 'post_training_static_quant'
conf.quantization.calibration.sampling_size = 1
conf.quantization.optype_wise = {'Add': FP32}
conf.quantization.op_wise = {'add': {'weight': INT8_SYM_MINMAX_PERTENSOR, 'activation': UINT8_ASYM_MINMAX_PERTENSOR}}
conf.evaluation.accuracy.metric = {'MSE': {'compare_label': False}}
quantizer = Quantization(conf)
quantizer.calib_dataloader = self.matmul_dataloader
quantizer.eval_dataloader = self.matmul_dataloader
quantizer.model = self.matmul_model
q_model = quantizer.fit()
self.assertTrue('add2' in [i.name for i in q_model.nodes()])
self.assertTrue('add_quant' in [i.name for i in q_model.nodes()])

conf.quantization.pop('op_wise')
conf.quantization.model_wise = {'weight': INT8_SYM_MINMAX_PERTENSOR}
conf.quantization.optype_wise = {'MatMul': {'weight': {'granularity': ['per_channel']}}}
quantizer = Quantization(conf)
quantizer.calib_dataloader = self.matmul_dataloader
quantizer.eval_dataloader = self.matmul_dataloader
quantizer.model = self.matmul_model
q_model = quantizer.fit()
self.assertEqual(len([i for i in q_model.initializer() if i.name == 'B_scale'][0].float_data), 2)

conf.quantization.pop('optype_wise')
conf.quantization.pop('model_wise')

conf.model.framework = 'onnxrt_integerops'
conf.quantization.approach = 'post_training_dynamic_quant'
conf.quantization.calibration.sampling_size = 1
Expand Down

0 comments on commit 5032f34

Please sign in to comment.