-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enhance: Quant Tensor Test #894
Changes from all commits
67f63ab
ca962fd
7e80aac
0740f06
c7e6d60
12c5108
367e906
1b9b62a
261831e
7fc19bb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
from enum import Enum | ||
|
||
import pytest | ||
import torch | ||
|
||
from brevitas.inject.enum import QuantType | ||
from brevitas.nn import QuantIdentity | ||
from brevitas.quant_tensor import QuantTensor | ||
|
||
|
||
class Operator(Enum): | ||
ADD = 0 | ||
SUBTRACT = 1 | ||
DIVIDE = 2 | ||
MULTIPLY = 3 | ||
MATMUL = 4 | ||
|
||
|
||
def to_quant_tensor(input: torch.Tensor) -> QuantTensor: | ||
mod = QuantIdentity(bit_width=8, return_quant_tensor=True) | ||
return mod(input) | ||
|
||
|
||
def qdq(normal_tensor, quant_tensor): | ||
return ( | ||
torch.round(normal_tensor / quant_tensor.scale + quant_tensor.zero_point) - | ||
quant_tensor.zero_point) * quant_tensor.scale | ||
|
||
|
||
def test_quant_tensor_init(): | ||
x = torch.randn(4, 4) | ||
quant_tensor = to_quant_tensor(x) | ||
normal_tensor = torch.Tensor(x) | ||
assert torch.allclose(qdq(normal_tensor, quant_tensor), quant_tensor, rtol=0.01) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
'op', [Operator.ADD, Operator.SUBTRACT, Operator.DIVIDE, Operator.MULTIPLY, Operator.MATMUL]) | ||
def test_quant_tensor_operators(op): | ||
x = torch.randn(4, 4) | ||
|
||
a = torch.Tensor(x) | ||
b = torch.Tensor(x) | ||
|
||
qa = to_quant_tensor(a) | ||
qb = to_quant_tensor(b) | ||
|
||
# to factor in quantisation error | ||
e_a = a - qa | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. didn't use qdq approach above as should be covered by the init test, I just need the difference so I can incorporate it into the calculations below |
||
e_b = b - qb | ||
|
||
if op == Operator.ADD: | ||
quant = qa + qb | ||
normal = (a - e_a) + (b - e_b) | ||
elif op == Operator.SUBTRACT: | ||
quant = qa - qb | ||
normal = (a - e_a) - (b - e_b) | ||
elif op == Operator.DIVIDE: | ||
quant = qa / qb | ||
normal = (a - e_a) / (b - e_b) | ||
elif op == Operator.MULTIPLY: | ||
quant = qa * qb | ||
normal = (a - e_a) * (b - e_b) | ||
elif op == Operator.MATMUL: | ||
# @ matmul operator not implemented for QuantTensor | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the difference between There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't believe there is a difference so its probably something we should create an issue to implement |
||
quant = torch.matmul(qa, qb) | ||
normal = (a - e_a) @ (b - e_b) | ||
else: | ||
# unrecognised operator | ||
assert False | ||
|
||
assert torch.allclose(normal, quant) | ||
|
||
|
||
def test_quant_tensor_div_by_zero(): | ||
a = to_quant_tensor(torch.ones(4, 4)) | ||
b = to_quant_tensor(torch.zeros(4, 4)) | ||
assert torch.isinf(a / b).all().item() | ||
|
||
|
||
def test_quant_tensor_div_by_fraction(): | ||
a = to_quant_tensor(torch.ones(4, 4)) | ||
b = to_quant_tensor(torch.ones(4, 4) * 0.5) | ||
assert torch.allclose(a / b, torch.ones(4, 4) * 2, atol=0.1) | ||
|
||
|
||
# TODO: need to deal with quant metadata | ||
def test_quant_tensor_transpose(): | ||
x = torch.ones(4, 4).tril() | ||
a = x.clone() | ||
b = to_quant_tensor(x) | ||
assert torch.allclose(a.transpose(0, 1), b.transpose(0, 1), atol=0.01) | ||
|
||
|
||
# TODO: need to deal with quant metadata | ||
def test_quant_tensor_view(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. View and transpose open the discussion to a broader topic regarding how to deal with quant metadata views and transpose, especially in the case where we are doing per channel or finer granularity quantizations. For now, I would add a TODO in both test case that says that we need to deal with quant metadata and test it |
||
x = torch.ones(4, 4) | ||
a = to_quant_tensor(x) | ||
b = torch.Tensor(x) | ||
|
||
assert torch.allclose(a.view(-1), b.view(-1), atol=0.01) | ||
assert torch.allclose(a.view(2, -1), b.view(2, -1), atol=0.01) | ||
assert torch.allclose(a.view(16, -1), b.view(16, -1), atol=0.01) | ||
assert torch.allclose(a.view(8, 2), b.view(8, 2), atol=0.01) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
difference between the qdq result and quant tensor is extremely close but some error is creeping in from the quanttensor somewhere so added relative tolerance