Skip to content

Commit

Permalink
Merge pull request #349 from dirac-institute/remove_img_info
Browse files Browse the repository at this point in the history
Remove ImageInfo
  • Loading branch information
jeremykubica authored Sep 27, 2023
2 parents 95701ed + 656916e commit af3c35a
Show file tree
Hide file tree
Showing 12 changed files with 142 additions and 713 deletions.
1 change: 0 additions & 1 deletion src/kbmod/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
analysis_utils,
file_utils,
filters,
image_info,
jointfit_functions,
result_list,
run_search,
Expand Down
75 changes: 41 additions & 34 deletions src/kbmod/analysis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import os
import time

from astropy.io import fits
from astropy.wcs import WCS
import numpy as np
from scipy.special import erfinv # import mpmath

Expand All @@ -10,7 +12,6 @@
from .file_utils import *
from .filters.clustering_filters import DBSCANFilter
from .filters.stats_filters import *
from .image_info import *
from .result_list import ResultList, ResultRow


Expand Down Expand Up @@ -54,8 +55,10 @@ def load_images(
-------
stack : `kbmod.ImageStack`
The stack of images loaded.
img_info : `ImageInfo`
The information for the images loaded.
wcs_list : `list`
A list of `astropy.wcs.WCS` objects for each image.
visit_times : `list`
A list of MJD times.
"""
print("---------------------------------------")
print("Loading Images")
Expand All @@ -77,9 +80,9 @@ def load_images(
patch_visits = sorted(os.listdir(im_filepath))

# Load the images themselves.
img_info = ImageInfoSet()
images = []
visit_times = []
wcs_list = []
for visit_file in np.sort(patch_visits):
# Skip non-fits files.
if not ".fits" in visit_file:
Expand All @@ -90,25 +93,40 @@ def load_images(
# Compute the full file path for loading.
full_file_path = os.path.join(im_filepath, visit_file)

# Load the image info from the FITS header.
header_info = ImageInfo()
header_info.populate_from_fits_file(full_file_path)
# Try loading information from the FITS header.
visit_id = None
with fits.open(full_file_path) as hdu_list:
curr_wcs = WCS(hdu_list[1].header)

# If the visit ID is in header (using Rubin tags), use for the visit ID.
# Otherwise extract it from the filename.
if "IDNUM" in hdu_list[0].header:
visit_id = str(hdu_list[0].header["IDNUM"])
else:
name = os.path.split(full_file_path)[-1]
visit_id = FileUtils.visit_from_file_name(name)

# Skip files without a valid visit ID.
if header_info.visit_id is None:
if visit_id is None:
if verbose:
print(f"WARNING: Unable to extract visit ID for {visit_file}.")
continue

# Compute the time stamp as a MJD float. If there is an entry in the
# timestamp file, defer to that. Otherwise use the value from the header.
time_stamp = -1.0
if header_info.visit_id in image_time_dict:
time_stamp = image_time_dict[header_info.visit_id]
else:
time_obj = header_info.get_epoch(none_if_unset=True)
if time_obj is not None:
time_stamp = time_obj.mjd
# Check if the image has a specific PSF.
psf = default_psf
if visit_id in image_psf_dict:
psf = kb.PSF(image_psf_dict[visit_id])

# Load the image file and set its time.
if verbose:
print(f"Loading file: {full_file_path}")
img = kb.LayeredImage(full_file_path, psf)
time_stamp = img.get_obstime()

# Overload the header's time stamp if needed.
if visit_id in image_time_dict:
time_stamp = image_time_dict[visit_id]
img.set_obstime(time_stamp)

if time_stamp <= 0.0:
if verbose:
Expand All @@ -121,32 +139,21 @@ def load_images(
print(f"Pruning file {visit_file} by timestamp={time_stamp}.")
continue

# Check if the image has a specific PSF.
psf = default_psf
if header_info.visit_id in image_psf_dict:
psf = kb.PSF(image_psf_dict[header_info.visit_id])

# Load the image file and set its time.
if verbose:
print(f"Loading file: {full_file_path}")
img = kb.LayeredImage(full_file_path, psf)
img.set_obstime(time_stamp)

# Save the file, time, and image information.
img_info.append(header_info)
# Save image, time, and WCS information.
visit_times.append(time_stamp)
images.append(img)
wcs_list.append(curr_wcs)

print(f"Loaded {len(images)} images")
stack = kb.ImageStack(images)

# Create a list of visit times and visit times shifted to 0.0.
img_info.set_times_mjd(np.array(visit_times))
times = img_info.get_zero_shifted_times()
stack.set_times(times)
min_time = min(visit_times)
zero_shifted = [(t - min_time) for t in visit_times]
stack.set_times(zero_shifted)
print("Times set", flush=True)

return (stack, img_info)
return (stack, wcs_list, visit_times)


class PostProcess:
Expand Down
Loading

0 comments on commit af3c35a

Please sign in to comment.