diff --git a/click_spinner/__init__.py b/click_spinner/__init__.py index e32616d..aeec089 100644 --- a/click_spinner/__init__.py +++ b/click_spinner/__init__.py @@ -7,17 +7,18 @@ class Spinner(object): spinner_cycle = itertools.cycle(['-', '/', '|', '\\']) - def __init__(self, beep=False, disable=False, force=False): + def __init__(self, beep=False, disable=False, force=False, stream=sys.stdout): self.disable = disable self.beep = beep self.force = force + self.stream = stream self.stop_running = None self.spin_thread = None def start(self): if self.disable: return - if sys.stdout.isatty() or self.force: + if self.stream.isatty() or self.force: self.stop_running = threading.Event() self.spin_thread = threading.Thread(target=self.init_spin) self.spin_thread.start() @@ -29,11 +30,11 @@ def stop(self): def init_spin(self): while not self.stop_running.is_set(): - sys.stdout.write(next(self.spinner_cycle)) - sys.stdout.flush() + self.stream.write(next(self.spinner_cycle)) + self.stream.flush() time.sleep(0.25) - sys.stdout.write('\b') - sys.stdout.flush() + self.stream.write('\b') + self.stream.flush() def __enter__(self): self.start() @@ -44,12 +45,12 @@ def __exit__(self, exc_type, exc_val, exc_tb): return False self.stop() if self.beep: - sys.stdout.write('\7') - sys.stdout.flush() + self.stream.write('\7') + self.stream.flush() return False -def spinner(beep=False, disable=False, force=False): +def spinner(beep=False, disable=False, force=False, stream=sys.stdout): """This function creates a context manager that is used to display a spinner on stdout as long as the context has not exited. @@ -73,7 +74,7 @@ def spinner(beep=False, disable=False, force=False): do_something_else() """ - return Spinner(beep, disable, force) + return Spinner(beep, disable, force, stream) from ._version import get_versions diff --git a/tests/test_spinner.py b/tests/test_spinner.py index a04f25a..3f6d1ed 100644 --- a/tests/test_spinner.py +++ b/tests/test_spinner.py @@ -63,13 +63,10 @@ def test_spinner_redirect_force(): @click.command() def cli(): stdout_io = StringIO() - saved_stdout = sys.stdout - sys.stdout = stdout_io # redirect stdout to a string buffer - spinner = click_spinner.Spinner(force=True) + spinner = click_spinner.Spinner(force=True, stream=stdout_io) spinner.start() time.sleep(1) # allow time for a few spins spinner.stop() - sys.stdout = saved_stdout stdout_io.flush() stdout_str = stdout_io.getvalue() assert len(stdout_str) > 0