From fa17b1ddc0a87c2973b1ac816f6f5f06f19c40c7 Mon Sep 17 00:00:00 2001 From: gfyoung Date: Fri, 21 Apr 2017 22:53:07 -0400 Subject: [PATCH] MAINT: Refactor _AssertRaisesContextManager Rewrite _AssertRaisesContextManager with more documentation and remove vestigial assertRaises. Follow-up to gh-16089. --- pandas/tests/test_base.py | 5 ++- pandas/util/testing.py | 83 ++++++++++++++++++++++++++++----------- 2 files changed, 64 insertions(+), 24 deletions(-) diff --git a/pandas/tests/test_base.py b/pandas/tests/test_base.py index d91aab6bc3cebb..5814ae3494b448 100644 --- a/pandas/tests/test_base.py +++ b/pandas/tests/test_base.py @@ -45,8 +45,8 @@ class CheckImmutable(object): mutable_regex = re.compile('does not support mutable operations') def check_mutable_error(self, *args, **kwargs): - # pass whatever functions you normally would to assertRaises (after the - # Exception kind) + # Pass whatever function you normally would to assertRaisesRegexp + # (after the Exception kind). tm.assertRaisesRegexp(TypeError, self.mutable_regex, *args, **kwargs) def test_no_mutable_funcs(self): @@ -70,6 +70,7 @@ def delslice(): self.check_mutable_error(delslice) mutable_methods = getattr(self, "mutable_methods", []) + for meth in mutable_methods: self.check_mutable_error(getattr(self.container, meth)) diff --git a/pandas/util/testing.py b/pandas/util/testing.py index 7f62d319aa096c..27df39f43fc44b 100644 --- a/pandas/util/testing.py +++ b/pandas/util/testing.py @@ -2500,40 +2500,79 @@ def assertRaisesRegexp(_exception, _regexp, _callable=None, *args, **kwargs): class _AssertRaisesContextmanager(object): """ - Handles the behind the scenes work - for assertRaises and assertRaisesRegexp + Context manager behind assertRaisesRegexp. """ - def __init__(self, exception, regexp=None, *args, **kwargs): + def __init__(self, exception, regexp=None): + """ + Initialize an _AssertRaisesContextManager instance. + + Parameters + ---------- + exception : class + The expected Exception class. + regexp : str, default None + The regex to compare against the Exception message. + """ + self.exception = exception + if regexp is not None and not hasattr(regexp, "search"): regexp = re.compile(regexp, re.DOTALL) + self.regexp = regexp def __enter__(self): return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type, exc_value, trace_back): expected = self.exception - if not exc_type: - name = getattr(expected, "__name__", str(expected)) - raise AssertionError("{0} not raised.".format(name)) - if issubclass(exc_type, expected): - return self.handle_success(exc_type, exc_value, traceback) - return self.handle_failure(exc_type, exc_value, traceback) - - def handle_failure(*args, **kwargs): - # Failed, so allow Exception to bubble up - return False - def handle_success(self, exc_type, exc_value, traceback): - if self.regexp is not None: - val = str(exc_value) - if not self.regexp.search(val): - e = AssertionError('"%s" does not match "%s"' % - (self.regexp.pattern, str(val))) - raise_with_traceback(e, traceback) - return True + if not exc_type: + exp_name = getattr(expected, "__name__", str(expected)) + raise AssertionError("{0} not raised.".format(exp_name)) + + return self.exception_matches(exc_type, exc_value, trace_back) + + def exception_matches(self, exc_type, exc_value, trace_back): + """ + Check that the Exception raised matches the expected Exception + and expected error message regular expression. + + Parameters + ---------- + exc_type : class + The type of Exception raised. + exc_value : Exception + The instance of `exc_type` raised. + trace_back : stack trace object + The traceback object associated with `exc_value`. + + Returns + ------- + is_matched : bool + Whether or not the Exception raised matches the expected + Exception class and expected error message regular expression. + + Raises + ------ + AssertionError : The error message provided does not match + the expected error message regular expression. + """ + + if issubclass(exc_type, self.exception): + if self.regexp is not None: + val = str(exc_value) + + if not self.regexp.search(val): + e = AssertionError('"%s" does not match "%s"' % + (self.regexp.pattern, str(val))) + raise_with_traceback(e, trace_back) + + return True + else: + # Failed, so allow Exception to bubble up. + return False @contextmanager