Skip to content

Commit

Permalink
[bug fix] update quantizable add ops detection on IPEX backend (#1456)
Browse files Browse the repository at this point in the history
Signed-off-by: Xin He <[email protected]>
  • Loading branch information
xin3he authored Dec 25, 2023
1 parent 5119fcb commit 4c004d7
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 126 deletions.
16 changes: 7 additions & 9 deletions neural_compressor/adaptor/torch_utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,15 +465,13 @@ def get_quantizable_ops_from_cfgs(ops_name, op_infos_from_cfgs, input_tensor_ids
op_info = op_infos_from_cfgs[name]
output_tensors = op_info["output_tensor_infos"]
input_tensors = op_info["input_tensor_infos"]
for input_tensor in input_tensors:
if "inf_dtype" not in input_tensor.keys():
continue
if input_tensor["inf_dtype"] == torch.float32:
pre_op_name = input_tensor_ids_op_name[input_tensor["id"]]
if pre_op_name[1] in ["q_op_infos"]:
print(pre_op_name, "is not the fuse ops first op.")
start = False
continue
start = any(
[
input_tensor["inf_dtype"] != "torch.float32"
for input_tensor in input_tensors
if "inf_dtype" in input_tensor.keys()
]
)
if not start:
continue
# add quantizable ops, include op and fuse ops.
Expand Down
117 changes: 0 additions & 117 deletions test/ipex/test_adaptor_ipex.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,62 +316,6 @@ def forward(self, a):
)
self.assertTrue(isinstance(q_model._model, torch.jit.ScriptModule))

def test_tune_add(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 1, 1)
self.linear = torch.nn.Linear(224 * 224, 5)

def forward(self, a):
x = self.conv(a)
x = x.view(1, -1)
x += x
x = self.linear(x)
return x

model = M()
from neural_compressor import PostTrainingQuantConfig, quantization

acc_lst = [1, 0.8, 1.1, 1.2]

def fake_eval(model):
res = acc_lst.pop(0)
return res

conf = PostTrainingQuantConfig(backend="ipex", quant_level=0)
calib_dataloader = Dataloader()
q_model = quantization.fit(model, conf, calib_dataloader=calib_dataloader, eval_func=fake_eval)
self.assertTrue(isinstance(q_model._model, torch.jit.ScriptModule))

def test_tune_add_with_recipe(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(224 * 224 * 3, 5)

def forward(self, x):
x += x
x = x.view(1, -1)
x = self.linear(x)
return x

model = M()
from neural_compressor import PostTrainingQuantConfig, quantization

acc_lst = [1, 0.8, 1.1, 1.2]

def fake_eval(model):
res = acc_lst.pop(0)
return res

conf = PostTrainingQuantConfig(
backend="ipex", quant_level=0, recipes={"smooth_quant": True, "smooth_quant_args": {"alpha": 0.5}}
)
calib_dataloader = Dataloader()
q_model = quantization.fit(model, conf, calib_dataloader=calib_dataloader, eval_func=fake_eval)
self.assertTrue(isinstance(q_model._model, torch.jit.ScriptModule))

def test_tune_minmax_obs(self):
class M(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -527,67 +471,6 @@ def forward(self, a):
)
self.assertTrue(isinstance(q_model._model, torch.jit.ScriptModule))

def test_tune_add(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 1, 1)
self.linear = torch.nn.Linear(224 * 224, 5)

def forward(self, a):
x = self.conv(a)
x = x.view(1, -1)
x += x
x = self.linear(x)
return x

model = M().to("xpu")
from neural_compressor import PostTrainingQuantConfig, quantization

acc_lst = [1, 0.8, 1.1, 1.2]

def fake_eval(model):
res = acc_lst.pop(0)
return res

conf = PostTrainingQuantConfig(backend="ipex", device="xpu", quant_level=0)
calib_dataloader = Dataloader(device="xpu")
q_model = quantization.fit(model, conf, calib_dataloader=calib_dataloader, eval_func=fake_eval)
self.assertTrue(isinstance(q_model._model, torch.jit.ScriptModule))

def test_tune_add_with_recipe(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 1, 1)
self.linear = torch.nn.Linear(224 * 224, 5)

def forward(self, a):
x = self.conv(a)
x = x.view(1, -1)
x += x
x = self.linear(x)
return x

model = M().to("xpu")
from neural_compressor import PostTrainingQuantConfig, quantization

acc_lst = [1, 0.8, 1.1, 1.2]

def fake_eval(model):
res = acc_lst.pop(0)
return res

conf = PostTrainingQuantConfig(
backend="ipex",
device="xpu",
quant_level=0,
recipes={"smooth_quant": True, "smooth_quant_args": {"alpha": 0.5}},
)
calib_dataloader = Dataloader(device="xpu")
q_model = quantization.fit(model, conf, calib_dataloader=calib_dataloader, eval_func=fake_eval)
self.assertTrue(isinstance(q_model._model, torch.jit.ScriptModule))


class TestMixedPrecision(unittest.TestCase):
@classmethod
Expand Down

0 comments on commit 4c004d7

Please sign in to comment.