Skip to content

Commit

Permalink
remove unneeded data handling from WorkUnit (#702)
Browse files Browse the repository at this point in the history
  • Loading branch information
maxwest-uw authored Sep 11, 2024
1 parent abe04ca commit 32fb22f
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 285 deletions.
200 changes: 4 additions & 196 deletions src/kbmod/work_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,10 @@ def get_wcs(self, img_num):
return per_img

def get_all_obstimes(self):
"""Return a list of the observation times."""
"""Return a list of the observation times.
If the `WorkUnit` was lazily loaded, then the obstimes
have already been preloaded. Otherwise, grab them from
the `ImageStack."""
if self._obstimes is not None:
return self._obstimes

Expand Down Expand Up @@ -342,153 +345,6 @@ def from_fits(cls, filename, show_progress=None):
)
return result

@classmethod
def from_dict(cls, workunit_dict):
"""Create a WorkUnit from a combined dictionary.
Parameters
----------
workunit_dict : `dict`
The dictionary of information.
Returns
-------
`WorkUnit`
Raises
------
Raises a ``ValueError`` for any invalid parameters.
"""
num_images = workunit_dict["num_images"]
logger.debug(f"Creating WorkUnit from dictionary with {num_images} images.")

width = workunit_dict["width"]
height = workunit_dict["height"]
if width <= 0 or height <= 0:
raise ValueError(f"Illegal image dimensions width={width}, height={height}")

# Load the configuration supporting both dictionary and SearchConfiguration.
if type(workunit_dict["config"]) is dict:
config = SearchConfiguration.from_dict(workunit_dict["config"])
elif type(workunit_dict["config"]) is SearchConfiguration:
config = workunit_dict["config"]
else:
raise ValueError("Unrecognized type for WorkUnit config parameter.")

# Load the global WCS if one exists.
if "wcs" in workunit_dict:
if type(workunit_dict["wcs"]) is dict:
global_wcs = wcs_from_dict(workunit_dict["wcs"])
else:
global_wcs = workunit_dict["wcs"]
else:
global_wcs = None

constituent_images = workunit_dict["constituent_images"]
heliocentric_distance = workunit_dict["heliocentric_distance"]
geocentric_distances = workunit_dict["geocentric_distances"]
reprojected = workunit_dict["reprojected"]
per_image_indices = workunit_dict["per_image_indices"]

imgs = []
per_image_wcs = []
per_image_ebd_wcs = []
for i in range(num_images):
obs_time = workunit_dict["times"][i]

if type(workunit_dict["sci_imgs"][i]) is RawImage:
sci_img = workunit_dict["sci_imgs"][i]
else:
sci_arr = np.array(workunit_dict["sci_imgs"][i], dtype=np.float32).reshape(height, width)
sci_img = RawImage(img=sci_arr, obs_time=obs_time)

if type(workunit_dict["var_imgs"][i]) is RawImage:
var_img = workunit_dict["var_imgs"][i]
else:
var_arr = np.array(workunit_dict["var_imgs"][i], dtype=np.float32).reshape(height, width)
var_img = RawImage(img=var_arr, obs_time=obs_time)

# Masks are optional.
if workunit_dict["msk_imgs"][i] is None:
msk_arr = np.zeros(height, width)
msk_img = RawImage(img=msk_arr, obs_time=obs_time)
elif type(workunit_dict["msk_imgs"][i]) is RawImage:
msk_img = workunit_dict["msk_imgs"][i]
else:
msk_arr = np.array(workunit_dict["msk_imgs"][i], dtype=np.float32).reshape(height, width)
msk_img = RawImage(img=msk_arr, obs_time=obs_time)

# PSFs are optional.
if workunit_dict["psfs"][i] is None:
p = PSF()
elif type(workunit_dict["psfs"][i]) is PSF:
p = workunit_dict["psfs"][i]
else:
p = PSF(np.array(workunit_dict["psfs"][i], dtype=np.float32))

imgs.append(LayeredImage(sci_img, var_img, msk_img, p))

n_constituents = len(constituent_images)
for i in range(n_constituents):
# Read a per_image_wcs if one exists.
current_wcs = workunit_dict["per_image_wcs"][i]
if type(current_wcs) is dict:
current_wcs = wcs_from_dict(current_wcs)
per_image_wcs.append(current_wcs)

