Skip to content

Commit

Permalink
Merge pull request #524 from dirac-institute/proposals/logging
Browse files Browse the repository at this point in the history
Add Logging module to the code.
  • Loading branch information
DinoBektesevic authored Mar 18, 2024
2 parents f3882ff + d51c8c0 commit 4fec2d1
Show file tree
Hide file tree
Showing 17 changed files with 417 additions and 100 deletions.
File renamed without changes.
59 changes: 54 additions & 5 deletions src/kbmod/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import warnings

try:
from ._version import version as __version__
from ._version import version as __version__ # noqa: F401
except ImportError:
warnings.warn("Unable to determine the package version. " "This is likely a broken installation.")

# This is needed for compatibility with some older compilers: erfinv needs to be
# imported before other packages though I'm not sure why.
from scipy.special import erfinv
import os
import time
import logging as _logging
from logging import config as _config

from . import (
# Import the rest of the package
from kbmod.search import Logging
from . import ( # noqa: F401
analysis,
data_interface,
file_utils,
Expand All @@ -22,3 +25,49 @@
from .search import PSF, RawImage, LayeredImage, ImageStack, StackSearch
from .standardizers import Standardizer, StandardizerConfig
from .image_collection import ImageCollection


# there are ways for this to go to a file, but is it worth it?
# Then we have to roll a whole logging.config_from_shared_config thing
_SHARED_LOGGING_CONFIG = {
"level": os.environ.get("KB_LOG_LEVEL", "WARNING"),
"format": "[%(asctime)s %(levelname)s %(name)s] %(message)s",
"datefmt": "%Y-%m-%dT%H:%M:%SZ",
"converter": "gmtime",
}

# Declare our own root logger, so that we don't start printing DEBUG
# messages from every package we import
__PY_LOGGING_CONFIG = {
"version": 1.0,
"formatters": {
"standard": {
"format": _SHARED_LOGGING_CONFIG["format"],
},
},
"handlers": {
"default": {
"level": _SHARED_LOGGING_CONFIG["level"],
"formatter": "standard",
"class": "logging.StreamHandler",
"stream": "ext://sys.stderr",
}
},
"loggers": {
"kbmod": {
"handlers": ["default"],
"level": _SHARED_LOGGING_CONFIG["level"],
}
},
}

# The timezone converter can not be configured via the config submodule for
# some reason, only directly. Must be configured after loading the dictConfig
_config.dictConfig(__PY_LOGGING_CONFIG)
if _SHARED_LOGGING_CONFIG["converter"] == "gmtime":
_logging.Formatter.converter = time.gmtime
else:
_logging.Formatter.converter = time.localtime

# Configure the CPP logging wrapper with the same setup
Logging().setConfig(_SHARED_LOGGING_CONFIG)
7 changes: 5 additions & 2 deletions src/kbmod/analysis/create_stamps.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@
import numpy as np
from astropy.io import fits

from kbmod.search import Logging
from kbmod.file_utils import *


logger = Logging.getLogger(__name__)


class CreateStamps(object):
def __init__(self):
return
Expand Down Expand Up @@ -117,8 +121,7 @@ def max_value_stamp_filter(self, stamps, center_thresh, verbose=True):
An np array of stamp indices to keep.
"""
keep_stamps = np.where(np.max(stamps, axis=1) > center_thresh)[0]
if verbose:
print("Center filtering keeps %i out of %i stamps." % (len(keep_stamps), len(stamps)))
logger.info(f"Center filtering keeps {len(keep_stamps)} out of {len(stamps)} stamps.")
return keep_stamps

def load_results(self, res_filename):
Expand Down
9 changes: 6 additions & 3 deletions src/kbmod/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
from pathlib import Path
import yaml
from yaml import dump, safe_load
from kbmod.search import Logging


logger = Logging.getLogger(__name__)


class SearchConfiguration:
Expand Down Expand Up @@ -120,8 +124,7 @@ def set(self, param, value, strict=True):
if param not in self._params:
if strict:
raise KeyError(f"Invalid parameter: {param}")
else:
print(f"Ignoring invalid parameter: {param}")
logger.warning(f"Ignoring invalid parameter: {param}")
else:
self._params[param] = value

Expand Down Expand Up @@ -281,7 +284,7 @@ def to_file(self, filename, overwrite=False):
Indicates whether to overwrite an existing file.
"""
if Path(filename).is_file() and not overwrite:
print(f"Warning: Configuration file {filename} already exists.")
logger.warning(f"Configuration file {filename} already exists.")
return

with open(filename, "w") as file:
Expand Down
40 changes: 16 additions & 24 deletions src/kbmod/data_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,14 @@

from kbmod.configuration import SearchConfiguration
from kbmod.file_utils import *
from kbmod.search import (
ImageStack,
LayeredImage,
PSF,
RawImage,
)
from kbmod.search import ImageStack, LayeredImage, PSF, RawImage, Logging
from kbmod.wcs_utils import append_wcs_to_hdu_header
from kbmod.work_unit import WorkUnit, raw_image_to_hdu


logger = Logging.getLogger(__name__)


def load_deccam_layered_image(filename, psf):
"""Load a layered image from the legacy deccam format.
Expand Down Expand Up @@ -110,6 +108,9 @@ def save_deccam_layered_image(img, filename, wcs=None, overwrite=True):
hdul.writeto(filename, overwrite=overwrite)


logger = kb.Logging.getLogger(__name__)


def load_input_from_individual_files(
im_filepath,
time_file,
Expand Down Expand Up @@ -146,21 +147,17 @@ def load_input_from_individual_files(
visit_times : `list`
A list of MJD times.
"""
print("---------------------------------------")
print("Loading Images")
print("---------------------------------------")
logger.info("Loading Images")

# Load a mapping from visit numbers to the visit times. This dictionary stays
# empty if no time file is specified.
image_time_dict = FileUtils.load_time_dictionary(time_file)
if verbose:
print(f"Loaded {len(image_time_dict)} time stamps.")
logger.info(f"Loaded {len(image_time_dict)} time stamps.")

# Load a mapping from visit numbers to PSFs. This dictionary stays
# empty if no time file is specified.
image_psf_dict = FileUtils.load_psf_dictionary(psf_file)
if verbose:
print(f"Loaded {len(image_psf_dict)} image PSFs stamps.")
logger.info(f"Loaded {len(image_psf_dict)} image PSFs stamps.")

# Retrieve the list of visits (file names) in the data directory.
patch_visits = sorted(os.listdir(im_filepath))
Expand All @@ -172,8 +169,7 @@ def load_input_from_individual_files(
for visit_file in np.sort(patch_visits):
# Skip non-fits files.
if not ".fits" in visit_file:
if verbose:
print(f"Skipping non-FITS file {visit_file}")
logger.info(f"Skipping non-FITS file {visit_file}")
continue

# Compute the full file path for loading.
Expand All @@ -194,8 +190,7 @@ def load_input_from_individual_files(

# Skip files without a valid visit ID.
if visit_id is None:
if verbose:
print(f"WARNING: Unable to extract visit ID for {visit_file}.")
logger.warning(f"WARNING: Unable to extract visit ID for {visit_file}.")
continue

# Check if the image has a specific PSF.
Expand All @@ -204,8 +199,7 @@ def load_input_from_individual_files(
psf = PSF(image_psf_dict[visit_id])

# Load the image file and set its time.
if verbose:
print(f"Loading file: {full_file_path}")
logger.info(f"Loading file: {full_file_path}")
img = load_deccam_layered_image(full_file_path, psf)
time_stamp = img.get_obstime()

Expand All @@ -215,22 +209,20 @@ def load_input_from_individual_files(
img.set_obstime(time_stamp)

if time_stamp <= 0.0:
if verbose:
print(f"WARNING: No valid timestamp provided for {visit_file}.")
logger.warning(f"WARNING: No valid timestamp provided for {visit_file}.")
continue

# Check if we should filter the record based on the time bounds.
if mjd_lims is not None and (time_stamp < mjd_lims[0] or time_stamp > mjd_lims[1]):
if verbose:
print(f"Pruning file {visit_file} by timestamp={time_stamp}.")
logger.info(f"Pruning file {visit_file} by timestamp={time_stamp}.")
continue

# Save image, time, and WCS information.
visit_times.append(time_stamp)
images.append(img)
wcs_list.append(curr_wcs)

print(f"Loaded {len(images)} images")
logger.info(f"Loaded {len(images)} images")
stack = ImageStack(images)

return (stack, wcs_list, visit_times)
Expand Down
3 changes: 2 additions & 1 deletion src/kbmod/fake_data/fake_data_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from kbmod.data_interface import save_deccam_layered_image
from kbmod.file_utils import *
from kbmod.search import *
from kbmod.search import Logging
from kbmod.wcs_utils import append_wcs_to_hdu_header
from kbmod.work_unit import WorkUnit

Expand Down Expand Up @@ -279,7 +280,7 @@ def save_fake_data_to_dir(self, data_dir):
# Make the subdirectory if needed.
dir_path = Path(data_dir)
if not dir_path.is_dir():
print("Directory '%s' does not exist. Creating." % data_dir)
logger.info(f"Directory {data_dir} does not exist. Creating.")
os.mkdir(data_dir)

# Save each of the image files.
Expand Down
27 changes: 13 additions & 14 deletions src/kbmod/filters/stamp_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,13 @@
StampCreator,
StampParameters,
StampType,
Logging,
)


logger = Logging.getLogger(__name__)


class BaseStampFilter(abc.ABC):
"""The base class for the various stamp filters.
Expand Down Expand Up @@ -327,16 +331,12 @@ def get_coadds_and_filter(result_list, im_stack, stamp_params, chunk_size=100000
if type(stamp_params) is SearchConfiguration:
stamp_params = extract_search_parameters_from_config(stamp_params)

if debug:
print("---------------------------------------")
print("Applying Stamp Filtering")
print("---------------------------------------")
if result_list.num_results() <= 0:
print("Skipping. Nothing to filter.")
else:
print(f"Stamp filtering {result_list.num_results()} results.")
print(stamp_params)
print(f"Using chunksize = {chunk_size}")
if result_list.num_results() <= 0:
logger.debug("Stamp Filtering : skipping, othing to filter.")
else:
logger.debug(f"Stamp filtering {result_list.num_results()} results.")
logger.debug(f"Using filtering params: {stamp_params}")
logger.debug(f"Using chunksize = {chunk_size}")

# Run the stamp creation and filtering in batches of chunk_size.
start_time = time.time()
Expand Down Expand Up @@ -382,10 +382,9 @@ def get_coadds_and_filter(result_list, im_stack, stamp_params, chunk_size=100000

# Do the actual filtering of results
result_list.filter_results(all_valid_inds)
if debug:
print("Keeping %i results" % result_list.num_results(), flush=True)
time_elapsed = time.time() - start_time
print("{:.2f}s elapsed".format(time_elapsed))

logger.debug(f"Keeping {result_list.num_results()} results")
logger.debug("{:.2f}s elapsed".format(time.time() - start_time))


def append_all_stamps(result_list, im_stack, stamp_radius):
Expand Down
Loading

0 comments on commit 4fec2d1

Please sign in to comment.