From dd65871f3f25a523a47a74f2f5306c57048592b0 Mon Sep 17 00:00:00 2001 From: "yizhi.chen" Date: Wed, 16 Dec 2020 11:11:07 +0800 Subject: [PATCH] Support direct calls and its unittest. Support python3 for unittest. --- contexttimer/__init__.py | 47 +++++++++++++++++++--------------------- tests/test_timer.py | 22 ++++++++++++++++--- 2 files changed, 41 insertions(+), 28 deletions(-) diff --git a/contexttimer/__init__.py b/contexttimer/__init__.py index d7190bc..973d96c 100644 --- a/contexttimer/__init__.py +++ b/contexttimer/__init__.py @@ -105,9 +105,9 @@ def elapsed(self): return (self.end - self.start) * self.factor -def timer(logger=None, level=logging.INFO, +def timer(f=None, logger=None, level=logging.INFO, fmt="function %(function_name)s execution time: %(execution_time).3f", - *func_or_func_args, **timer_kwargs): + **timer_kwargs): """ Function decorator displaying the function execution time All kwargs are the arguments taken by the Timer class constructor. @@ -115,30 +115,27 @@ def timer(logger=None, level=logging.INFO, """ # store Timer kwargs in local variable so the namespace isn't polluted # by different level args and kwargs + if f is None: + return functools.partial(timer, logger=logger, level=level, fmt=fmt, **timer_kwargs) + + @functools.wraps(f) + def wrapped(*args, **kwargs): + with Timer(**timer_kwargs) as t: + out = f(*args, **kwargs) + context = { + 'function_name': f.__name__, + 'execution_time': t.elapsed, + } + if logger: + logger.log( + level, + fmt % context, + extra=context) + else: + print(fmt % context) + return out + return wrapped - def wrapped_f(f): - @functools.wraps(f) - def wrapped(*args, **kwargs): - with Timer(**timer_kwargs) as t: - out = f(*args, **kwargs) - context = { - 'function_name': f.__name__, - 'execution_time': t.elapsed, - } - if logger: - logger.log( - level, - fmt % context, - extra=context) - else: - print(fmt % context) - return out - return wrapped - if (len(func_or_func_args) == 1 - and isinstance(func_or_func_args[0], collections.Callable)): - return wrapped_f(func_or_func_args[0]) - else: - return wrapped_f if __name__ == "__main__": import logging diff --git a/tests/test_timer.py b/tests/test_timer.py index 1ea1334..dde5033 100644 --- a/tests/test_timer.py +++ b/tests/test_timer.py @@ -1,13 +1,18 @@ import contexttimer import unittest import mock -from cStringIO import StringIO + +import sys +if sys.version_info.major > 2: + from io import StringIO +else: + from cStringIO import StringIO class ContextTimerTest(unittest.TestCase): def test_timer_print(self): def print_reversed(string): - print " ".join(reversed(string.split())) + print(" ".join(reversed(string.split()))) tests = [ # (kwargs, expected_regex) @@ -29,12 +34,23 @@ def print_reversed(string): self.assertRegexpMatches(output.getvalue(), expected) def test_decorator_print(self): + # Test direct call. + expected = r"function foo execution time: [0-9.]+" + output = StringIO() + with mock.patch('sys.stdout', new=output): + @contexttimer.timer + def foo(): + pass + foo() + self.assertIsNotNone(output) + self.assertRegexpMatches(output.getvalue(), expected) + + # Test calls with 0 or more args. tests = [ ({}, r"function foo execution time: [0-9.]+"), ({'fmt': '%(execution_time)s seconds later...'}, r"[0-9.]+ seconds later..."), ] - for kwargs, expected in tests: output = StringIO() with mock.patch('sys.stdout', new=output):