Skip to content

Commit

Permalink
Make showing progress bars a global setting.
Browse files Browse the repository at this point in the history
  • Loading branch information
DinoBektesevic committed Aug 24, 2024
1 parent 4d68bd8 commit 56442f0
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 33 deletions.
16 changes: 15 additions & 1 deletion src/kbmod/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,21 @@
# Import the rest of the package
from kbmod.search import Logging

PROGRESS_BAR = bool(int(os.environ.get("KB_PROGRESS_BARS", 1)))
KB_INTERACTIVE_MODE = bool(int(os.environ.get("KB_INTERACTIVE_MODE", 1)))

def is_interactive():
"""Returns the KBMOD use-mode.
In interactive mode, displays progress bars and user-friendly
progress output.
Returns
------
mode : `bool`
`True` when in interactive mode.
"""
global KBM_INTERACTIVE_MODE
return KB_INTERACTIVE_MODE

# there are ways for this to go to a file, but is it worth it?
# Then we have to roll a whole logging.config_from_shared_config thing
Expand Down
46 changes: 25 additions & 21 deletions src/kbmod/reprojection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from astropy.wcs import WCS
from tqdm.asyncio import tqdm

from kbmod import is_interactive
from kbmod.search import KB_NO_DATA, PSF, ImageStack, LayeredImage, RawImage
from kbmod.work_unit import WorkUnit
from kbmod.tqdm_utils import TQDMUtils
from kbmod.wcs_utils import append_wcs_to_hdu_header
from kbmod import PROGRESS_BAR
from astropy.io import fits
import os
from copy import copy
Expand Down Expand Up @@ -81,7 +81,7 @@ def reproject_work_unit(
write_output=False,
directory=None,
filename=None,
progress=PROGRESS_BAR,
show_progress=None,
):
"""Given a WorkUnit and a WCS, reproject all of the images in the ImageStack
into a common WCS.
Expand Down Expand Up @@ -110,14 +110,16 @@ def reproject_work_unit(
The directory where output will be written if `write_output` is set to True.
filename : `str`
The base filename where output will be written if `write_output` is set to True.
progress : `bool`
Whether or not to enable the `tqdm` progress bar.
show_progress : `bool` or `None`, optional
If `None` use default settings, when a boolean forces the progress bar to be
displayed or hidden.
Returns
----------
A `kbmod.WorkUnit` reprojected with a common `astropy.wcs.WCS`, or `None` in the case
where `write_output` is set to True.
"""
show_progress = is_interactive() if show_progress is None else show_progress
if (work_unit.lazy or write_output) and (directory is None or filename is None):
raise ValueError("can't write output to sharded fits without directory and filename provided.")
if work_unit.lazy:
Expand All @@ -128,7 +130,7 @@ def reproject_work_unit(
max_parallel_processes=max_parallel_processes,
directory=directory,
filename=filename,
progress=progress,
show_progress=show_progress,
)
if parallelize:
return _reproject_work_unit_in_parallel(
Expand All @@ -139,7 +141,7 @@ def reproject_work_unit(
write_output=write_output,
directory=directory,
filename=filename,
progress=progress,
show_progress=show_progress,
)
else:
return _reproject_work_unit(
Expand All @@ -149,7 +151,7 @@ def reproject_work_unit(
write_output=write_output,
directory=directory,
filename=filename,
progress=progress,
show_progress=show_progress,
)


Expand All @@ -160,7 +162,7 @@ def _reproject_work_unit(
write_output=False,
directory=None,
filename=None,
progress=PROGRESS_BAR,
show_progress=False,
):
"""Given a WorkUnit and a WCS, reproject all of the images in the ImageStack
into a common WCS.
Expand All @@ -182,8 +184,8 @@ def _reproject_work_unit(
The directory where output will be written if `write_output` is set to True.
filename : `str`
The base filename where output will be written if `write_output` is set to True.
disable_progress : `bool`
Whether or not to disable the `tqdm` progress bar.
disable_show_progress : `bool`
Whether or not to disable the `tqdm` show_progress bar.
Returns
----------
A `kbmod.WorkUnit` reprojected with a common `astropy.wcs.WCS`, or `None` in the case
Expand All @@ -197,7 +199,7 @@ def _reproject_work_unit(
enumerate(zip(unique_obstimes, unique_obstime_indices)),
bar_format=TQDMUtils.DEFAULT_TQDM_BAR_FORMAT,
desc="Reprojecting",
disable=not progress,
disable=not show_progress,
):
time, indices = o_i
science_add = np.zeros(common_wcs.array_shape, dtype=np.float32)
Expand Down Expand Up @@ -311,7 +313,7 @@ def _reproject_work_unit_in_parallel(
write_output=False,
directory=None,
filename=None,
progress=PROGRESS_BAR,
show_progress=False,
):
"""Given a WorkUnit and a WCS, reproject all of the images in the ImageStack
into a common WCS. This function uses multiprocessing to reproject the images
Expand All @@ -338,8 +340,8 @@ def _reproject_work_unit_in_parallel(
The directory where output will be written if `write_output` is set to True.
filename : `str`
The base filename where output will be written if `write_output` is set to True.
progress : `bool`
Whether or not to enable the `tqdm` progress bar.
show_progress : `bool`
Whether or not to enable the `tqdm` show_progress bar.
Returns
----------
Expand Down Expand Up @@ -399,14 +401,14 @@ def _reproject_work_unit_in_parallel(
original_wcs=original_wcs,
)
)
# Need to consume the generator producted by tqdm to update the progress bar so we instantiate a list
# Need to consume the generator producted by tqdm to update the show_progress bar so we instantiate a list
list(
tqdm(
concurrent.futures.as_completed(future_reprojections),
total=len(future_reprojections),
bar_format=TQDMUtils.DEFAULT_TQDM_BAR_FORMAT,
desc="Reprojecting",
disable=not progress,
disable=not show_progress,
)
)

Expand Down Expand Up @@ -467,7 +469,7 @@ def reproject_lazy_work_unit(
filename,
frame="original",
max_parallel_processes=MAX_PROCESSES,
progress=PROGRESS_BAR,
show_progress=None,
):
"""Given a WorkUnit and a WCS, reproject all of the images in the ImageStack
into a common WCS. This function is used with lazily evaluated `WorkUnit`s and
Expand Down Expand Up @@ -496,9 +498,11 @@ def reproject_lazy_work_unit(
The maximum number of parallel processes to use when reprojecting.
Default is 8. For more see `concurrent.futures.ProcessPoolExecutor` in
the Python docs.
progress : `bool`
Whether or not to enable the `tqdm` progress bar.
show_progress : `bool` or `None`, optional
If `None` use default settings, when a boolean forces the progress bar to be
displayed or hidden.
"""
show_progress = is_interactive() if show_progress is None else show_progress
if not work_unit.lazy:
raise ValueError("WorkUnit must be lazily loaded.")

Expand Down Expand Up @@ -529,14 +533,14 @@ def reproject_lazy_work_unit(
)
)

# Need to consume the generator producted by tqdm to update the progress bar so we instantiate a list
# Need to consume the generator producted by tqdm to update the show_progress bar so we instantiate a list
list(
tqdm(
concurrent.futures.as_completed(future_reprojections),
total=len(future_reprojections),
bar_format=TQDMUtils.DEFAULT_TQDM_BAR_FORMAT,
desc="Reprojecting",
disable=not progress,
disable=not show_progress,
)
)

Expand Down
22 changes: 11 additions & 11 deletions src/kbmod/work_unit.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
import math
import os
import warnings
from pathlib import Path

from astropy.io import fits
from astropy.table import Table
from astropy.utils.exceptions import AstropyWarning
from astropy.wcs import WCS
from astropy.wcs.utils import skycoord_to_pixel
from astropy.time import Time
from astropy.coordinates import SkyCoord, EarthLocation
import numpy as np
from pathlib import Path
import warnings
from yaml import dump, safe_load
from tqdm import tqdm

from kbmod import is_interactive
from kbmod.configuration import SearchConfiguration
from kbmod.search import ImageStack, LayeredImage, PSF, RawImage, Logging
from kbmod.wcs_utils import (
Expand All @@ -25,7 +23,6 @@
)
from kbmod.reprojection_utils import invert_correct_parallax
from kbmod.tqdm_utils import TQDMUtils
from kbmod import PROGRESS_BAR


logger = Logging.getLogger(__name__)
Expand Down Expand Up @@ -226,7 +223,7 @@ def get_num_images(self):
return len(self._per_image_indices)

@classmethod
def from_fits(cls, filename, progress=PROGRESS_BAR):
def from_fits(cls, filename, show_progress=None):
"""Create a WorkUnit from a single FITS file.
The FITS file will have at least the following extensions:
Expand All @@ -242,14 +239,16 @@ def from_fits(cls, filename, progress=PROGRESS_BAR):
----------
filename : `str`
The file to load.
progress : `bool`
Whether or not to enable the `tqdm` progress bar.
show_progress : `bool` or `None`, optional
If `None` use default settings, when a boolean forces the progress bar to be
displayed or hidden.
Returns
-------
result : `WorkUnit`
The loaded WorkUnit.
"""
show_progress = is_interactive() if show_progress is None else show_progress
logger.info(f"Loading WorkUnit from FITS file {filename}.")
if not Path(filename).is_file():
raise ValueError(f"WorkUnit file {filename} not found.")
Expand Down Expand Up @@ -293,7 +292,7 @@ def from_fits(cls, filename, progress=PROGRESS_BAR):
range(num_images),
bar_format=TQDMUtils.DEFAULT_TQDM_BAR_FORMAT,
desc="Loading images",
disable=not progress,
disable=not show_progress,
):
sci_hdu = hdul[f"SCI_{i}"]

Expand Down Expand Up @@ -322,7 +321,7 @@ def from_fits(cls, filename, progress=PROGRESS_BAR):
range(n_constituents),
bar_format=TQDMUtils.DEFAULT_TQDM_BAR_FORMAT,
desc="Loading WCS",
disable=not progress,
disable=not show_progress,
):
# Extract the per-image WCS if one exists.
per_image_wcs.append(extract_wcs_from_hdu_header(hdul[f"WCS_{i}"].header))
Expand Down Expand Up @@ -942,6 +941,7 @@ def image_positions_to_original_icrs(
con_image = self.constituent_images[j]
con_wcs = self._per_image_wcs[j]
height, width = con_wcs.array_shape
breakpoint()
x, y = skycoord_to_pixel(coord, con_wcs)
x, y = float(x), float(y)
if output_format == "xy":
Expand Down

0 comments on commit 56442f0

Please sign in to comment.