From c5c3c7e6840a648da50b05a90d029d6b392d807a Mon Sep 17 00:00:00 2001 From: Giuseppe Date: Tue, 14 Feb 2023 13:46:24 +0000 Subject: [PATCH] Test (graph eq): new test for graph equalization --- tests/brevitas/graph/equalization_fixtures.py | 61 +++++++++++++++++++ tests/brevitas/graph/test_equalization.py | 31 ++++++++++ 2 files changed, 92 insertions(+) create mode 100644 tests/brevitas/graph/equalization_fixtures.py create mode 100644 tests/brevitas/graph/test_equalization.py diff --git a/tests/brevitas/graph/equalization_fixtures.py b/tests/brevitas/graph/equalization_fixtures.py new file mode 100644 index 000000000..c2446a1e8 --- /dev/null +++ b/tests/brevitas/graph/equalization_fixtures.py @@ -0,0 +1,61 @@ +import pytest_cases +from pytest_cases import fixture_union +import torch +import torch.nn as nn + + +@pytest_cases.fixture +def residual_model(): + class ResidualModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = nn.Conv2d(3, 16, kernel_size=1) + self.conv_0 = nn.Conv2d(16, 3, kernel_size=1) + def forward(self, x): + start = x + x = self.conv(x) + x = self.conv_0(x) + x = start + x + return x + return ResidualModel + +@pytest_cases.fixture +def srcsinkconflict_model(): + """ + In this example, conv_0 is both a src and sink. + """ + class ResidualSrcsAndSinkModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv_start = nn.Conv2d(3, 3, kernel_size=1) + self.conv = nn.Conv2d(3, 3, kernel_size=1) + self.conv_0 = nn.Conv2d(3, 3, kernel_size=1) + def forward(self, x): + start = self.conv_start(x) + x = self.conv_0(start) + x = start + x + x = self.conv(x) + return x + return ResidualSrcsAndSinkModel + + +@pytest_cases.fixture +def mul_model(): + """ + In this example, conv_0 is both a src and sink. + """ + class ResidualSrcsAndSinkModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv_1 = nn.Conv2d(3, 3, kernel_size=1) + self.conv_0 = nn.Conv2d(3, 3, kernel_size=1) + self.conv_end = nn.Conv2d(3, 3, kernel_size=1) + def forward(self, x): + x_0 = self.conv_0(x) + x_1 = self.conv_1(x) + x = x_0 * x_1 + x = self.conv_end(x) + return x + return ResidualSrcsAndSinkModel + +toy_model = fixture_union('toy_model', [ 'residual_model', 'srcsinkconflict_model', 'mul_model']) diff --git a/tests/brevitas/graph/test_equalization.py b/tests/brevitas/graph/test_equalization.py new file mode 100644 index 000000000..b94d7c25f --- /dev/null +++ b/tests/brevitas/graph/test_equalization.py @@ -0,0 +1,31 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + + +from inspect import getfullargspec + +import torch + +from brevitas.fx import value_trace +from brevitas.graph import EqualizeGraph + +from .equalization_fixtures import * + +SEED = 123456 +IN_SIZE = (16,3,224,224) +ATOL = 1e-3 + + +def test_models(toy_model): + model = toy_model() + inp = torch.randn(IN_SIZE) + + input_name = getfullargspec(model.forward)[0][0] + model.eval() + expected_out = model(inp) + model = value_trace(model, {input_name: inp}) + model, regions = EqualizeGraph(3, return_regions=True).apply(model) + + out = model(inp) + assert len(regions) > 0 + assert torch.allclose(expected_out, out, atol=ATOL)