Skip to content

Commit

Permalink
Merge pull request #424 from dirac-institute/masking2
Browse files Browse the repository at this point in the history
Clean up masking logic
  • Loading branch information
jeremykubica authored Jan 12, 2024
2 parents 3a22534 + 53fb617 commit ecef117
Show file tree
Hide file tree
Showing 17 changed files with 368 additions and 696 deletions.
263 changes: 62 additions & 201 deletions src/kbmod/masking.py
Original file line number Diff line number Diff line change
@@ -1,95 +1,12 @@
"""Classes for performing masking on images from FITS files.
ImageMasker provides an abstract base class that can be overridden to define masking
algorithms for specific studies, instruments, or FITS headers. Specific masking classes
are provided to support common masking operations including: masking based on a bit vector,
masking based on a dictionary, masking based on a threshold, and growing a current mask.
"""Functions for performing masking operations as specified in the configuration.
"""

import abc

import kbmod.search as kb


def apply_mask_operations(stack, mask_list):
"""Apply a series of masking operations defined by a list of
ImageMasker objects.
Parameters
----------
stack : `kbmod.ImageStack`
The stack before the masks have been applied.
mask_list : `list`
A list of mask_list objects.
Returns
-------
stack : `kbmod.ImageStack`
The same stack object to allow chaining.
"""
for mask in mask_list:
stack = mask.apply_mask(stack)
return stack


class ImageMasker(abc.ABC):
"""The base class for masking operations."""

def __init__(self, *args, **kwargs):
pass

@abc.abstractmethod
def apply_mask(self, stack):
"""Apply the mask to an image stack.
Parameters
----------
stack : `kbmod.ImageStack`
The stack before the masks have been applied.
Returns
-------
stack : `kbmod.ImageStack`
The same stack object to allow chaining.
"""
pass


class BitVectorMasker(ImageMasker):
"""Apply a mask given a bit vector of masking flags to use
and vector of bit vectors to ignore.
Attributes
----------
flags : `int`
A bit vector of masking flags to apply.
"""

def __init__(self, flags, *args, **kwargs):
super().__init__(*args, **kwargs)
self.flags = flags

def apply_mask(self, stack):
"""Apply the mask to an image stack.
Parameters
----------
stack : `kbmod.ImageStack`
The stack before the masks have been applied.
Returns
-------
stack : `kbmod.ImageStack`
The same stack object to allow chaining.
"""
if self.flags != 0:
stack.apply_mask_flags(self.flags)
return stack


class DictionaryMasker(BitVectorMasker):
"""Apply a mask given a dictionary of masking condition to key
and a list of masking conditions to use.
def mask_flags_from_dict(mask_bits_dict, flag_keys):
"""Generate a bitmask integer of flag keys from a dictionary
of masking reasons and a list of reasons to use.
Attributes
----------
Expand All @@ -98,126 +15,70 @@ class DictionaryMasker(BitVectorMasker):
number in the masking bit vector.
flag_keys : `list`
A list of masking keys to use.
"""

def __init__(self, mask_bits_dict, flag_keys, *args, **kwargs):
self.mask_bits_dict = mask_bits_dict
self.flag_keys = flag_keys

# Convert the dictionary into a bit vector.
bitvector = 0
for bit in self.flag_keys:
bitvector += 2 ** self.mask_bits_dict[bit]

# Initialize the BitVectorMasker parameters.
super().__init__(bitvector, *args, **kwargs)

class GlobalDictionaryMasker(ImageMasker):
"""Apply a mask given a dictionary of masking condition to key
and a list of masking conditions to use. Masks pixels in every image
if they are masked in *multiple* images in the stack.
Attributes
----------
mask_bits_dict : `dict`
A dictionary mapping a masking key (string) to the bit
number in the masking bit vector.
global_flag_keys : `list`
A list of masking keys to use.
mask_num_images : `int`
The number of images that need to be masked in the stack
to apply the mask to all images.
Returns
-------
bitmask : `int`
The bitmask to use for masking operations.
"""
bitmask = 0
for bit in flag_keys:
bitmask += 2 ** mask_bits_dict[bit]
return bitmask

def __init__(self, mask_bits_dict, global_flag_keys, mask_num_images, *args, **kwargs):
super().__init__(*args, **kwargs)

