Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring of GainSelector #1093

Merged
merged 2 commits into from
Jun 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 75 additions & 138 deletions ctapipe/calib/camera/gainselection.py
Original file line number Diff line number Diff line change
@@ -1,173 +1,110 @@
"""
Algorithms to select correct gain channel
"""
from abc import ABCMeta, abstractclassmethod

from abc import abstractmethod
from enum import IntEnum
import numpy as np
from ctapipe.core import Component, traits

from ...core import Component, traits
from ...utils import get_table_dataset

__all__ = ['GainSelector',
'ThresholdGainSelector',
'SimpleGainSelector',
'pick_gain_channel']
__all__ = [
'GainChannel',
'GainSelector',
'ManualGainSelector',
'ThresholdGainSelector',
]


def pick_gain_channel(waveforms, threshold, select_by_sample=False):
class GainChannel(IntEnum):
"""
the PMTs on some cameras have 2 gain channels. select one
according to a threshold.

Parameters:
-----------
waveforms: np.ndarray
Array of shape (N_gain, N_pix, N_samp)
threshold: float
threshold (in PE/sample) of when to switch to low-gain channel
select_by_sample: bool
if true, select only low-gain *samples* when high-gain is over
threshold

Returns
-------
tuple:
gain-selected intensity, boolean array of which channel was chosen
Possible gain channels
"""

# if we have 2 channels:
if waveforms.shape[0] == 2:
waveforms = np.squeeze(waveforms)
new_waveforms = waveforms[0].copy()

if select_by_sample:
# replace any samples that are above threshold with low-gain ones:
gain_mask = waveforms[0] > threshold
new_waveforms[gain_mask] = waveforms[1][gain_mask]
else:
# use entire low-gain waveform if any sample of high-gain
# waveform is above threshold
gain_mask = (waveforms[0] > threshold).any(axis=1)
new_waveforms[gain_mask] = waveforms[1][gain_mask]

elif waveforms.shape[0] == 1:
new_waveforms = np.squeeze(waveforms)
gain_mask = np.zeros_like(new_waveforms).astype(bool)

else:
raise ValueError("input waveforms has shape %s. not sure what to do "
"with that.", waveforms.shape)

return new_waveforms, gain_mask
HIGH = 0
LOW = 1


class GainSelector(Component, metaclass=ABCMeta):
class GainSelector(Component):
"""
Base class for algorithms that reduce a 2-gain-channel waveform to a
single waveform.
"""
@abstractclassmethod
def select_gains(self, cam_id, multi_gain_waveform):

def __call__(self, waveforms):
"""
Takes an input waveform and cam_id and performs gain selection
Reduce the waveform to a single gain channel

Parameters
----------
waveforms : ndarray
Waveforms stored in a numpy array of shape
(n_chan, n_pix, n_samples).

Returns
-------
tuple(ndarray, ndarray):
(waveform, gain_mask), where the gain_mask is a boolean array of
which gain channel was used.
reduced_waveforms : ndarray
Waveform with a single channel
Shape: (n_pix, n_samples)
"""
pass
if waveforms.ndim == 2: # Return if already gain selected
return waveforms
elif waveforms.ndim == 3:
n_channels, n_pixels, _ = waveforms.shape
if n_channels == 1: # Reduce if already single channel
return waveforms[0]
else:
pixel_channel = self.select_channel(waveforms)
return waveforms[pixel_channel, np.arange(n_pixels)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a question, why not just 'return waveforms[pixel_channel, :]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately is isn't as simple to index the array like that. waveforms[pixel_channel, :] returns an array of shape (n_pix, n_pix, n_samples).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(When pixel_channel is an array of n_pix itself)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok

else:
raise ValueError(
f"Cannot handle waveform array of shape: {waveforms.ndim}"
)

@abstractmethod
def select_channel(self, waveforms):
"""
Abstract method to be defined by a GainSelector subclass.

class NullGainSelector(GainSelector):
"""
do no gain selection, leaving possibly 2 gain channels at the DL1 level.
this may break further steps in the chain if they do not expect 2 gains.
"""
Call the relevant functions to decide on the gain channel used for
each pixel.

def select_gains(self, cam_id, multi_gain_waveform):
return multi_gain_waveform, np.ones(multi_gain_waveform.shape[1])
Parameters
----------
waveforms : ndarray
Waveforms stored in a numpy array of shape
(n_chan, n_pix, n_samples).

Returns
-------
pixel_channel : ndarray
Gain channel to use for each pixel
Shape: n_pix
Dtype: int
"""

class SimpleGainSelector(GainSelector):

class ManualGainSelector(GainSelector):
"""
Simply choose a single gain channel always.
Manually choose a gain channel.
"""
channel = traits.CaselessStrEnum(
["HIGH", "LOW"],
default_value="HIGH",
help="Which gain channel to retain"
).tag(config=True)

channel = traits.Int(default_value=0, help="which gain channel to "
"retain").tag(config=True)

def select_gains(self, cam_id, multi_gain_waveform):
return (
multi_gain_waveform[self.channel],
(np.ones(multi_gain_waveform.shape[1]) * self.channel).astype(
np.bool)
)
def select_channel(self, waveforms):
return GainChannel[self.channel]


class ThresholdGainSelector(GainSelector):
"""
Select gain channel using fixed-threshold for any sample in the waveform.
The thresholds are loaded from an `astropy.table.Table` that must contain
two columns: `cam_id` (the name of the camera) and `gain_threshold_pe`,
the threshold in photo-electrons per sample at which the switch should
occur.

Parameters
----------
threshold_table_name: str
Name of gain channel table to load
select_by_sample: bool
If True, replaces only the waveform samples that are above
the threshold with low-gain versions, otherwise the full
low-gain waveform is used.

Attributes
----------
thresholds: dict
mapping of cam_id to threshold value
Select gain channel according to a maximum threshold value.
"""

threshold_table_name = traits.Unicode(
default_value='gain_channel_thresholds',
help='Name of gain channel table to load'
).tag(config=True)

select_by_sample = traits.Bool(
default_value=False,
help='If True, replaces only the waveform samples that are above '
'the threshold with low-gain versions, otherwise the full '
'low-gain waveform is used.'
threshold = traits.Float(
default_value=1000,
help="Threshold value in waveform sample units. If a waveform "
"contains a sample above this threshold, use the low gain "
"channel for that pixel."
).tag(config=True)

def __init__(self, config=None, parent=None, **kwargs):
super().__init__(config=config, parent=parent, **kwargs)

tab = get_table_dataset(
self.threshold_table_name,
role='dl0.tel.svc.gain_thresholds'
)
self.thresholds = dict(zip(tab['cam_id'], tab['gain_threshold_pe']))
self.log.debug("Loaded threshold table: \n %s", tab)

def __str__(self):
return f"{self.__class__.__name__}({self.thresholds})"

def select_gains(self, cam_id, multi_gain_waveform):

try:
threshold = self.thresholds[cam_id]
except KeyError:
raise KeyError(
"Camera ID '{}' not found in the gain-threshold "
"table '{}'".format(cam_id, self.threshold_table_name)
)

waveform, gain_mask = pick_gain_channel(
waveforms=multi_gain_waveform,
threshold=threshold,
select_by_sample=self.select_by_sample
)

return waveform, gain_mask
def select_channel(self, waveforms):
return (waveforms[0] > self.threshold).any(axis=1).astype(int)
Loading