Skip to content

Commit

Permalink
add ut and fix bug
Browse files Browse the repository at this point in the history
Signed-off-by: Mengni Wang <[email protected]>
  • Loading branch information
mengniwang95 committed Jul 11, 2023
1 parent 0e5b843 commit b9d9a48
Show file tree
Hide file tree
Showing 5 changed files with 207 additions and 191 deletions.
178 changes: 17 additions & 161 deletions neural_compressor/adaptor/onnxrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1442,18 +1442,12 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):

quant_config = self._cfg_to_quantize_config(tune_cfg)
algos = set([item["weight"]["algorithm"] for key, item in quant_config.items() if isinstance(item, dict)])
#if "GPTQ" in algos:
# model = gptq_quantize(model, quant_config, data_loader)
#if "AWQ" in algos:
# model = awq_quantize(model, quant_config, data_loader)
#elif "RTN" in algos:
# model = rtn_quantize(model, quant_config)
model = awq_quantize(model, quant_config, data_loader)
#model = rtn_quantize(model, quant_config)

#model = rtn_quantize(model, quant_config)
#model.q_config = copy.deepcopy(self.tune_cfg)
#self._dump_model_op_stats(model, self.tune_cfg)
if "AWQ" in algos:
model = awq_quantize(model, quant_config, data_loader)
elif "RTN" in algos:
model = rtn_quantize(model, quant_config)
model.q_config = copy.deepcopy(quant_config)
self._dump_model_op_stats(model, tune_cfg)
model.topological_sort()
return model

