Skip to content

Commit

Permalink
Improve error message for array_equal asserts. (rapidsai#4973)
Browse files Browse the repository at this point in the history
Authors:
  - Carl Simon Adorf (https://github.com/csadorf)

Approvers:
  - William Hicks (https://github.com/wphicks)

URL: rapidsai#4973
  • Loading branch information
csadorf authored Nov 9, 2022
1 parent f2724e2 commit 034a625
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 10 deletions.
109 changes: 99 additions & 10 deletions python/cuml/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from textwrap import dedent, indent

import cupy as cp
import numpy as np
Expand Down Expand Up @@ -39,25 +40,113 @@
import pytest


def array_equal(a, b, unit_tol=1e-4, total_tol=1e-4, with_sign=True):
def array_difference(a, b, with_sign=True):
"""
Utility function to compare 2 numpy arrays. Two individual elements
are assumed equal if they are within `unit_tol` of each other, and two
arrays are considered equal if less than `total_tol` percentage of
elements are different.
Utility function to compute the difference between 2 arrays.
"""

a = to_nparray(a)
b = to_nparray(b)

if len(a) == 0 and len(b) == 0:
return True
return 0

if not with_sign:
a, b = np.abs(a), np.abs(b)
res = (np.sum(np.abs(a - b) > unit_tol)) / a.size < total_tol
return res
return np.sum(np.abs(a - b))


class array_equal:
"""
Utility functor to compare 2 numpy arrays and optionally show a meaningful
error message in case they are not. Two individual elements are assumed
equal if they are within `unit_tol` of each other, and two arrays are
considered equal if less than `total_tol` percentage of elements are
different.
"""

def __init__(self, a, b, unit_tol=1e-4, total_tol=1e-4, with_sign=True):
self.a = to_nparray(a)
self.b = to_nparray(b)
self.unit_tol = unit_tol
self.total_tol = total_tol
self.with_sign = with_sign

def compute_difference(self):
return array_difference(self.a, self.b, with_sign=self.with_sign)

def __bool__(self):
if len(self.a) == len(self.b) == 0:
return True

if self.with_sign:
a, b = self.a, self.b
else:
a, b = np.abs(self.a), np.abs(self.b)

res = (np.sum(np.abs(a - b) > self.unit_tol)) / a.size < self.total_tol
return bool(res)

def __eq__(self, other):
if isinstance(other, bool):
return bool(self) == other
return super().__eq__(other)

def _repr(self, threshold=None):
name = self.__class__.__name__

return [f"<{name}: "] \
+ f"{np.array2string(self.a, threshold=threshold)} ".splitlines()\
+ f"{np.array2string(self.b, threshold=threshold)} ".splitlines()\
+ [
f"unit_tol={self.unit_tol} ",
f"total_tol={self.total_tol} ",
f"with_sign={self.with_sign}",
">"
]

def __repr__(self):
return "".join(self._repr(threshold=5))

def __str__(self):
tokens = self._repr(threshold=1000)
return "\n".join(
f"{' ' if 0 < n < len(tokens) - 1 else ''}{token}"
for n, token in enumerate(tokens)
)


def assert_array_equal(a, b, unit_tol=1e-4, total_tol=1e-4, with_sign=True):
"""
Raises an AssertionError if arrays are not considered equal.
Uses the same arguments as array_equal(), but raises an AssertionError in
case that the test considers the arrays to not be equal.
This function will generate a nicer error message in the context of pytest
compared to a plain `assert array_equal(...)`.
"""
# Determine array equality.
equal = array_equal(
a, b, unit_tol=unit_tol, total_tol=total_tol, with_sign=with_sign)
if not equal:
# Generate indented array string representation ...
str_a = indent(np.array2string(a), " ").splitlines()
str_b = indent(np.array2string(b), " ").splitlines()
# ... and add labels
str_a[0] = f"a: {str_a[0][3:]}"
str_b[0] = f"b: {str_b[0][3:]}"

# Create assertion error message and raise exception.
assertion_error_msg = dedent(f"""
Arrays are not equal
unit_tol: {unit_tol}
total_tol: {total_tol}
with_sign: {with_sign}
""") + "\n".join(str_a + str_b)
raise AssertionError(assertion_error_msg)


def get_pattern(name, n_samples):
Expand Down
95 changes: 95 additions & 0 deletions python/cuml/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest
import numpy as np
from cuml.testing.utils import array_equal, assert_array_equal
from hypothesis import given, note
from hypothesis import strategies as st
from hypothesis import target
from hypothesis.extra.numpy import (array_shapes, arrays, floating_dtypes,
integer_dtypes)


@given(
arrays(
dtype=st.one_of(floating_dtypes(), integer_dtypes()),
shape=array_shapes(),
),
st.floats(1e-4, 1.0))
@pytest.mark.filterwarnings("ignore:invalid value encountered in subtract")
def test_array_equal_same_array(array, tol):
equal = array_equal(array, array, tol)
note(equal)
difference = equal.compute_difference()
if np.isfinite(difference):
target(float(np.abs(difference)))
assert equal
assert equal == True # noqa: E712
assert bool(equal) is True
assert_array_equal(array, array, tol)


@given(
arrays=array_shapes().flatmap(
lambda shape:
st.tuples(
arrays(
dtype=st.one_of(floating_dtypes(), integer_dtypes()),
shape=shape,
),
arrays(
dtype=st.one_of(floating_dtypes(), integer_dtypes()),
shape=shape,
),
)
),
unit_tol=st.floats(1e-4, 1.0),
with_sign=st.booleans(),
)
@pytest.mark.filterwarnings("ignore:invalid value encountered in subtract")
def test_array_equal_two_arrays(arrays, unit_tol, with_sign):
array_a, array_b = arrays
equal = array_equal(array_a, array_b, unit_tol, with_sign=with_sign)
equal_flipped = array_equal(
array_b, array_a, unit_tol, with_sign=with_sign)
note(equal)
difference = equal.compute_difference()
a, b = (array_a, array_b) if with_sign else \
(np.abs(array_a), np.abs(array_b))
expect_equal = np.sum(np.abs(a - b) > unit_tol) / array_a.size < 1e-4
if expect_equal:
assert_array_equal(array_a, array_b, unit_tol, with_sign=with_sign)
assert equal
assert bool(equal) is True
assert equal == True # noqa: E712
assert True == equal # noqa: E712
assert equal != False # noqa: E712
assert False != equal # noqa: E712
assert equal_flipped
assert bool(equal_flipped) is True
assert equal_flipped == True # noqa: E712
assert True == equal_flipped # noqa: E712
assert equal_flipped != False # noqa: E712
assert False != equal_flipped # noqa: E712
else:
with pytest.raises(AssertionError):
assert_array_equal(array_a, array_b, unit_tol, with_sign=with_sign)
assert not equal
assert bool(equal) is not True
assert equal != True # noqa: E712
assert True != equal # noqa: E712
assert equal == False # noqa: E712
assert False == equal # noqa: E712
assert difference != 0

0 comments on commit 034a625

Please sign in to comment.