-
Notifications
You must be signed in to change notification settings - Fork 272
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactoring of GainSelector #1093
Merged
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,173 +1,110 @@ | ||
""" | ||
Algorithms to select correct gain channel | ||
""" | ||
from abc import ABCMeta, abstractclassmethod | ||
|
||
from abc import abstractmethod | ||
from enum import IntEnum | ||
import numpy as np | ||
from ctapipe.core import Component, traits | ||
|
||
from ...core import Component, traits | ||
from ...utils import get_table_dataset | ||
|
||
__all__ = ['GainSelector', | ||
'ThresholdGainSelector', | ||
'SimpleGainSelector', | ||
'pick_gain_channel'] | ||
__all__ = [ | ||
'GainChannel', | ||
'GainSelector', | ||
'ManualGainSelector', | ||
'ThresholdGainSelector', | ||
] | ||
|
||
|
||
def pick_gain_channel(waveforms, threshold, select_by_sample=False): | ||
class GainChannel(IntEnum): | ||
""" | ||
the PMTs on some cameras have 2 gain channels. select one | ||
according to a threshold. | ||
|
||
Parameters: | ||
----------- | ||
waveforms: np.ndarray | ||
Array of shape (N_gain, N_pix, N_samp) | ||
threshold: float | ||
threshold (in PE/sample) of when to switch to low-gain channel | ||
select_by_sample: bool | ||
if true, select only low-gain *samples* when high-gain is over | ||
threshold | ||
|
||
Returns | ||
------- | ||
tuple: | ||
gain-selected intensity, boolean array of which channel was chosen | ||
Possible gain channels | ||
""" | ||
|
||
# if we have 2 channels: | ||
if waveforms.shape[0] == 2: | ||
waveforms = np.squeeze(waveforms) | ||
new_waveforms = waveforms[0].copy() | ||
|
||
if select_by_sample: | ||
# replace any samples that are above threshold with low-gain ones: | ||
gain_mask = waveforms[0] > threshold | ||
new_waveforms[gain_mask] = waveforms[1][gain_mask] | ||
else: | ||
# use entire low-gain waveform if any sample of high-gain | ||
# waveform is above threshold | ||
gain_mask = (waveforms[0] > threshold).any(axis=1) | ||
new_waveforms[gain_mask] = waveforms[1][gain_mask] | ||
|
||
elif waveforms.shape[0] == 1: | ||
new_waveforms = np.squeeze(waveforms) | ||
gain_mask = np.zeros_like(new_waveforms).astype(bool) | ||
|
||
else: | ||
raise ValueError("input waveforms has shape %s. not sure what to do " | ||
"with that.", waveforms.shape) | ||
|
||
return new_waveforms, gain_mask | ||
HIGH = 0 | ||
LOW = 1 | ||
|
||
|
||
class GainSelector(Component, metaclass=ABCMeta): | ||
class GainSelector(Component): | ||
""" | ||
Base class for algorithms that reduce a 2-gain-channel waveform to a | ||
single waveform. | ||
""" | ||
@abstractclassmethod | ||
def select_gains(self, cam_id, multi_gain_waveform): | ||
|
||
def __call__(self, waveforms): | ||
""" | ||
Takes an input waveform and cam_id and performs gain selection | ||
Reduce the waveform to a single gain channel | ||
|
||
Parameters | ||
---------- | ||
waveforms : ndarray | ||
Waveforms stored in a numpy array of shape | ||
(n_chan, n_pix, n_samples). | ||
|
||
Returns | ||
------- | ||
tuple(ndarray, ndarray): | ||
(waveform, gain_mask), where the gain_mask is a boolean array of | ||
which gain channel was used. | ||
reduced_waveforms : ndarray | ||
Waveform with a single channel | ||
Shape: (n_pix, n_samples) | ||
""" | ||
pass | ||
if waveforms.ndim == 2: # Return if already gain selected | ||
return waveforms | ||
elif waveforms.ndim == 3: | ||
n_channels, n_pixels, _ = waveforms.shape | ||
if n_channels == 1: # Reduce if already single channel | ||
return waveforms[0] | ||
else: | ||
pixel_channel = self.select_channel(waveforms) | ||
return waveforms[pixel_channel, np.arange(n_pixels)] | ||
else: | ||
raise ValueError( | ||
f"Cannot handle waveform array of shape: {waveforms.ndim}" | ||
) | ||
|
||
@abstractmethod | ||
def select_channel(self, waveforms): | ||
""" | ||
Abstract method to be defined by a GainSelector subclass. | ||
|
||
class NullGainSelector(GainSelector): | ||
""" | ||
do no gain selection, leaving possibly 2 gain channels at the DL1 level. | ||
this may break further steps in the chain if they do not expect 2 gains. | ||
""" | ||
Call the relevant functions to decide on the gain channel used for | ||
each pixel. | ||
|
||
def select_gains(self, cam_id, multi_gain_waveform): | ||
return multi_gain_waveform, np.ones(multi_gain_waveform.shape[1]) | ||
Parameters | ||
---------- | ||
waveforms : ndarray | ||
Waveforms stored in a numpy array of shape | ||
(n_chan, n_pix, n_samples). | ||
|
||
Returns | ||
------- | ||
pixel_channel : ndarray | ||
Gain channel to use for each pixel | ||
Shape: n_pix | ||
Dtype: int | ||
""" | ||
|
||
class SimpleGainSelector(GainSelector): | ||
|
||
class ManualGainSelector(GainSelector): | ||
""" | ||
Simply choose a single gain channel always. | ||
Manually choose a gain channel. | ||
""" | ||
channel = traits.CaselessStrEnum( | ||
["HIGH", "LOW"], | ||
default_value="HIGH", | ||
help="Which gain channel to retain" | ||
).tag(config=True) | ||
|
||
channel = traits.Int(default_value=0, help="which gain channel to " | ||
"retain").tag(config=True) | ||
|
||
def select_gains(self, cam_id, multi_gain_waveform): | ||
return ( | ||
multi_gain_waveform[self.channel], | ||
(np.ones(multi_gain_waveform.shape[1]) * self.channel).astype( | ||
np.bool) | ||
) | ||
def select_channel(self, waveforms): | ||
return GainChannel[self.channel] | ||
|
||
|
||
class ThresholdGainSelector(GainSelector): | ||
""" | ||
Select gain channel using fixed-threshold for any sample in the waveform. | ||
The thresholds are loaded from an `astropy.table.Table` that must contain | ||
two columns: `cam_id` (the name of the camera) and `gain_threshold_pe`, | ||
the threshold in photo-electrons per sample at which the switch should | ||
occur. | ||
|
||
Parameters | ||
---------- | ||
threshold_table_name: str | ||
Name of gain channel table to load | ||
select_by_sample: bool | ||
If True, replaces only the waveform samples that are above | ||
the threshold with low-gain versions, otherwise the full | ||
low-gain waveform is used. | ||
|
||
Attributes | ||
---------- | ||
thresholds: dict | ||
mapping of cam_id to threshold value | ||
Select gain channel according to a maximum threshold value. | ||
""" | ||
|
||
threshold_table_name = traits.Unicode( | ||
default_value='gain_channel_thresholds', | ||
help='Name of gain channel table to load' | ||
).tag(config=True) | ||
|
||
select_by_sample = traits.Bool( | ||
default_value=False, | ||
help='If True, replaces only the waveform samples that are above ' | ||
'the threshold with low-gain versions, otherwise the full ' | ||
'low-gain waveform is used.' | ||
threshold = traits.Float( | ||
default_value=1000, | ||
help="Threshold value in waveform sample units. If a waveform " | ||
"contains a sample above this threshold, use the low gain " | ||
"channel for that pixel." | ||
).tag(config=True) | ||
|
||
def __init__(self, config=None, parent=None, **kwargs): | ||
super().__init__(config=config, parent=parent, **kwargs) | ||
|
||
tab = get_table_dataset( | ||
self.threshold_table_name, | ||
role='dl0.tel.svc.gain_thresholds' | ||
) | ||
self.thresholds = dict(zip(tab['cam_id'], tab['gain_threshold_pe'])) | ||
self.log.debug("Loaded threshold table: \n %s", tab) | ||
|
||
def __str__(self): | ||
return f"{self.__class__.__name__}({self.thresholds})" | ||
|
||
def select_gains(self, cam_id, multi_gain_waveform): | ||
|
||
try: | ||
threshold = self.thresholds[cam_id] | ||
except KeyError: | ||
raise KeyError( | ||
"Camera ID '{}' not found in the gain-threshold " | ||
"table '{}'".format(cam_id, self.threshold_table_name) | ||
) | ||
|
||
waveform, gain_mask = pick_gain_channel( | ||
waveforms=multi_gain_waveform, | ||
threshold=threshold, | ||
select_by_sample=self.select_by_sample | ||
) | ||
|
||
return waveform, gain_mask | ||
def select_channel(self, waveforms): | ||
return (waveforms[0] > self.threshold).any(axis=1).astype(int) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a question, why not just 'return waveforms[pixel_channel, :]?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately is isn't as simple to index the array like that.
waveforms[pixel_channel, :]
returns an array of shape (n_pix, n_pix, n_samples).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(When pixel_channel is an array of n_pix itself)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah ok