Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance: Quant Tensor Test #894

Merged
merged 10 commits into from
Apr 10, 2024
107 changes: 107 additions & 0 deletions tests/brevitas/quant_tensor/test_quant_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
from enum import Enum

import pytest
import torch

from brevitas.inject.enum import QuantType
from brevitas.nn import QuantIdentity
from brevitas.quant_tensor import QuantTensor


class Operator(Enum):
ADD = 0
SUBTRACT = 1
DIVIDE = 2
MULTIPLY = 3
MATMUL = 4


def to_quant_tensor(input: torch.Tensor) -> QuantTensor:
mod = QuantIdentity(bit_width=8, return_quant_tensor=True)
return mod(input)


def qdq(normal_tensor, quant_tensor):
return (
torch.round(normal_tensor / quant_tensor.scale + quant_tensor.zero_point) -
quant_tensor.zero_point) * quant_tensor.scale


def test_quant_tensor_init():
x = torch.randn(4, 4)
quant_tensor = to_quant_tensor(x)
normal_tensor = torch.Tensor(x)
assert torch.allclose(qdq(normal_tensor, quant_tensor), quant_tensor, rtol=0.01)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

difference between the qdq result and quant tensor is extremely close but some error is creeping in from the quanttensor somewhere so added relative tolerance



@pytest.mark.parametrize(
'op', [Operator.ADD, Operator.SUBTRACT, Operator.DIVIDE, Operator.MULTIPLY, Operator.MATMUL])
def test_quant_tensor_operators(op):
x = torch.randn(4, 4)

a = torch.Tensor(x)
b = torch.Tensor(x)

qa = to_quant_tensor(a)
qb = to_quant_tensor(b)

# to factor in quantisation error
e_a = a - qa
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didn't use qdq approach above as should be covered by the init test, I just need the difference so I can incorporate it into the calculations below

e_b = b - qb

if op == Operator.ADD:
quant = qa + qb
normal = (a - e_a) + (b - e_b)
elif op == Operator.SUBTRACT:
quant = qa - qb
normal = (a - e_a) - (b - e_b)
elif op == Operator.DIVIDE:
quant = qa / qb
normal = (a - e_a) / (b - e_b)
elif op == Operator.MULTIPLY:
quant = qa * qb
normal = (a - e_a) * (b - e_b)
elif op == Operator.MATMUL:
# @ matmul operator not implemented for QuantTensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference between @ and matmul? Also in terms of implementations, what would we need to override to implement @?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't believe there is a difference so its probably something we should create an issue to implement

quant = torch.matmul(qa, qb)
normal = (a - e_a) @ (b - e_b)
else:
# unrecognised operator
assert False

# tolerance set to a high value as there is considerable loss of precision
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment is outdated I believe

assert torch.allclose(normal, quant)


def test_quant_tensor_div_by_zero():
a = to_quant_tensor(torch.ones(4, 4))
b = to_quant_tensor(torch.zeros(4, 4))
assert torch.isinf(a / b).all().item()


def test_quant_tensor_div_by_fraction():
a = to_quant_tensor(torch.ones(4, 4))
b = to_quant_tensor(torch.ones(4, 4) * 0.5)
assert torch.allclose(a / b, torch.ones(4, 4) * 2, atol=0.1)


# TODO: need to deal with quant metadata
def test_quant_tensor_transpose():
x = torch.ones(4, 4).tril()
a = x.clone()
b = to_quant_tensor(x)
assert torch.allclose(a.transpose(0, 1), b.transpose(0, 1), atol=0.01)


# TODO: need to deal with quant metadata
def test_quant_tensor_view():
Copy link
Collaborator

@Giuseppe5 Giuseppe5 Mar 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

View and transpose open the discussion to a broader topic regarding how to deal with quant metadata views and transpose, especially in the case where we are doing per channel or finer granularity quantizations.

For now, I would add a TODO in both test case that says that we need to deal with quant metadata and test it

x = torch.ones(4, 4)
a = to_quant_tensor(x)
b = torch.Tensor(x)

assert torch.allclose(a.view(-1), b.view(-1), atol=0.01)
assert torch.allclose(a.view(2, -1), b.view(2, -1), atol=0.01)
assert torch.allclose(a.view(16, -1), b.view(16, -1), atol=0.01)
assert torch.allclose(a.view(8, 2), b.view(8, 2), atol=0.01)
Loading