Expand All @@ -1464,7 +1458,7 @@ def _dump_model_op_stats(self, model, tune_cfg):
for op, config in tune_cfg['op'].items():
op_type = op[1]
if not config['weight']['dtype'] == 'fp32':
num_bits = config['weight']['bit']
num_bits = config['weight']['bits']
group_size = config['weight']['group_size']
dtype_str = "A32W{}G{}".format(num_bits, group_size)
dtype_set.add(dtype_str)
Expand All @@ -1481,7 +1475,7 @@ def _dump_model_op_stats(self, model, tune_cfg):
if config['weight']['dtype'] == 'fp32':
res[op_type]['FP32'] += 1
else:
num_bits = config['weight']['bit']
num_bits = config['weight']['bits']
group_size = config['weight']['group_size']
dtype_str = "A32W{}G{}".format(num_bits, group_size)
res[op_type][dtype_str] += 1
Expand Down Expand Up @@ -1528,20 +1522,7 @@ def query_fw_capability(self, model):
"""
# optype_wise and op_wise capability
self._pre_optimize(model)
recipes_ops = {}
recipes_ops['first_conv_or_matmul_quantization'] = []
recipes_ops['last_conv_or_matmul_quantization'] = []
recipes_ops['pre_post_process_quantization'] = []
exclude_first_quantizable_op = True if 'first_conv_or_matmul_quantization' in \
self.recipes and not self.recipes['first_conv_or_matmul_quantization'] \
else False
exclude_last_quantizable_op = True if 'last_conv_or_matmul_quantization' in \
self.recipes and not self.recipes['last_conv_or_matmul_quantization'] \
else False
exclude_pre_post_process = True if 'pre_post_process_quantization' in \
self.recipes and not self.recipes['pre_post_process_quantization'] \
else False


quantizable_optype = set([i.op_type for i in self.pre_optimized_model.nodes()])
optype_wise = OrderedDict()
op_wise = OrderedDict()
Expand Down Expand Up @@ -1578,140 +1559,15 @@ def query_fw_capability(self, model):
elif op_capability not in optype_wise[op]:
optype_wise[op].append(op_capability)

first_quantizable_node = []
last_quantizable_node = []
all_conv_matmul = []
attention_matmul = []
for _, node in enumerate(self.pre_optimized_model.nodes()):
if node.op_type in ['Conv', 'MatMul', 'Attention']:
# get first Conv or MatMul node
if len(first_quantizable_node) == 0:
first_quantizable_node.append(node)

# get last Conv or MatMul node
if len(last_quantizable_node) != 0:
last_quantizable_node.pop()
last_quantizable_node.append(node)

all_conv_matmul.append(node)
if node.op_type != 'Conv':
attention_matmul.append(node)

if len(first_quantizable_node) != 0:
recipes_ops['first_conv_or_matmul_quantization'] = [(first_quantizable_node[0].name,
first_quantizable_node[0].op_type)]
if len(last_quantizable_node) != 0:
recipes_ops['last_conv_or_matmul_quantization'] = [(last_quantizable_node[0].name,
last_quantizable_node[0].op_type)]


ffn_matmul = []
attention_matmul_optype = [node.op_type for node in attention_matmul]
# find matmul ops in feed forward network (FFN) structure whitch mainly in transfomers based NLP models
if len(attention_matmul) > 0 and 'Attention' in attention_matmul_optype:
# model is optimized and Attention is fused,
# index of Attention is used as split to find FFN MatMul
first_attention_index = attention_matmul_optype.index('Attention')
attention_matmul_optype = attention_matmul_optype[first_attention_index:]
attention_matmul = attention_matmul[first_attention_index:]
attention_index = list(np.where(np.array(attention_matmul_optype) == 'Attention')[0])
block_len = attention_index[1] - attention_index[0] if len(attention_index) > 2 else 4
for idx in range(len(attention_index)):
if idx != len(attention_index) - 1:
index = attention_index[idx + 1]
if index - 2 >= 0 and index - 1 >= 0:
ffn_matmul.append([attention_matmul[index - 2],
attention_matmul[index - 1]])
else:
index = attention_index[idx]
if index + block_len - 2 < len(attention_matmul) and \
index + block_len - 1 < len(attention_matmul):
ffn_matmul.append([attention_matmul[index + block_len - 2],
attention_matmul[index + block_len - 1]])
else:
# model is not optimized or Attention isn't fused,
# query MatMul, key MatMul and value MatMul are used as split to find FFN MatMul
qkv = self.pre_optimized_model.find_qkv_in_attention(find_all=True)
if len(qkv) != 0:
attention_starts = [nodes[0] for nodes in qkv]
attention_index = [np.where(np.array([n.name for n in attention_matmul]) \
== attention_start)[0].tolist()[0] \
for attention_start in attention_starts]
block_len = attention_index[1] - attention_index[0] if len(attention_index) > 2 else 4
for idx in range(len(attention_index)):
if idx != len(attention_index) - 1:
index = attention_index[idx + 1]
if index - 2 >= 0 and index - 1 >= 0:
ffn_matmul.append([attention_matmul[index - 2],
attention_matmul[index - 1]])
else:
index = attention_index[idx]
if index + block_len - 2 < len(attention_matmul) and \
index + block_len - 1 < len(attention_matmul):
ffn_matmul.append([attention_matmul[index + block_len - 2],
attention_matmul[index + block_len - 1]])

block_wise = []
for block in reversed(ffn_matmul):
node_info = []
for node in block:
node_info.append((node.name, node.op_type))
if len(node_info) != 0:
block_wise.append(node_info)

for parent, nodes in self.pre_optimized_model.get_absorb_pairs(["MatMul", "Attention"]).items():
for node in nodes:
if node.op_type in optype_wise:
if (exclude_first_quantizable_op and node in first_quantizable_node) \
or (exclude_last_quantizable_op and node in last_quantizable_node):
tmp_cfg = copy.deepcopy(optype_wise[node.op_type])
tmp_cfg = list(filter(lambda x:'quant_mode' not in x['activation'], tmp_cfg))
op_wise.update({(node.name, node.op_type): tmp_cfg})
continue
op_wise.update(
{(node.name, node.op_type): copy.deepcopy(optype_wise[node.op_type])})

# only when first and last quantizable nodes are found and they are not the same,
# fallback pre/postprocess ops
if len(first_quantizable_node) != 0 and \
len(last_quantizable_node) != 0 and \
first_quantizable_node[0].name != last_quantizable_node[0].name:
# get backbone nodes
from collections import deque

# get nodes between first quantizable node and last quantizable node
backbone_queue = deque(last_quantizable_node)
backbone_nodes = self.pre_optimized_model.get_nodes_chain(backbone_queue, first_quantizable_node)

# get extra Conv or MatMul nodes not between first quantizable node and last quantizable node
backbone_queue_extra = deque()
for conv_or_matmul in all_conv_matmul:
if conv_or_matmul.name not in backbone_nodes:
backbone_queue_extra.append(conv_or_matmul)
backbone_nodes = self.pre_optimized_model.get_nodes_chain(backbone_queue_extra,
first_quantizable_node, backbone_nodes)
backbone_nodes += [i.name for i in first_quantizable_node]

for _, node in enumerate(self.pre_optimized_model.nodes()):
if node.name not in backbone_nodes and node.op_type in optype_wise:
recipes_ops['pre_post_process_quantization'].append((node.name, node.op_type))
if exclude_pre_post_process:
for _, node in enumerate(self.pre_optimized_model.nodes()):
if node.op_type in optype_wise:
# nodes not in backbone are not quantized
if node.name not in backbone_nodes:
tmp_cfg = copy.deepcopy(optype_wise[node.op_type])
tmp_cfg = list(filter(lambda x:'quant_mode' not in x['activation'], tmp_cfg))
op_wise.update({(node.name, node.op_type): tmp_cfg})
continue
if (node.name, node.op_type) in op_wise:
op_wise.update(
{(node.name, node.op_type): copy.deepcopy(op_wise[(node.name, node.op_type)])})
else: # pragma: no cover
op_wise.update(
{(node.name, node.op_type): copy.deepcopy(optype_wise[node.op_type])})
for node in self.pre_optimized_model.nodes():
if node.op_type in ['MatMul', 'Attention'] and model.get_initializer(node.input[1]) is None:
op_wise.update(
{(node.name, node.op_type): [{'weight': {'dtype': 'fp32'}, 'activation': {'dtype': 'fp32'}}]})
if node.op_type in optype_wise:
op_wise.update(
{(node.name, node.op_type): copy.deepcopy(optype_wise[node.op_type])})

return {'optypewise': optype_wise, 'opwise': op_wise, 'recipes_ops': recipes_ops, 'block_wise': block_wise}
return {'optypewise': optype_wise, 'opwise': op_wise, 'recipes_ops': {}, 'block_wise': []}


@adaptor_registry
Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/adaptor/onnxrt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
'dtype': ['int'], # no need to care uint
'bits': [4, 3, 8], # [1-8]
'granularity':['per_group', 'per_channel'],
'group_size': [32, 1, 16, 64, 128, 256, 512, 1024], # [1-inf]
'group_size': [32, -1, 1, 16, 64, 128, 256, 512, 1024], # [1-inf]
'scheme': ['sym', 'asym'], # sym, no ZP
'algorithm': ['RTN'], # RTN, [RTN, GPTQ, AWQ,] RTN+AWQ+TEQ order
},
Expand Down
68 changes: 39 additions & 29 deletions neural_compressor/adaptor/ox_utils/weight_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def qdq_tensor(data, config, ratio=1):
def rtn_quantize(model, tune_cfg, ratios={}):
model = model if isinstance(model, BaseModel) else ONNXModel(model)
for node in model.nodes():
if node.name in tune_cfg and tune_cfg[node.name]["weight"]["dtype"] != "fp32":
if node.name in tune_cfg and tune_cfg[node.name] != "fp32":
if model.get_initializer(node.input[1]) is None:
continue
weight = numpy_helper.to_array(
Expand Down Expand Up @@ -136,59 +136,69 @@ def apply_awq_scale(model, tune_cfg, absorb_pairs, output_dicts):
best_scale = scales

for node in nodes:
tensor = numpy_helper.to_array(model.get_initializer(node.input[1]))
new_tensor = tensor * best_scale
model.set_initializer(node.input[1], new_tensor.astype(tensor.dtype), raw=True)
output_dicts[node.input[0]] = output_dicts[node.input[0]] / np.reshape(best_scale, (1, -1))
if node.name in tune_cfg and tune_cfg[node.name] != "fp32":
tensor = numpy_helper.to_array(model.get_initializer(node.input[1]))
new_tensor = tensor * best_scale
model.set_initializer(node.input[1], new_tensor.astype(tensor.dtype), raw=True)
output_dicts[node.input[0]] = output_dicts[node.input[0]] / np.reshape(best_scale, (1, -1))

parent = model.get_node(parent)
if parent.name in updated_nodes:
continue

if parent.op_type in ["LayerNormalization", "BatchNormalization", "InstanceNormalization"]:
if parent.op_type in ["LayerNormalization", "BatchNormalization", "InstanceNormalization"] and \
all([node.name in tune_cfg and tune_cfg[node.name] != "fp32" for node in nodes]):
for idx in [1, 2]:
tensor = numpy_helper.to_array(model.get_initializer(parent.input[idx]),
os.path.dirname(model.model_path))
new_tensor = tensor / best_scale
new_tensor = tensor / np.reshape(best_scale, (1, -1))
model.set_initializer(parent.input[idx], new_tensor.astype(tensor.dtype), raw=True)
updated_nodes.append(parent.name)
output_dicts[parent.output[0]] = output_dicts[parent.output[0]] / np.reshape(best_scale, (1, -1))

elif parent.op_type in ["SimplifiedLayerNormalization", "MatMul", "Gemm", "Mul"] and \
not all([model.get_initializer(inp) is None for inp in parent.input]):
not all([model.get_initializer(inp) is None for inp in parent.input]) and \
all([node.name in tune_cfg and tune_cfg[node.name] != "fp32" for node in nodes]):
for inp in parent.input:
if model.get_initializer(inp) is not None:
tensor = numpy_helper.to_array(model.get_initializer(inp),
os.path.dirname(model.model_path))
new_tensor = tensor / np.reshape(best_scale, (1, -1))
model.set_initializer(inp, new_tensor.astype(tensor.dtype), raw=True)
updated_nodes.append(parent.name)
output_dicts[parent.output[0]] = output_dicts[parent.output[0]] / np.reshape(best_scale, (1, -1))

elif parent.op_type in ["Conv", "FusedConv"]:
elif parent.op_type in ["Conv", "FusedConv"] and \
all([node.name in tune_cfg and tune_cfg[node.name] != "fp32" for node in nodes]):
tensor = numpy_helper.to_array(model.get_initializer(parent.input[2]),
os.path.dirname(model.model_path))
new_tensor = tensor / best_scale
new_tensor = tensor / np.reshape(best_scale, (1, -1))
model.set_initializer(parent.input[2], new_tensor.astype(tensor.dtype), raw=True)
updated_nodes.append(parent.name)
output_dicts[parent.output[0]] = output_dicts[parent.output[0]] / np.reshape(best_scale, (1, -1))

else:
# insert mul
scale_tensor = helper.make_tensor(
name=parent.output[0] + "_weight_only_scale",
data_type=onnx_proto.TensorProto.FLOAT,
dims=best_scale.shape,
vals=best_scale.flatten().tolist())
new_init_tensors.append(scale_tensor)
mul_output_name = parent.output[0] + "_weight_only_out"
mul_node = helper.make_node(
"Mul",
inputs=[node.input[0], scale_tensor.name],
outputs=[mul_output_name],
name=node.input[0] + "_weight_only_mul"
)
new_added_mul_nodes.append(mul_node)
replace_input.append([node, node.input[0], mul_node.output[0]])
updated_nodes.append(parent.name)
output_dicts[mul_node.output[0]] = output_dicts[mul_node.input[0]]
q_nodes = [node for node in nodes if node.name in tune_cfg and tune_cfg[node.name] != "fp32"]
if len(q_nodes) > 0:
scale_tensor = helper.make_tensor(
name=parent.output[0] + "_weight_only_scale",
data_type=onnx_proto.TensorProto.FLOAT,
dims=best_scale.shape,
vals=(1. / best_scale).flatten().tolist())
new_init_tensors.append(scale_tensor)
mul_output_name = parent.output[0] + "_weight_only_out"
mul_node = helper.make_node(
"Mul",
inputs=[q_nodes[0].input[0], scale_tensor.name],
outputs=[mul_output_name],
name=q_nodes[0].input[0] + "_weight_only_mul"
)
new_added_mul_nodes.append(mul_node)
for node in q_nodes:
replace_input.append([node, node.input[0], mul_node.output[0]])
updated_nodes.append(parent.name)
output_dicts[mul_node.output[0]] = output_dicts[mul_node.input[0]] / np.reshape(best_scale, (1, -1))

model.add_nodes(new_added_mul_nodes)
model.add_initializers(new_init_tensors)
Expand All @@ -207,7 +217,7 @@ def apply_awq_clip(model, tune_cfg, absorb_pairs, output_dicts):
inp = np.concatenate(output_dicts[nodes[0].input[0]], axis=0)

for node in nodes:
if node.name in tune_cfg and tune_cfg[node.name]["weight"]["dtype"] != "fp32":
if node.name in tune_cfg and tune_cfg[node.name] != "fp32":

group_size = tune_cfg[node.name]["weight"]["group_size"]
config = tune_cfg[node.name]
Expand Down Expand Up @@ -255,7 +265,7 @@ def awq_quantize(model,
dataloader,
[],
white_nodes=white_nodes,
iterations=list(range(0, n_samples)))
iterations=list(range(0, math.ceil(n_samples / dataloader.batch_size))))

augment.augment_graph(activation_only=True, weight_only=False)

Expand Down
Loading

0 comments on commit b9d9a48

Please sign in to comment.