From 56442f0e3cafc5cad7d91a87a68a2c72c5b24423 Mon Sep 17 00:00:00 2001 From: DinoBektesevic Date: Fri, 23 Aug 2024 17:29:12 -0700 Subject: [PATCH] Make showing progress bars a global setting. --- src/kbmod/__init__.py | 16 +++++++++++++- src/kbmod/reprojection.py | 46 +++++++++++++++++++++------------------ src/kbmod/work_unit.py | 22 +++++++++---------- 3 files changed, 51 insertions(+), 33 deletions(-) diff --git a/src/kbmod/__init__.py b/src/kbmod/__init__.py index 9fdfcbcea..e9dfb0319 100644 --- a/src/kbmod/__init__.py +++ b/src/kbmod/__init__.py @@ -13,7 +13,21 @@ # Import the rest of the package from kbmod.search import Logging -PROGRESS_BAR = bool(int(os.environ.get("KB_PROGRESS_BARS", 1))) +KB_INTERACTIVE_MODE = bool(int(os.environ.get("KB_INTERACTIVE_MODE", 1))) + +def is_interactive(): + """Returns the KBMOD use-mode. + + In interactive mode, displays progress bars and user-friendly + progress output. + + Returns + ------ + mode : `bool` + `True` when in interactive mode. + """ + global KBM_INTERACTIVE_MODE + return KB_INTERACTIVE_MODE # there are ways for this to go to a file, but is it worth it? # Then we have to roll a whole logging.config_from_shared_config thing diff --git a/src/kbmod/reprojection.py b/src/kbmod/reprojection.py index 6dcc6a0e1..4c2083cd7 100644 --- a/src/kbmod/reprojection.py +++ b/src/kbmod/reprojection.py @@ -5,11 +5,11 @@ from astropy.wcs import WCS from tqdm.asyncio import tqdm +from kbmod import is_interactive from kbmod.search import KB_NO_DATA, PSF, ImageStack, LayeredImage, RawImage from kbmod.work_unit import WorkUnit from kbmod.tqdm_utils import TQDMUtils from kbmod.wcs_utils import append_wcs_to_hdu_header -from kbmod import PROGRESS_BAR from astropy.io import fits import os from copy import copy @@ -81,7 +81,7 @@ def reproject_work_unit( write_output=False, directory=None, filename=None, - progress=PROGRESS_BAR, + show_progress=None, ): """Given a WorkUnit and a WCS, reproject all of the images in the ImageStack into a common WCS. @@ -110,14 +110,16 @@ def reproject_work_unit( The directory where output will be written if `write_output` is set to True. filename : `str` The base filename where output will be written if `write_output` is set to True. - progress : `bool` - Whether or not to enable the `tqdm` progress bar. + show_progress : `bool` or `None`, optional + If `None` use default settings, when a boolean forces the progress bar to be + displayed or hidden. Returns ---------- A `kbmod.WorkUnit` reprojected with a common `astropy.wcs.WCS`, or `None` in the case where `write_output` is set to True. """ + show_progress = is_interactive() if show_progress is None else show_progress if (work_unit.lazy or write_output) and (directory is None or filename is None): raise ValueError("can't write output to sharded fits without directory and filename provided.") if work_unit.lazy: @@ -128,7 +130,7 @@ def reproject_work_unit( max_parallel_processes=max_parallel_processes, directory=directory, filename=filename, - progress=progress, + show_progress=show_progress, ) if parallelize: return _reproject_work_unit_in_parallel( @@ -139,7 +141,7 @@ def reproject_work_unit( write_output=write_output, directory=directory, filename=filename, - progress=progress, + show_progress=show_progress, ) else: return _reproject_work_unit( @@ -149,7 +151,7 @@ def reproject_work_unit( write_output=write_output, directory=directory, filename=filename, - progress=progress, + show_progress=show_progress, ) @@ -160,7 +162,7 @@ def _reproject_work_unit( write_output=False, directory=None, filename=None, - progress=PROGRESS_BAR, + show_progress=False, ): """Given a WorkUnit and a WCS, reproject all of the images in the ImageStack into a common WCS. @@ -182,8 +184,8 @@ def _reproject_work_unit( The directory where output will be written if `write_output` is set to True. filename : `str` The base filename where output will be written if `write_output` is set to True. - disable_progress : `bool` - Whether or not to disable the `tqdm` progress bar. + disable_show_progress : `bool` + Whether or not to disable the `tqdm` show_progress bar. Returns ---------- A `kbmod.WorkUnit` reprojected with a common `astropy.wcs.WCS`, or `None` in the case @@ -197,7 +199,7 @@ def _reproject_work_unit( enumerate(zip(unique_obstimes, unique_obstime_indices)), bar_format=TQDMUtils.DEFAULT_TQDM_BAR_FORMAT, desc="Reprojecting", - disable=not progress, + disable=not show_progress, ): time, indices = o_i science_add = np.zeros(common_wcs.array_shape, dtype=np.float32) @@ -311,7 +313,7 @@ def _reproject_work_unit_in_parallel( write_output=False, directory=None, filename=None, - progress=PROGRESS_BAR, + show_progress=False, ): """Given a WorkUnit and a WCS, reproject all of the images in the ImageStack into a common WCS. This function uses multiprocessing to reproject the images @@ -338,8 +340,8 @@ def _reproject_work_unit_in_parallel( The directory where output will be written if `write_output` is set to True. filename : `str` The base filename where output will be written if `write_output` is set to True. - progress : `bool` - Whether or not to enable the `tqdm` progress bar. + show_progress : `bool` + Whether or not to enable the `tqdm` show_progress bar. Returns ---------- @@ -399,14 +401,14 @@ def _reproject_work_unit_in_parallel( original_wcs=original_wcs, ) ) - # Need to consume the generator producted by tqdm to update the progress bar so we instantiate a list + # Need to consume the generator producted by tqdm to update the show_progress bar so we instantiate a list list( tqdm( concurrent.futures.as_completed(future_reprojections), total=len(future_reprojections), bar_format=TQDMUtils.DEFAULT_TQDM_BAR_FORMAT, desc="Reprojecting", - disable=not progress, + disable=not show_progress, ) ) @@ -467,7 +469,7 @@ def reproject_lazy_work_unit( filename, frame="original", max_parallel_processes=MAX_PROCESSES, - progress=PROGRESS_BAR, + show_progress=None, ): """Given a WorkUnit and a WCS, reproject all of the images in the ImageStack into a common WCS. This function is used with lazily evaluated `WorkUnit`s and @@ -496,9 +498,11 @@ def reproject_lazy_work_unit( The maximum number of parallel processes to use when reprojecting. Default is 8. For more see `concurrent.futures.ProcessPoolExecutor` in the Python docs. - progress : `bool` - Whether or not to enable the `tqdm` progress bar. + show_progress : `bool` or `None`, optional + If `None` use default settings, when a boolean forces the progress bar to be + displayed or hidden. """ + show_progress = is_interactive() if show_progress is None else show_progress if not work_unit.lazy: raise ValueError("WorkUnit must be lazily loaded.") @@ -529,14 +533,14 @@ def reproject_lazy_work_unit( ) ) - # Need to consume the generator producted by tqdm to update the progress bar so we instantiate a list + # Need to consume the generator producted by tqdm to update the show_progress bar so we instantiate a list list( tqdm( concurrent.futures.as_completed(future_reprojections), total=len(future_reprojections), bar_format=TQDMUtils.DEFAULT_TQDM_BAR_FORMAT, desc="Reprojecting", - disable=not progress, + disable=not show_progress, ) ) diff --git a/src/kbmod/work_unit.py b/src/kbmod/work_unit.py index 5a980edcc..c2bfabb9f 100644 --- a/src/kbmod/work_unit.py +++ b/src/kbmod/work_unit.py @@ -1,19 +1,17 @@ -import math import os +import warnings +from pathlib import Path from astropy.io import fits -from astropy.table import Table from astropy.utils.exceptions import AstropyWarning -from astropy.wcs import WCS from astropy.wcs.utils import skycoord_to_pixel from astropy.time import Time from astropy.coordinates import SkyCoord, EarthLocation import numpy as np -from pathlib import Path -import warnings from yaml import dump, safe_load from tqdm import tqdm +from kbmod import is_interactive from kbmod.configuration import SearchConfiguration from kbmod.search import ImageStack, LayeredImage, PSF, RawImage, Logging from kbmod.wcs_utils import ( @@ -25,7 +23,6 @@ ) from kbmod.reprojection_utils import invert_correct_parallax from kbmod.tqdm_utils import TQDMUtils -from kbmod import PROGRESS_BAR logger = Logging.getLogger(__name__) @@ -226,7 +223,7 @@ def get_num_images(self): return len(self._per_image_indices) @classmethod - def from_fits(cls, filename, progress=PROGRESS_BAR): + def from_fits(cls, filename, show_progress=None): """Create a WorkUnit from a single FITS file. The FITS file will have at least the following extensions: @@ -242,14 +239,16 @@ def from_fits(cls, filename, progress=PROGRESS_BAR): ---------- filename : `str` The file to load. - progress : `bool` - Whether or not to enable the `tqdm` progress bar. + show_progress : `bool` or `None`, optional + If `None` use default settings, when a boolean forces the progress bar to be + displayed or hidden. Returns ------- result : `WorkUnit` The loaded WorkUnit. """ + show_progress = is_interactive() if show_progress is None else show_progress logger.info(f"Loading WorkUnit from FITS file {filename}.") if not Path(filename).is_file(): raise ValueError(f"WorkUnit file {filename} not found.") @@ -293,7 +292,7 @@ def from_fits(cls, filename, progress=PROGRESS_BAR): range(num_images), bar_format=TQDMUtils.DEFAULT_TQDM_BAR_FORMAT, desc="Loading images", - disable=not progress, + disable=not show_progress, ): sci_hdu = hdul[f"SCI_{i}"] @@ -322,7 +321,7 @@ def from_fits(cls, filename, progress=PROGRESS_BAR): range(n_constituents), bar_format=TQDMUtils.DEFAULT_TQDM_BAR_FORMAT, desc="Loading WCS", - disable=not progress, + disable=not show_progress, ): # Extract the per-image WCS if one exists. per_image_wcs.append(extract_wcs_from_hdu_header(hdul[f"WCS_{i}"].header)) @@ -942,6 +941,7 @@ def image_positions_to_original_icrs( con_image = self.constituent_images[j] con_wcs = self._per_image_wcs[j] height, width = con_wcs.array_shape + breakpoint() x, y = skycoord_to_pixel(coord, con_wcs) x, y = float(x), float(y) if output_format == "xy":