Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update _discrete_tune.py #155

Merged
merged 10 commits into from
Nov 9, 2022
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/).

## [Unreleased]

### Fixed
- Fixed bug where `DiscreteTune`s returned only one value when called (regardless of input shape)

### Added
- more documentation

Expand Down
37 changes: 32 additions & 5 deletions attune/_discrete_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Dict, Tuple, Optional

import WrightTools as wt
import numpy as np


class DiscreteTune:
Expand Down Expand Up @@ -34,13 +35,39 @@ def __init__(
def __repr__(self):
return f"DiscreteTune({repr(self.ranges)}, {repr(self.default)})"

def __call__(self, ind_value, *, ind_units=None, dep_units=None):
def __call__(self, ind_value, *, ind_units=None):
"""Evaluate the DiscreteTune at specific independent value(s).

Paramters
---------
ind_val: float-like or ndarray
The value or values at which to evaluate the DiscreteTune.
ind_units: Optional[str]
Units of the independent variable. Default is "nm".

Returns
-------
key: str or ndarray
The string identifier for the independent value.
For an array of ind_val, an array of identifiers is given.

"""
if ind_units is not None and self._ind_units is not None:
ind_value = wt.units.convert(ind_value, ind_units, self._ind_units)
for key, (min, max) in self.ranges.items():
if min <= ind_value <= max:
return key
return self.default
if isinstance(ind_value, np.ndarray):
out = np.full(
ind_value.shape,
self.default,
dtype=f"U{max([len(s) for s in self.ranges.keys()])}",
)
for key, (imin, imax) in self.ranges.items():
out[(ind_value >= imin) & (ind_value <= imax)] = key
return out
else:
for key, (imin, imax) in self.ranges.items():
if imin <= ind_value <= imax:
return key
return self.default

def __eq__(self, other):
return self.ranges == other.ranges and self.default == other.default
Expand Down