current_ebd = workunit_dict["per_image_ebd_wcs"][i]
if type(current_ebd) is dict:
current_ebd = wcs_from_dict(current_ebd)
per_image_ebd_wcs.append(current_ebd)

im_stack = ImageStack(imgs)
return WorkUnit(
im_stack=im_stack,
config=config,
wcs=global_wcs,
constituent_images=constituent_images,
per_image_wcs=per_image_wcs,
per_image_ebd_wcs=per_image_ebd_wcs,
heliocentric_distance=heliocentric_distance,
geocentric_distances=geocentric_distances,
reprojected=reprojected,
per_image_indices=per_image_indices,
)

@classmethod
def from_yaml(cls, work_unit, strict=False):
"""Load a configuration from a YAML string.
Parameters
----------
work_unit : `str` or `_io.TextIOWrapper`
The serialized YAML data.
strict : `bool`
Raise an error if the file is not a WorkUnit.
Returns
-------
result : `WorkUnit` or `None`
Returns the extracted WorkUnit. If the file did not contain a WorkUnit and
strict=False the function will return None.
Raises
------
Raises a ``ValueError`` for any invalid parameters.
"""
yaml_dict = safe_load(work_unit)

# Check if this a WorkUnit yaml file by checking it has the required fields.
required_fields = ["config", "height", "num_images", "sci_imgs", "times", "var_imgs", "width"]
for name in required_fields:
if name not in yaml_dict:
if strict:
raise ValueError(f"Missing required field {name}")
else:
return None

return WorkUnit.from_dict(yaml_dict)

