diff --git a/notebooks/Kbmod_Reference.ipynb b/notebooks/Kbmod_Reference.ipynb index 40193c8dd..4ffed5ee0 100644 --- a/notebooks/Kbmod_Reference.ipynb +++ b/notebooks/Kbmod_Reference.ipynb @@ -181,7 +181,7 @@ "metadata": {}, "outputs": [], "source": [ - "from kbmod.data_interface import load_deccam_layered_image\n", + "from kbmod.file_utils import load_deccam_layered_image\n", "\n", "im = load_deccam_layered_image(im_file, p)\n", "print(f\"Loaded a {im.get_width()} by {im.get_height()} image at time {im.get_obstime()}\")" @@ -526,13 +526,6 @@ "# These top_results are all be duplicating searches on the same bright object we added.\n", "top_results[:20]" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/notebooks/kbmod_search_results_for_fakes.ipynb b/notebooks/kbmod_search_results_for_fakes.ipynb index a2b17f8f3..a7b2c52db 100644 --- a/notebooks/kbmod_search_results_for_fakes.ipynb +++ b/notebooks/kbmod_search_results_for_fakes.ipynb @@ -32,7 +32,6 @@ "import os\n", "\n", "from kbmod.analysis.plotting import *\n", - "from kbmod.data_interface import load_deccam_layered_image\n", "from kbmod.search import ImageStack, PSF, StampCreator, Trajectory\n", "from kbmod.results import Results\n", "from kbmod.work_unit import WorkUnit\n", @@ -260,9 +259,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Wilson’s KBMOD Analysis", + "display_name": "Jeremy's KBMOD", "language": "python", - "name": "wbeebe_kbmod_analysis" + "name": "kbmod_jk" }, "language_info": { "codemirror_mode": { @@ -274,7 +273,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.3" + "version": "3.12.2" } }, "nbformat": 4, diff --git a/notebooks/kbmod_visualize.ipynb b/notebooks/kbmod_visualize.ipynb index 6ac638985..fb7078402 100644 --- a/notebooks/kbmod_visualize.ipynb +++ b/notebooks/kbmod_visualize.ipynb @@ -30,7 +30,7 @@ "import os\n", "\n", "from kbmod.analysis.plotting import *\n", - "from kbmod.data_interface import load_deccam_layered_image\n", + "from kbmod.file_utils import load_deccam_layered_image\n", "from kbmod.search import ImageStack, PSF, StampCreator, Trajectory\n", "from kbmod.results import Results\n", "\n", diff --git a/src/kbmod/__init__.py b/src/kbmod/__init__.py index 36b76a347..9775f5324 100644 --- a/src/kbmod/__init__.py +++ b/src/kbmod/__init__.py @@ -87,7 +87,6 @@ def is_interactive(): from . import ( # noqa: F401 analysis, - data_interface, file_utils, filters, jointfit_functions, diff --git a/src/kbmod/data_interface.py b/src/kbmod/data_interface.py deleted file mode 100644 index 82c634832..000000000 --- a/src/kbmod/data_interface.py +++ /dev/null @@ -1,287 +0,0 @@ -import os - -from astropy.io import fits -from astropy.time import Time -from astropy.wcs import WCS -from itertools import product -import numpy as np -from pathlib import Path - -from kbmod.configuration import SearchConfiguration -from kbmod.file_utils import * -from kbmod.search import ImageStack, LayeredImage, PSF, RawImage, Logging -from kbmod.wcs_utils import append_wcs_to_hdu_header -from kbmod.work_unit import WorkUnit, raw_image_to_hdu - - -logger = Logging.getLogger(__name__) - - -def visit_from_file_name(filename): - """Automatically extract the visit ID from the file name. - - Uses the heuristic that the visit ID is the first numeric - string of at least length 5 digits in the file name. - - Parameters - ---------- - filename : str - The file name - - Returns - ------- - result : str - The visit ID string or None if there is no match. - """ - expr = re.compile(r"\d{4}(?:\d+)") - res = expr.search(filename) - if res is None: - return None - return res.group() - - -def load_deccam_layered_image(filename, psf): - """Load a layered image from the legacy deccam format. - - Parameters - ---------- - filename : `str` - The name of the file to load. - psf : `PSF` - The PSF to use for the image. - - Returns - ------- - img : `LayeredImage` - The loaded image. - - Raises - ------ - Raises a ``FileNotFoundError`` if the file does not exist. - Raises a ``ValueError`` if any of the validation checks fail. - """ - if not Path(filename).is_file(): - raise FileNotFoundError(f"{filename} not found") - - img = None - with fits.open(filename) as hdul: - if len(hdul) < 4: - raise ValueError("Not enough extensions for legacy deccam format") - - # Extract the obstime trying from a few keys and a few extensions. - obstime = -1.0 - for key, ext in product(["MJD", "DATE-AVG", "MJD-OBS"], [0, 1]): - if key in hdul[ext].header: - value = hdul[ext].header[key] - if type(value) is float: - obstime = value - break - if type(value) is str: - timesys = hdul[ext].header.get("TIMESYS", "UTC").lower() - obstime = Time(value, scale=timesys).mjd - break - - img = LayeredImage( - hdul[1].data.astype(np.float32), # Science - hdul[3].data.astype(np.float32), # Variance - hdul[2].data.astype(np.float32), # Mask - psf, - obstime, - ) - - return img - - -def load_input_from_individual_files( - im_filepath, - time_file, - psf_file, - mjd_lims, - default_psf, - verbose=False, -): - """This function loads images and ingests them into an ImageStack. - - Parameters - ---------- - im_filepath : `str` - Image file path from which to load images. - time_file : `str` - File name containing image times. - psf_file : `str` - File name containing the image-specific PSFs. - If set to None the code will use the provided default psf for - all images. - mjd_lims : `list` of ints - Optional MJD limits on the images to search. - default_psf : `PSF` - The default PSF in case no image-specific PSF is provided. - verbose : `bool` - Use verbose output (mainly for debugging). - - Returns - ------- - stack : `kbmod.ImageStack` - The stack of images loaded. - wcs_list : `list` - A list of `astropy.wcs.WCS` objects for each image. - visit_times : `list` - A list of MJD times. - """ - logger.info("Loading Images") - - # Load a mapping from visit numbers to the visit times. This dictionary stays - # empty if no time file is specified. - image_time_dict = FileUtils.load_time_dictionary(time_file) - logger.info(f"Loaded {len(image_time_dict)} time stamps.") - - # Load a mapping from visit numbers to PSFs. This dictionary stays - # empty if no time file is specified. - image_psf_dict = FileUtils.load_psf_dictionary(psf_file) - logger.info(f"Loaded {len(image_psf_dict)} image PSFs stamps.") - - # Retrieve the list of visits (file names) in the data directory. - patch_visits = sorted(os.listdir(im_filepath)) - - # Load the images themselves. - stack = ImageStack() - visit_times = [] - wcs_list = [] - for visit_file in np.sort(patch_visits): - # Skip non-fits files. - if not ".fits" in visit_file: - logger.info(f"Skipping non-FITS file {visit_file}") - continue - - # Compute the full file path for loading. - full_file_path = os.path.join(im_filepath, visit_file) - - # Try loading information from the FITS header. - visit_id = None - with fits.open(full_file_path) as hdu_list: - curr_wcs = WCS(hdu_list[1].header) - - # If the visit ID is in header (using Rubin tags), use for the visit ID. - # Otherwise extract it from the filename. - if "IDNUM" in hdu_list[0].header: - visit_id = str(hdu_list[0].header["IDNUM"]) - else: - name = os.path.split(full_file_path)[-1] - visit_id = visit_from_file_name(name) - - # Skip files without a valid visit ID. - if visit_id is None: - logger.warning(f"WARNING: Unable to extract visit ID for {visit_file}.") - continue - - # Check if the image has a specific PSF. - psf = default_psf - if visit_id in image_psf_dict: - psf = PSF(image_psf_dict[visit_id]) - - # Load the image file and set its time. - logger.info(f"Loading file: {full_file_path}") - img = load_deccam_layered_image(full_file_path, psf) - time_stamp = img.get_obstime() - - # Overload the header's time stamp if needed. - if visit_id in image_time_dict: - time_stamp = image_time_dict[visit_id] - img.set_obstime(time_stamp) - - if time_stamp <= 0.0: - logger.warning(f"WARNING: No valid timestamp provided for {visit_file}.") - continue - - # Check if we should filter the record based on the time bounds. - if mjd_lims is not None and (time_stamp < mjd_lims[0] or time_stamp > mjd_lims[1]): - logger.info(f"Pruning file {visit_file} by timestamp={time_stamp}.") - continue - - # Save image, time, and WCS information. The force move destroys img, so we should - # not use it after that point. - visit_times.append(time_stamp) - stack.append_image(img, force_move=True) - wcs_list.append(curr_wcs) - - logger.info(f"Loaded {len(stack)} images") - - return (stack, wcs_list, visit_times) - - -def load_input_from_config(config, verbose=False): - """This function loads images and ingests them into a WorkUnit. - - Parameters - ---------- - config : `SearchConfiguration` - The configuration with the individual file information. - verbose : `bool`, optional - Use verbose output (mainly for debugging). - - Returns - ------- - result : `kbmod.WorkUnit` - The input data as a ``WorkUnit``. - """ - stack, wcs_list, _ = load_input_from_individual_files( - config["im_filepath"], - config["time_file"], - config["psf_file"], - config["mjd_lims"], - PSF(config["psf_val"]), # Default PSF. - verbose=verbose, - ) - return WorkUnit(stack, config, None, None, wcs_list) - - -def load_input_from_file(filename, overrides=None): - """Build a WorkUnit from a single filename which could point to a WorkUnit - or configuration file. - - Parameters - ---------- - filename : `str` - The path and file name of the data to load. - overrides : `dict`, optional - A dictionary of configuration parameters to override. For testing. - - Returns - ------- - result : `kbmod.WorkUnit` - The input data as a ``WorkUnit``. - - Raises - ------ - ``ValueError`` if unable to read the data. - """ - path_var = Path(filename) - if not path_var.is_file(): - raise ValueError(f"File {filename} not found.") - - work = None - - path_suffix = path_var.suffix - if path_suffix == ".yml" or path_suffix == ".yaml": - # Try loading as a WorkUnit first. - with open(filename) as ff: - work = WorkUnit.from_yaml(ff.read()) - - # If that load did not work, try loading the file as a configuration - # and then using that to load the data files. - if work is None: - config = SearchConfiguration.from_file(filename) - if overrides is not None: - config.set_multiple(overrides) - if config["im_filepath"] is not None: - return load_input_from_config(config) - elif ".fits" in filename: - work = WorkUnit.from_fits(filename) - - # None of the load paths worked. - if work is None: - raise ValueError(f"Could not interprete {filename}.") - - if overrides is not None: - work.config.set_multiple(overrides) - return work diff --git a/src/kbmod/file_utils.py b/src/kbmod/file_utils.py index 2a44fc723..86e59b453 100644 --- a/src/kbmod/file_utils.py +++ b/src/kbmod/file_utils.py @@ -3,18 +3,73 @@ import csv import re from collections import OrderedDict +from itertools import product from math import copysign from pathlib import Path import astropy.units as u import numpy as np from astropy.coordinates import * +from astropy.io import fits from astropy.time import Time import kbmod.search as kb +from kbmod.search import LayeredImage from kbmod.trajectory_utils import trajectory_from_np_object +def load_deccam_layered_image(filename, psf): + """Load a layered image from the legacy deccam format. + + Parameters + ---------- + filename : `str` + The name of the file to load. + psf : `PSF` + The PSF to use for the image. + + Returns + ------- + img : `LayeredImage` + The loaded image. + + Raises + ------ + Raises a ``FileNotFoundError`` if the file does not exist. + Raises a ``ValueError`` if any of the validation checks fail. + """ + if not Path(filename).is_file(): + raise FileNotFoundError(f"{filename} not found") + + img = None + with fits.open(filename) as hdul: + if len(hdul) < 4: + raise ValueError("Not enough extensions for legacy deccam format") + + # Extract the obstime trying from a few keys and a few extensions. + obstime = -1.0 + for key, ext in product(["MJD", "DATE-AVG", "MJD-OBS"], [0, 1]): + if key in hdul[ext].header: + value = hdul[ext].header[key] + if type(value) is float: + obstime = value + break + if type(value) is str: + timesys = hdul[ext].header.get("TIMESYS", "UTC").lower() + obstime = Time(value, scale=timesys).mjd + break + + img = LayeredImage( + hdul[1].data.astype(np.float32), # Science + hdul[3].data.astype(np.float32), # Variance + hdul[2].data.astype(np.float32), # Mask + psf, + obstime, + ) + + return img + + class FileUtils: """A class of static methods for working with KBMOD files. @@ -81,78 +136,6 @@ def load_csv_to_list(file_name, use_dtype=None, none_if_missing=False): data.append(np.array(row, dtype=use_dtype)) return data - @staticmethod - def load_time_dictionary(time_file): - """Load a OrderedDict mapping ``visit_id`` to time stamp. - - Parameters - ---------- - time_file : str - The path and name of the time file. - - Returns - ------- - image_time_dict : OrderedDict - A mapping of visit ID to time stamp. - """ - # Load a mapping from visit numbers to the visit times. This dictionary stays - # empty if no time file is specified. - image_time_dict = OrderedDict() - if time_file is None or len(time_file) == 0: - return image_time_dict - - with open(time_file, "r") as csvfile: - reader = csv.reader(csvfile, delimiter=" ") - for row in reader: - if len(row[0]) < 2 or row[0][0] == "#": - continue - image_time_dict[row[0]] = float(row[1]) - return image_time_dict - - @staticmethod - def save_time_dictionary(time_file_name, time_mapping): - """Save the mapping of visit_id -> time stamp to a file. - - Parameters - ---------- - time_file_name : str - The path and name of the time file. - time_mapping : dict or OrderedDict - The mapping of visit ID to time stamp. - """ - with open(time_file_name, "w") as file: - file.write("# visit_id mean_julian_date\n") - for k in time_mapping.keys(): - file.write(f"{k} {time_mapping[k]}\n") - - @staticmethod - def load_psf_dictionary(psf_file): - """Load a OrderedDict mapping ``visit_id`` to PSF. - - Parameters - ---------- - psf_file : str - The path and name of the PSF file. - - Returns - ------- - psf_dict : OrderedDict - A mapping of visit ID to psf value. - """ - # Load a mapping from visit numbers to the visit times. This dictionary stays - # empty if no time file is specified. - psf_dict = OrderedDict() - if psf_file is None or len(psf_file) == 0: - return psf_dict - - with open(psf_file, "r") as csvfile: - reader = csv.reader(csvfile, delimiter=" ") - for row in reader: - if len(row[0]) < 2 or row[0][0] == "#": - continue - psf_dict[row[0]] = float(row[1]) - return psf_dict - @staticmethod def save_results_file(filename, results): """Save the result trajectories to a file. @@ -205,114 +188,3 @@ def load_results_file_as_trajectories(filename): np_results = FileUtils.load_results_file(filename) results = [trajectory_from_np_object(x) for x in np_results] return results - - @staticmethod - def mpc_reader(filename): - """Read in a file with observations in MPC format and return the coordinates. - - Parameters - ---------- - filename: str - The name of the file with the MPC-formatted observations. - - Returns - ------- - coords: astropy SkyCoord object - A SkyCoord object with the ra, dec of the observations. - times: astropy Time object - Times of the observations - """ - iso_times = [] - time_frac = [] - ra = [] - dec = [] - - with open(filename, "r") as f: - for line in f: - year = str(line[15:19]) - month = str(line[20:22]) - day = str(line[23:25]) - iso_times.append(str("%s-%s-%s" % (year, month, day))) - time_frac.append(str(line[25:31])) - ra.append(str(line[32:44])) - dec.append(str(line[44:56])) - - coords = SkyCoord(ra, dec, unit=(u.hourangle, u.deg)) - t = Time(iso_times) - t_obs = [] - for t_i, frac in zip(t, time_frac): - t_obs.append(t_i.mjd + float(frac)) - obs_times = Time(t_obs, format="mjd") - - return coords, obs_times - - @staticmethod - def format_result_mpc(coords, t, observatory="X05"): - """ - This method will take a single result in and return a corresponding - MPC formatted string. - - Parameters - ---------- - coords : SkyCoord - The sky coordinates of the observation. - t : Time - The time of the observation as an astropy Time object. - observatory : string - The three digit observatory code to use. - - Returns - ------- - mpc_line: string - An MPC-formatted string of the observation - """ - mjd_frac = t.mjd % 1.0 - ra_hms = coords.ra.hms - dec_dms = coords.dec.dms - - if dec_dms.d == 0: - if copysign(1, dec_dms.d) == -1.0: - dec_dms_d = "-00" - else: - dec_dms_d = "+00" - else: - dec_dms_d = "%+03i" % dec_dms.d - - mpc_line = " c111112 c%4i %02i %08.5f %02i %02i %06.3f%s %02i %05.2f %s" % ( - t.datetime.year, - t.datetime.month, - t.datetime.day + mjd_frac, - ra_hms.h, - ra_hms.m, - ra_hms.s, - dec_dms_d, - np.abs(dec_dms.m), - np.abs(dec_dms.s), - observatory, - ) - return mpc_line - - @staticmethod - def save_results_mpc(file_out, coords, times, observatory="X05"): - """ - Save the MPC-formatted observations to file. - - Parameters - ---------- - file_out: str - The output filename with the MPC-formatted observations - of the KBMOD search result. - coords : list of SkyCoord - A list of sky coordinates (SkyCoord objects) of the observation. - t : list of Time - A list of times for each observation. - observatory : string - The three digit observatory code to use. - """ - if len(times) != len(coords): - raise ValueError(f"Unequal lists {len(times)} != {len(coords)}") - - with open(file_out, "w") as f: - for i in range(len(times)): - mpc_line = FileUtils.format_result_mpc(coords[i], times[i], observatory) - f.write(mpc_line + "\n") diff --git a/src/kbmod/run_search.py b/src/kbmod/run_search.py index 2278b9177..31478eb62 100644 --- a/src/kbmod/run_search.py +++ b/src/kbmod/run_search.py @@ -7,7 +7,6 @@ import kbmod.search as kb from .configuration import SearchConfiguration -from .data_interface import load_input_from_config, load_input_from_file from .filters.clustering_filters import apply_clustering from .filters.sigma_g_filter import apply_clipped_sigma_g, SigmaGClipping from .filters.stamp_filters import append_all_stamps, append_coadds, get_coadds_and_filter_results @@ -328,45 +327,6 @@ def run_search_from_work_unit(self, work): # Run the search. return self.run_search(work.config, work.im_stack) - def run_search_from_config(self, config): - """Run a KBMOD search from a SearchConfiguration object - (or corresponding dictionary). - - Parameters - ---------- - config : `SearchConfiguration` or `dict` - The configuration object with all the information for the run. - - Returns - ------- - keep : `Results` - The results. - """ - if type(config) is dict: - config = SearchConfiguration.from_dict(config) - - # Load the data. - work = load_input_from_config(config) - return self.run_search_from_work_unit(work) - - def run_search_from_file(self, filename, overrides=None): - """Run a KBMOD search from a configuration or WorkUnit file. - - Parameters - ---------- - filename : `str` - The name of the input file. - overrides : `dict`, optional - A dictionary of configuration parameters to override. For testing. - - Returns - ------- - keep : `Results` - The results. - """ - work = load_input_from_file(filename, overrides) - return self.run_search_from_work_unit(work) - def _count_known_matches(self, result_list, search): """Look up the known objects that overlap the images and count how many are found among the results. diff --git a/tests/data/fake_psfs.dat b/tests/data/fake_psfs.dat deleted file mode 100644 index 8c50c5304..000000000 --- a/tests/data/fake_psfs.dat +++ /dev/null @@ -1,3 +0,0 @@ -# visit_id psf_val -000002 1.3 -000012 1.5 \ No newline at end of file diff --git a/tests/data/fake_times.dat b/tests/data/fake_times.dat deleted file mode 100644 index c2c100642..000000000 --- a/tests/data/fake_times.dat +++ /dev/null @@ -1,4 +0,0 @@ -# visit_id mean_julian_date -000003 57162.0 -000005 57172.0 -010006 100000.0 \ No newline at end of file diff --git a/tests/data/mpcs.txt b/tests/data/mpcs.txt deleted file mode 100644 index a08506175..000000000 --- a/tests/data/mpcs.txt +++ /dev/null @@ -1,3 +0,0 @@ - Fake1 C2001 05 05.03500 18 45 47.64 -24 27 20.3 16.6 R 706 - Fake2 2C1995 01 20.70000 23 29 42.54 -02 59 15.0 15.5 V 706 - Fake3 C2005 10 10.28000 01 45 13.20 +08 07 21.4 r 706 \ No newline at end of file diff --git a/tests/test_data_interface.py b/tests/test_data_interface.py deleted file mode 100644 index c545b16d0..000000000 --- a/tests/test_data_interface.py +++ /dev/null @@ -1,180 +0,0 @@ -from astropy.wcs import WCS -import logging -import os -import tempfile -import unittest -from yaml import dump - -from kbmod.configuration import SearchConfiguration -from kbmod.data_interface import ( - load_input_from_config, - load_input_from_file, - load_input_from_individual_files, - visit_from_file_name, -) -from kbmod.fake_data.fake_data_creator import create_fake_times, FakeDataSet -from kbmod.search import * -from kbmod.work_unit import WorkUnit -from utils.utils_for_tests import get_absolute_data_path - - -class test_data_interface(unittest.TestCase): - def setUp(self): - # Turn off WARNING-level logging for these tests since they will always - # generate a warning about a bad file (wrong_filename.fits) that has - # intentionally been inserted into the data as part of the test. - logging.basicConfig(level=logging.CRITICAL) - - def tearDown(self): - # Re-enable the WARNING-level logging. - logging.basicConfig(level=logging.WARNING) - - def test_visit_from_file_name(self): - visit = visit_from_file_name("m00005.fits") - self.assertEqual(visit, "00005") - - visit = visit_from_file_name("m654321.fits") - self.assertEqual(visit, "654321") - - # Too few digits - visit = visit_from_file_name("m005.fits") - self.assertIsNone(visit) - - # Nonsequential digits - visit = visit_from_file_name("m123x45.fits") - self.assertIsNone(visit) - - def test_file_load_basic(self): - stack, wcs_list, mjds = load_input_from_individual_files( - get_absolute_data_path("fake_images"), - None, - None, - [0, 157130.2], - PSF(1.0), - verbose=False, - ) - self.assertEqual(stack.img_count(), 4) - - # Check that each image loaded corrected. - true_times = [57130.2, 57130.21, 57130.22, 57131.2] - for i in range(stack.img_count()): - img = stack.get_single_image(i) - self.assertEqual(img.get_width(), 64) - self.assertEqual(img.get_height(), 64) - self.assertAlmostEqual(img.get_obstime(), true_times[i], delta=0.005) - self.assertAlmostEqual(1.0, img.get_psf().get_std()) - - def test_file_load_extra(self): - p = PSF(1.0) - - stack, wcs_list, mjds = load_input_from_individual_files( - get_absolute_data_path("fake_images"), - get_absolute_data_path("fake_times.dat"), - get_absolute_data_path("fake_psfs.dat"), - [0, 157130.2], - p, - verbose=False, - ) - self.assertEqual(stack.img_count(), 4) - - # Check that each image loaded corrected. - true_times = [57130.2, 57130.21, 57130.22, 57162.0] - psfs_std = [1.0, 1.0, 1.3, 1.0] - for i in range(stack.img_count()): - img = stack.get_single_image(i) - self.assertEqual(img.get_width(), 64) - self.assertEqual(img.get_height(), 64) - self.assertAlmostEqual(img.get_obstime(), true_times[i], delta=0.005) - self.assertAlmostEqual(psfs_std[i], img.get_psf().get_std()) - - def test_file_load_config(self): - config = SearchConfiguration() - config.set("im_filepath", get_absolute_data_path("fake_images")), - config.set("time_file", get_absolute_data_path("fake_times.dat")), - config.set("psf_file", get_absolute_data_path("fake_psfs.dat")), - config.set("psf_val", 1.0) - - worku = load_input_from_config(config, verbose=False) - - # Check that each image loaded corrected. - true_times = [57130.2, 57130.21, 57130.22, 57162.0] - psfs_std = [1.0, 1.0, 1.3, 1.0] - for i in range(worku.im_stack.img_count()): - img = worku.im_stack.get_single_image(i) - self.assertEqual(img.get_width(), 64) - self.assertEqual(img.get_height(), 64) - self.assertAlmostEqual(img.get_obstime(), true_times[i], delta=0.005) - self.assertAlmostEqual(psfs_std[i], img.get_psf().get_std()) - - # Try writing the configuration to a YAML file and loading. - with tempfile.TemporaryDirectory() as dir_name: - yaml_file_path = os.path.join(dir_name, "test_config.yml") - - with self.assertRaises(ValueError): - work_fits = load_input_from_file(yaml_file_path) - - config.to_file(yaml_file_path) - - work_yml = load_input_from_file(yaml_file_path) - self.assertIsNotNone(work_yml) - self.assertEqual(work_yml.im_stack.img_count(), 4) - - def test_file_load_workunit(self): - # Create a fake WCS - fake_wcs = WCS( - { - "WCSAXES": 2, - "CTYPE1": "RA---TAN-SIP", - "CTYPE2": "DEC--TAN-SIP", - "CRVAL1": 200.614997245422, - "CRVAL2": -7.78878863332778, - "CRPIX1": 1033.934327, - "CRPIX2": 2043.548284, - "CTYPE1A": "LINEAR ", - "CTYPE2A": "LINEAR ", - "CUNIT1A": "PIXEL ", - "CUNIT2A": "PIXEL ", - } - ) - fake_config = SearchConfiguration() - fake_times = create_fake_times(11, 57130.2, 10, 0.01, 1) - fake_data = FakeDataSet(64, 64, fake_times, use_seed=True) - work = WorkUnit(fake_data.stack, fake_config, fake_wcs, None) - - with tempfile.TemporaryDirectory() as dir_name: - # Save and load as FITS - fits_file_path = os.path.join(dir_name, "test_workunit.fits") - - with self.assertRaises(ValueError): - work_fits = load_input_from_file(fits_file_path) - - work.to_fits(fits_file_path) - - work_fits = load_input_from_file(fits_file_path) - self.assertIsNotNone(work_fits) - self.assertEqual(work_fits.im_stack.img_count(), 11) - - # Save and load as YAML - yaml_file_path = os.path.join(dir_name, "test_workunit.yml") - with open(yaml_file_path, "w") as file: - file.write(work.to_yaml()) - - work_yml = load_input_from_file(yaml_file_path) - self.assertIsNotNone(work_yml) - self.assertEqual(work_yml.im_stack.img_count(), 11) - - def test_file_load_invalid(self): - # Create a YAML file that is neither a configuration nor a WorkUnit. - yaml_str = dump({"Field1": 1, "Field2": False}) - - with tempfile.TemporaryDirectory() as dir_name: - yaml_file_path = os.path.join(dir_name, "test_invalid.yml") - with open(yaml_file_path, "w") as file: - file.write(yaml_str) - - with self.assertRaises(ValueError): - work = load_input_from_file(yaml_file_path) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_file_utils.py b/tests/test_file_utils.py index a75c5b243..0743d8cb2 100644 --- a/tests/test_file_utils.py +++ b/tests/test_file_utils.py @@ -43,36 +43,6 @@ def test_save_load_csv(self): with self.assertRaises(ValueError): FileUtils.save_csv_from_list(file_name, data2) - def test_load_times(self): - times = FileUtils.load_time_dictionary(get_absolute_data_path("fake_times.dat")) - self.assertEqual(len(times), 3) - self.assertTrue("000003" in times) - self.assertTrue("000005" in times) - self.assertTrue("010006" in times) - self.assertEqual(times["000003"], 57162.0) - self.assertEqual(times["000005"], 57172.0) - self.assertEqual(times["010006"], 100000.0) - - def test_save_times(self): - mapping = {"0001": 100.0, "0002": 110.0, "0003": 111.0} - with tempfile.TemporaryDirectory() as dir_name: - file_name = os.path.join(dir_name, "times.dat") - FileUtils.save_time_dictionary(file_name, mapping) - self.assertTrue(Path(file_name).is_file()) - - loaded = FileUtils.load_time_dictionary(file_name) - self.assertEqual(len(loaded), len(mapping)) - for k in mapping.keys(): - self.assertEqual(loaded[k], mapping[k]) - - def test_load_psfs(self): - psfs = FileUtils.load_psf_dictionary(get_absolute_data_path("fake_psfs.dat")) - self.assertEqual(len(psfs), 2) - self.assertTrue("000002" in psfs) - self.assertTrue("000012" in psfs) - self.assertEqual(psfs["000002"], 1.3) - self.assertEqual(psfs["000012"], 1.5) - def test_load_results(self): np_results = FileUtils.load_results_file(get_absolute_data_path("fake_results.txt")) self.assertEqual(len(np_results), 2) @@ -131,77 +101,6 @@ def test_save_and_load_single_result(self): self.assertEqual(loaded_trjs[0].vx, trj.vx) self.assertEqual(loaded_trjs[0].vy, trj.vy) - def test_load_mpc(self): - coords, obs_times = FileUtils.mpc_reader(get_absolute_data_path("mpcs.txt")) - - # Check the coordinates - self.assertEqual(len(coords), 3) - self.assertAlmostEqual(coords[0].ra.degree, 281.4485, delta=1e-4) - self.assertAlmostEqual(coords[0].dec.degree, -24.45564, delta=1e-4) - self.assertAlmostEqual(coords[1].ra.degree, 352.42725, delta=1e-4) - self.assertAlmostEqual(coords[1].dec.degree, -2.987500, delta=1e-4) - self.assertAlmostEqual(coords[2].ra.degree, 26.305, delta=1e-4) - self.assertAlmostEqual(coords[2].dec.degree, 8.122611, delta=1e-4) - - # Check the times - self.assertEqual(len(obs_times), 3) - self.assertAlmostEqual(obs_times[0].mjd, 52034.035, delta=1e-4) - self.assertAlmostEqual(obs_times[1].mjd, 49737.700, delta=1e-4) - self.assertAlmostEqual(obs_times[2].mjd, 53653.280, delta=1e-4) - - def test_format_mpc(self): - c = SkyCoord(281.4485, -24.45564, unit="deg") - t = Time(52034.035, format="mjd", scale="utc") - res = FileUtils.format_result_mpc(c, t) - self.assertEqual( - res, " c111112 c2001 05 05.03500 18 45 47.640-24 27 20.30 X05" - ) - - c = SkyCoord(352.42725, -0.45564, unit="deg") - t = Time(49737.700, format="mjd", scale="utc") - res = FileUtils.format_result_mpc(c, t) - self.assertEqual( - res, " c111112 c1995 01 20.70000 23 29 42.540-00 27 20.30 X05" - ) - - c = SkyCoord(26.305, 8.122611, unit="deg") - t = Time(53653.280, format="mjd", scale="utc") - res = FileUtils.format_result_mpc(c, t, observatory="001") - self.assertEqual( - res, " c111112 c2005 10 10.28000 01 45 13.200+08 07 21.40 001" - ) - - def test_save_and_load_mpcs(self): - coords = [ - SkyCoord(281.4485, -24.45564, unit="deg"), - SkyCoord(352.42725, -0.45564, unit="deg"), - SkyCoord(26.305, 8.122611, unit="deg"), - SkyCoord(0.42725, 0.45564, unit="deg"), - SkyCoord(0.0, 0.0, unit="deg"), - ] - times = [ - Time(52034.035, format="mjd", scale="utc"), - Time(49737.700, format="mjd", scale="utc"), - Time(53653.280, format="mjd", scale="utc"), - Time(52000.000, format="mjd", scale="utc"), - Time(53653.280, format="mjd", scale="utc"), - ] - num_res = len(times) - - with tempfile.TemporaryDirectory() as dir_name: - # Write out the data to a temporary file. - file_out = os.path.join(dir_name, "fake_mpc.txt") - FileUtils.save_results_mpc(file_out, coords, times) - - # Read back in the data. - res2, times2 = FileUtils.mpc_reader(file_out) - self.assertEqual(num_res, len(res2)) - self.assertEqual(num_res, len(times2)) - for i in range(num_res): - self.assertAlmostEqual(coords[i].ra.degree, res2[i].ra.degree, delta=1e-4) - self.assertAlmostEqual(coords[i].dec.degree, res2[i].dec.degree, delta=1e-4) - self.assertAlmostEqual(times[i].mjd, times2[i].mjd, delta=1e-4) - if __name__ == "__main__": unittest.main()