self.mask_bits_dict = mask_bits_dict
self.global_flag_keys = global_flag_keys
self.mask_num_images = mask_num_images
def apply_mask_operations(config, stack):
"""Perform all the masking operations based on the search's configuration parameters.
# Convert the dictionary into a bit vector.
self.global_flags = 0
for bit in self.global_flag_keys:
self.global_flags += 2 ** self.mask_bits_dict[bit]

def apply_mask(self, stack):
"""Apply the mask to an image stack.
Parameters
----------
stack : `kbmod.ImageStack`
The stack before the masks have been applied.
Returns
-------
stack : `kbmod.ImageStack`
The same stack object to allow chaining.
"""
if self.global_flags != 0:
stack.apply_global_mask(self.global_flags, self.mask_num_images)
return stack


class ThresholdMask(ImageMasker):
"""Mask pixels over a given value.
Attributes
Parameters
----------
mask_threshold : `float`
The flux threshold for a pixel.
"""

def __init__(self, mask_threshold, *args, **kwargs):
super().__init__(*args, **kwargs)
self.mask_threshold = mask_threshold

def apply_mask(self, stack):
"""Apply the mask to an image stack.
Parameters
----------
stack : `kbmod.ImageStack`
The stack before the masks have been applied.
config : `SearchConfiguration`
The configuration parameters
stack : `ImageStack`
The stack before the masks have been applied. Modified in-place.
Returns
-------
stack : `kbmod.ImageStack`
The same stack object to allow chaining.
"""
stack.apply_mask_threshold(self.mask_threshold)
return stack


class GrowMask(ImageMasker):
"""Apply a mask that grows the current max out a given number of pixels.
Attributes
----------
num_pixels : `int`
The number of pixels to extend the mask.
Returns
-------
stack : `ImageStack`
The stack after the masks have been applied.
"""
# Generate the global mask before we start modifying the individual masks.
if config["repeated_flag_keys"] and len(config["repeated_flag_keys"]) > 0:
global_flags = mask_flags_from_dict(config["mask_bits_dict"], config["repeated_flag_keys"])
global_binary_mask = stack.make_global_mask(global_flags, config["mask_num_images"])
else:
global_binary_mask = None

# Start by creating a binary mask out of the primary flag values. Prioritize
# the config's mask_bit_vector over the dictionary based version.
if config["mask_bit_vector"]:
mask_flags = config["mask_bit_vector"]
elif config["flag_keys"] and len(config["flag_keys"]) > 0:
mask_flags = mask_flags_from_dict(config["mask_bits_dict"], config["flag_keys"])
else:
mask_flags = 0

# Apply the primary mask.
for i in range(stack.img_count()):
stack.get_single_image(i).binarize_mask(mask_flags)

# If the threshold is set, mask those pixels.
if config["mask_threshold"]:
for i in range(stack.img_count()):
stack.get_single_image(i).union_threshold_masking(config["mask_threshold"])

# Union in the global masking if there was one.
if global_binary_mask is not None:
for i in range(stack.img_count()):
stack.get_single_image(i).union_masks(global_binary_mask)

# Grow the masks.
if config["mask_grow"] and config["mask_grow"] > 0:
for i in range(stack.img_count()):
stack.get_single_image(i).grow_mask(config["mask_grow"])

# Apply the masks to the images. Use 0xFFFFFF to apply all active masking bits.
for i in range(stack.img_count()):
stack.get_single_image(i).apply_mask(0xFFFFFF)

def __init__(self, num_pixels, *args, **kwargs):
super().__init__(*args, **kwargs)

if num_pixels <= 0:
raise ValueError(f"Invalid num_pixels={num_pixels} for GrowMask")
self.num_pixels = num_pixels

def apply_mask(self, stack):
"""Apply the mask to an image stack.
Parameters
----------
stack : `kbmod.ImageStack`
The stack before the masks have been applied.
Returns
-------
stack : `kbmod.ImageStack`
The same stack object to allow chaining.
"""
stack.grow_mask(self.num_pixels)
return stack
return stack
57 changes: 2 additions & 55 deletions src/kbmod/run_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,7 @@
from .analysis_utils import PostProcess
from .data_interface import load_input_from_config
from .configuration import SearchConfiguration
from .masking import (
BitVectorMasker,
DictionaryMasker,
GlobalDictionaryMasker,
GrowMask,
ThresholdMask,
apply_mask_operations,
)
from .masking import apply_mask_operations
from .result_list import *
from .filters.sigma_g_filter import SigmaGClipping
from .work_unit import WorkUnit
Expand All @@ -34,52 +27,6 @@ class SearchRunner:
def __init__(self):
pass

def do_masking(self, config, stack):
"""Perform the masking based on the search's configuration parameters.
Parameters
----------
config : `SearchConfiguration`
The configuration parameters
stack : `ImageStack`
The stack before the masks have been applied. Modified in-place.
Returns
-------
stack : `ImageStack`
The stack after the masks have been applied.
"""
mask_steps = []

# Prioritize the mask_bit_vector over the dictionary based version.
if config["mask_bit_vector"]:
mask_steps.append(BitVectorMasker(config["mask_bit_vector"]))
elif config["flag_keys"] and len(config["flag_keys"]) > 0:
mask_steps.append(DictionaryMasker(config["mask_bits_dict"], config["flag_keys"]))

# Add the threshold mask if it is set.
if config["mask_threshold"]:
mask_steps.append(ThresholdMask(config["mask_threshold"]))

# Add the global masking if it is set.
if config["repeated_flag_keys"] and len(config["repeated_flag_keys"]) > 0:
mask_steps.append(
GlobalDictionaryMasker(
config["mask_bits_dict"],
config["repeated_flag_keys"],
config["mask_num_images"],
)
)

# Grow the mask.
if config["mask_grow"] and config["mask_grow"] > 0:
mask_steps.append(GrowMask(config["mask_grow"]))

# Apply the masks.
stack = apply_mask_operations(stack, mask_steps)

return stack

def get_angle_limits(self, config):
"""Compute the angle limits based on the configuration information.
Expand Down Expand Up @@ -199,7 +146,7 @@ def run_search(self, config, stack):

# Apply the mask to the images.
if config["do_mask"]:
stack = self.do_masking(config, stack)
stack = apply_mask_operations(config, stack)

# Perform the actual search.
search = kb.StackSearch(stack)
Expand Down
Loading

0 comments on commit ecef117

Please sign in to comment.