def to_fits(self, filename, overwrite=False):
"""Write the WorkUnit to a single FITS file.
Expand Down Expand Up @@ -804,54 +660,6 @@ def append_all_wcs(self, hdul):
ebd_hdu.name = f"EBD_{i}"
hdul.append(ebd_hdu)

def to_yaml(self):
"""Serialize the WorkUnit as a YAML string.
Returns
-------
result : `str`
The serialized YAML string.
"""
workunit_dict = {
"num_images": self.im_stack.img_count(),
"width": self.im_stack.get_width(),
"height": self.im_stack.get_height(),
"config": self.config._params,
"wcs": wcs_to_dict(self.wcs),
# Per image data
"times": [],
"sci_imgs": [],
"var_imgs": [],
"msk_imgs": [],
"psfs": [],
"constituent_images": self.constituent_images,
"per_image_wcs": [],
"per_image_ebd_wcs": [],
"heliocentric_distance": self.heliocentric_distance,
"geocentric_distances": self.geocentric_distances,
"reprojected": self.reprojected,
"per_image_indices": self._per_image_indices,
}

# Fill in the per-image data.
for i in range(self.im_stack.img_count()):
layered = self.im_stack.get_single_image(i)
workunit_dict["times"].append(layered.get_obstime())
p = layered.get_psf()

workunit_dict["sci_imgs"].append(layered.get_science().image.tolist())
workunit_dict["var_imgs"].append(layered.get_variance().image.tolist())
workunit_dict["msk_imgs"].append(layered.get_mask().image.tolist())

psf_array = np.array(p.get_kernel()).reshape((p.get_dim(), p.get_dim()))
workunit_dict["psfs"].append(psf_array.tolist())

for i in range(len(self._per_image_wcs)):
workunit_dict["per_image_wcs"].append(wcs_to_dict(self._per_image_wcs[i]))
workunit_dict["per_image_ebd_wcs"].append(wcs_to_dict(self._per_image_ebd_wcs[i]))

return dump(workunit_dict)

def image_positions_to_original_icrs(
self, image_indices, positions, input_format="xy", output_format="xy", filter_in_frame=True
):
Expand Down
89 changes: 0 additions & 89 deletions tests/test_work_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,72 +150,6 @@ def test_create(self):
self.assertIsNotNone(work3.get_wcs(i))
self.assertTrue(wcs_fits_equal(work3.get_wcs(i), self.diff_wcs[i]))

def test_create_from_dict(self):
for use_python_types in [True, False]:
if use_python_types:
work_unit_dict = {
"num_images": self.num_images,
"width": self.width,
"height": self.height,
"config": self.config._params,
"times": [self.images[i].get_obstime() for i in range(self.num_images)],
"sci_imgs": [self.images[i].get_science().image for i in range(self.num_images)],
"var_imgs": [self.images[i].get_variance().image for i in range(self.num_images)],
"msk_imgs": [self.images[i].get_mask().image for i in range(self.num_images)],
"psfs": [np.array(p.get_kernel()).reshape((p.get_dim(), p.get_dim())) for p in self.p],
"per_image_wcs": self.diff_wcs,
"per_image_ebd_wcs": [None] * self.num_images,
"heliocentric_distance": None,
"geocentric_distances": [None] * self.num_images,
"reprojected": False,
"wcs": None,
"constituent_images": [f"img_{i}" for i in range(self.num_images)],
"per_image_indices": [[i] for i in range(self.num_images)],
}
else:
work_unit_dict = {
"num_images": self.num_images,
"width": self.width,
"height": self.height,
"config": self.config,
"times": [self.images[i].get_obstime() for i in range(self.num_images)],
"sci_imgs": [self.images[i].get_science() for i in range(self.num_images)],
"var_imgs": [self.images[i].get_variance() for i in range(self.num_images)],
"msk_imgs": [self.images[i].get_mask() for i in range(self.num_images)],
"psfs": self.p,
"per_image_wcs": self.diff_wcs,
"per_image_ebd_wcs": [None] * self.num_images,
"heliocentric_distance": None,
"geocentric_distances": [None] * self.num_images,
"reprojected": False,
"wcs": None,
"constituent_images": [f"img_{i}" for i in range(self.num_images)],
"per_image_indices": [[i] for i in range(self.num_images)],
}

with self.subTest(i=use_python_types):
work = WorkUnit.from_dict(work_unit_dict)
self.assertEqual(work.im_stack.img_count(), self.num_images)
self.assertEqual(work.im_stack.get_width(), self.width)
self.assertEqual(work.im_stack.get_height(), self.height)
self.assertIsNone(work.wcs)
self.assertFalse(work.has_common_wcs())
for i in range(self.num_images):
layered1 = work.im_stack.get_single_image(i)
layered2 = self.im_stack.get_single_image(i)

self.assertTrue(layered1.get_science().l2_allclose(layered2.get_science(), 0.01))
self.assertTrue(layered1.get_variance().l2_allclose(layered2.get_variance(), 0.01))
self.assertTrue(layered1.get_mask().l2_allclose(layered2.get_mask(), 0.01))
self.assertEqual(layered1.get_obstime(), layered2.get_obstime())

self.assertIsNotNone(work.get_wcs(i))
self.assertTrue(wcs_fits_equal(work.get_wcs(i), self.diff_wcs[i]))

self.assertTrue(type(work.config) is SearchConfiguration)
self.assertEqual(work.config["im_filepath"], "Here")
self.assertEqual(work.config["num_obs"], 5)

def test_save_and_load_fits(self):
with tempfile.TemporaryDirectory() as dir_name:
file_path = os.path.join(dir_name, "test_workunit.fits")
Expand Down Expand Up @@ -383,29 +317,6 @@ def test_save_and_load_fits_global_wcs(self):
self.assertIsNotNone(work2.get_wcs(i))
self.assertTrue(wcs_fits_equal(work2.get_wcs(i), self.wcs))

def test_to_from_yaml(self):
# Create WorkUnit with only global WCS.
work = WorkUnit(self.im_stack, self.config, self.wcs, None)
yaml_str = work.to_yaml()

work2 = WorkUnit.from_yaml(yaml_str)
self.assertEqual(work2.im_stack.img_count(), self.num_images)
self.assertEqual(work2.im_stack.get_width(), self.width)
self.assertEqual(work2.im_stack.get_height(), self.height)
self.assertIsNotNone(work2.wcs)
for i in range(self.num_images):
layered1 = work2.im_stack.get_single_image(i)
layered2 = self.im_stack.get_single_image(i)

self.assertTrue(layered1.get_science().l2_allclose(layered2.get_science(), 0.01))
self.assertTrue(layered1.get_variance().l2_allclose(layered2.get_variance(), 0.01))
self.assertTrue(layered1.get_mask().l2_allclose(layered2.get_mask(), 0.01))
self.assertAlmostEqual(layered1.get_obstime(), layered2.get_obstime())

# Check that we read in the configuration values correctly.
self.assertEqual(work2.config["im_filepath"], "Here")
self.assertEqual(work2.config["num_obs"], self.num_images)

def test_image_positions_to_original_icrs_invalid_format(self):
work = WorkUnit(
im_stack=self.im_stack,
Expand Down

0 comments on commit 32fb22f

Please sign in to comment.