-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2333 from casperdcl/tqdm
change progress bar backend to tqdm
- Loading branch information
Showing
21 changed files
with
369 additions
and
562 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,7 @@ innosetup/config.ini | |
*.exe | ||
|
||
.coverage | ||
.coverage.* | ||
|
||
*.swp | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
Paweł Redzyński <[email protected]> | ||
Dmitry Petrov <[email protected]> | ||
Earl Hathaway <[email protected]> | ||
Nabanita Dash <[email protected]> | ||
Kurian Benoy <[email protected]> | ||
Sritanu Chakraborty <[email protected]> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,154 +1,94 @@ | ||
"""Manages progress bars for dvc repo.""" | ||
|
||
from __future__ import print_function | ||
from __future__ import unicode_literals | ||
|
||
from dvc.utils.compat import str | ||
|
||
import sys | ||
import threading | ||
import logging | ||
from tqdm import tqdm | ||
from copy import deepcopy | ||
from concurrent.futures import ThreadPoolExecutor | ||
|
||
CLEARLINE_PATTERN = "\r\x1b[K" | ||
|
||
|
||
class Progress(object): | ||
class TqdmThreadPoolExecutor(ThreadPoolExecutor): | ||
""" | ||
Simple multi-target progress bar. | ||
Ensure worker progressbars are cleared away properly. | ||
""" | ||
|
||
def __init__(self): | ||
self._n_total = 0 | ||
self._n_finished = 0 | ||
self._lock = threading.Lock() | ||
self._line = None | ||
|
||
def set_n_total(self, total): | ||
"""Sets total number of targets.""" | ||
self._n_total = total | ||
self._n_finished = 0 | ||
|
||
@property | ||
def is_finished(self): | ||
"""Returns if all targets have finished.""" | ||
return self._n_total == self._n_finished | ||
|
||
def clearln(self): | ||
self._print(CLEARLINE_PATTERN, end="") | ||
|
||
def _writeln(self, line): | ||
self.clearln() | ||
self._print(line, end="") | ||
sys.stdout.flush() | ||
|
||
def reset(self): | ||
with self._lock: | ||
self._n_total = 0 | ||
self._n_finished = 0 | ||
self._line = None | ||
|
||
def refresh(self, line=None): | ||
"""Refreshes progress bar.""" | ||
# Just go away if it is locked. Will update next time | ||
if not self._lock.acquire(False): | ||
return | ||
|
||
if line is None: | ||
line = self._line | ||
|
||
if sys.stdout.isatty() and line is not None: | ||
self._writeln(line) | ||
self._line = line | ||
|
||
self._lock.release() | ||
|
||
def update_target(self, name, current, total): | ||
"""Updates progress bar for a specified target.""" | ||
self.refresh(self._bar(name, current, total)) | ||
|
||
def finish_target(self, name): | ||
"""Finishes progress bar for a specified target.""" | ||
# We have to write a msg about finished target | ||
with self._lock: | ||
pbar = self._bar(name, 100, 100) | ||
|
||
if sys.stdout.isatty(): | ||
self.clearln() | ||
|
||
self._print(pbar) | ||
|
||
self._n_finished += 1 | ||
self._line = None | ||
|
||
def _bar(self, target_name, current, total): | ||
def __enter__(self): | ||
""" | ||
Make a progress bar out of info, which looks like: | ||
(1/2): [########################################] 100% master.zip | ||
Creates a blank initial dummy progress bar if needed so that workers | ||
are forced to create "nested" bars. | ||
""" | ||
bar_len = 30 | ||
|
||
if total is None: | ||
state = 0 | ||
percent = "?% " | ||
else: | ||
total = int(total) | ||
state = int((100 * current) / total) if current < total else 100 | ||
percent = str(state) + "% " | ||
|
||
if self._n_total > 1: | ||
num = "({}/{}): ".format(self._n_finished + 1, self._n_total) | ||
else: | ||
num = "" | ||
blank_bar = Tqdm(bar_format="Multi-Threaded:", leave=False) | ||
if blank_bar.pos > 0: | ||
# already nested - don't need a placeholder bar | ||
blank_bar.close() | ||
self.bar = blank_bar | ||
return super(TqdmThreadPoolExecutor, self).__enter__() | ||
|
||
n_sh = int((state * bar_len) / 100) | ||
n_sp = bar_len - n_sh | ||
pbar = "[" + "#" * n_sh + " " * n_sp + "] " | ||
def __exit__(self, *a, **k): | ||
super(TqdmThreadPoolExecutor, self).__exit__(*a, **k) | ||
self.bar.close() | ||
|
||
return num + pbar + percent + target_name | ||
|
||
@staticmethod | ||
def _print(*args, **kwargs): | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
if logger.getEffectiveLevel() == logging.CRITICAL: | ||
return | ||
|
||
print(*args, **kwargs) | ||
|
||
def __enter__(self): | ||
self._lock.acquire(True) | ||
if self._line is not None: | ||
self.clearln() | ||
|
||
def __exit__(self, typ, value, tbck): | ||
if self._line is not None: | ||
self.refresh() | ||
self._lock.release() | ||
|
||
def __call__(self, seq, name="", total=None): | ||
if total is None: | ||
total = len(seq) | ||
|
||
self.update_target(name, 0, total) | ||
for done, item in enumerate(seq, start=1): | ||
yield item | ||
self.update_target(name, done, total) | ||
self.finish_target(name) | ||
|
||
|
||
class ProgressCallback(object): | ||
def __init__(self, total): | ||
self.total = total | ||
self.current = 0 | ||
progress.reset() | ||
|
||
def update(self, name, progress_to_add=1): | ||
self.current += progress_to_add | ||
progress.update_target(name, self.current, self.total) | ||
|
||
def finish(self, name): | ||
progress.finish_target(name) | ||
|
||
class Tqdm(tqdm): | ||
""" | ||
maximum-compatibility tqdm-based progressbars | ||
""" | ||
|
||
progress = Progress() # pylint: disable=invalid-name | ||
def __init__( | ||
self, | ||
iterable=None, | ||
disable=None, | ||
bytes=False, # pylint: disable=W0622 | ||
desc_truncate=None, | ||
leave=None, | ||
**kwargs | ||
): | ||
""" | ||
bytes : shortcut for | ||
`unit='B', unit_scale=True, unit_divisor=1024, miniters=1` | ||
desc_truncate : like `desc` but will truncate to 10 chars | ||
kwargs : anything accepted by `tqdm.tqdm()` | ||
""" | ||
kwargs = deepcopy(kwargs) | ||
if bytes: | ||
for k, v in dict( | ||
unit="B", unit_scale=True, unit_divisor=1024, miniters=1 | ||
).items(): | ||
kwargs.setdefault(k, v) | ||
if desc_truncate is not None: | ||
kwargs.setdefault("desc", self.truncate(desc_truncate)) | ||
if disable is None: | ||
disable = ( | ||
logging.getLogger(__name__).getEffectiveLevel() | ||
>= logging.CRITICAL | ||
) | ||
super(Tqdm, self).__init__( | ||
iterable=iterable, disable=disable, leave=leave, **kwargs | ||
) | ||
|
||
def update_desc(self, desc, n=1, truncate=True): | ||
""" | ||
Calls `set_description(truncate(desc))` and `update(n)` | ||
""" | ||
self.set_description( | ||
self.truncate(desc) if truncate else desc, refresh=False | ||
) | ||
self.update(n) | ||
|
||
def update_to(self, current, total=None): | ||
if total: | ||
self.total = total # pylint: disable=W0613,W0201 | ||
self.update(current - self.n) | ||
|
||
@classmethod | ||
def truncate(cls, s, max_len=25, end=True, fill="..."): | ||
""" | ||
Guarantee len(output) < max_lenself. | ||
>>> truncate("hello", 4) | ||
'...o' | ||
""" | ||
if len(s) <= max_len: | ||
return s | ||
if len(fill) > max_len: | ||
return fill[-max_len:] if end else fill[:max_len] | ||
i = max_len - len(fill) | ||
return (fill + s[-i:]) if end else (s[:i] + fill) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.