Skip to content

Commit

Permalink
Merge pull request #144 from tjgalvin/weightstrim
Browse files Browse the repository at this point in the history
Trim the weights image from linmos
  • Loading branch information
tjgalvin authored Jul 17, 2024
2 parents 4bead77 + 0f6d114 commit ed39fae
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
- Adaptive colour bar scaling in the rms validation plot
- Create multiple linmos images if `--fixed-beam-shape` specified, one at an optimal common resolution and another at the specified resolution
- Dump the `FieldOptions` to the output science directory
- Weights produced by `linmos` are also trimmed in the same way as the corresponding image

## 0.2.4

Expand Down
57 changes: 51 additions & 6 deletions flint/coadd/linmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from argparse import ArgumentParser
from pathlib import Path
from typing import Collection, List, NamedTuple, Optional
from typing import Collection, List, NamedTuple, Optional, Tuple

import numpy as np
from astropy.io import fits
Expand Down Expand Up @@ -36,6 +36,8 @@ class BoundingBox(NamedTuple):
"""Minimum y pixel"""
ymax: int
"""Maximum y pixel"""
original_shape: Tuple[int, int]
"""The original shape of the image"""


def create_bound_box(image_data: np.ndarray, is_masked: bool = False) -> BoundingBox:
Expand Down Expand Up @@ -65,15 +67,29 @@ def create_bound_box(image_data: np.ndarray, is_masked: bool = False) -> Boundin
xmin, xmax = np.where(x_valid)[0][[0, -1]]
ymin, ymax = np.where(y_valid)[0][[0, -1]]

return BoundingBox(xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax)
return BoundingBox(
xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax, original_shape=image_data.shape
)


class TrimImageResult(NamedTuple):
"""The constructed path and the bounding box"""

path: Path
"""The path to the trimmed image"""
bounding_box: BoundingBox
"""The bounding box that was applied to the image"""

def trim_fits_image(image_path: Path) -> Path:

def trim_fits_image(
image_path: Path, bounding_box: Optional[BoundingBox] = None
) -> TrimImageResult:
"""Trim the FITS image produces by linmos to remove as many empty pixels around
the border of the image as possible. This is an inplace operation.
Args:
image_path (Path): The FITS image that will have its border trimmed
bounding_box (Optional[BoundingBox], optional): The bounding box that will be applied to the image. If None it is computed. Defaults to None.
Returns:
Path: Path of the FITS image that had its border trimmed
Expand All @@ -83,7 +99,18 @@ def trim_fits_image(image_path: Path) -> Path:
data = fits_image[0].data
logger.info(f"Original data shape: {data.shape}")

bounding_box = create_bound_box(image_data=np.squeeze(data), is_masked=False)
image_shape = (data.shape[-2], data.shape[-1])
logger.info(f"The image dimensions are: {image_shape}")

if not bounding_box:
bounding_box = create_bound_box(
image_data=np.squeeze(data), is_masked=False
)
else:
if image_shape != bounding_box.original_shape:
raise ValueError(
f"Bounding box constructed against {bounding_box.original_shape}, but being applied to {image_shape=}"
)

data = data[
...,
Expand All @@ -99,7 +126,7 @@ def trim_fits_image(image_path: Path) -> Path:

fits.writeto(filename=image_path, data=data, header=header, overwrite=True)

return image_path
return TrimImageResult(path=image_path, bounding_box=bounding_box)


def get_image_weight(
Expand Down Expand Up @@ -386,7 +413,11 @@ def linmos_images(
)

# Trim the fits image to remove empty pixels
trim_fits_image(image_path=linmos_names.image_fits)
image_trim_results = trim_fits_image(image_path=linmos_names.image_fits)
trim_fits_image(
image_path=linmos_names.weight_fits,
bounding_box=image_trim_results.bounding_box,
)

return linmos_cmd

Expand Down Expand Up @@ -431,6 +462,13 @@ def get_parser() -> ArgumentParser:
help="Path to the container with yandasoft tools",
)

trim_parser = subparsers.add_parser(
"trim", help="Generate a yandasoft linmos parset"
)

trim_parser.add_argument(
"images", type=Path, nargs="+", help="The images that will be trimmed"
)
return parser


Expand Down Expand Up @@ -458,6 +496,13 @@ def cli() -> None:
holofile=args.holofile,
container=args.yandasoft_container,
)
elif args.mode == "trim":
images = args.images
logger.info(f"Will be trimming {len(images)}")
for image in images:
trim_fits_image(image_path=Path(image))
else:
logger.error(f"Unrecognised mode: {args.mode}")


if __name__ == "__main__":
Expand Down
38 changes: 36 additions & 2 deletions tests/test_linmos_coadd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
some of trhe helper functions around it.
"""

from pathlib import Path

import numpy as np
import pytest
from astropy.io import fits

from flint.coadd.linmos import BoundingBox, create_bound_box, trim_fits_image


def create_fits_image(out_path):
data = np.zeros((1000, 1000))
def create_fits_image(out_path, image_size=(1000, 1000)):
data = np.zeros(image_size)
data[10:600, 20:500] = 1
data[data == 0] = np.nan

Expand Down Expand Up @@ -38,6 +41,37 @@ def test_trim_fits(tmp_path):
assert trim_data.shape == (589, 479)


def test_trim_fits_image_matching(tmp_path):
"""See the the bounding box can be passed through for matching to cutout"""

tmp_dir = Path(tmp_path) / "image_bb_match"
tmp_dir.mkdir()

out_fits = tmp_dir / "example.fits"

create_fits_image(out_fits)
og_trim = trim_fits_image(out_fits)

out_fits2 = tmp_dir / "example2.fits"
create_fits_image(out_fits2)
og_hdr = fits.getheader(out_fits2)
assert og_hdr["CRPIX1"] == 10
assert og_hdr["CRPIX2"] == 20

trim_fits_image(image_path=out_fits2, bounding_box=og_trim.bounding_box)
trim_hdr = fits.getheader(out_fits2)
trim_data = fits.getdata(out_fits2)
assert trim_hdr["CRPIX1"] == -10
assert trim_hdr["CRPIX2"] == 10
assert trim_data.shape == (589, 479)

out_fits2 = tmp_dir / "example3.fits"
create_fits_image(out_fits2, image_size=(300, 300))

with pytest.raises(ValueError):
trim_fits_image(image_path=out_fits2, bounding_box=og_trim.bounding_box)


def test_bounding_box():
data = np.zeros((1000, 1000))
data[10:600, 20:500] = 1
Expand Down

0 comments on commit ed39fae

Please sign in to comment.