Skip to content

Commit

Permalink
Merge pull request #155 from wright-group/DicreteTune-call-array
Browse files Browse the repository at this point in the history
Discrete tune can be called with array arguments
  • Loading branch information
ddkohler authored Nov 9, 2022
2 parents 6e5c6c9 + 30a2d86 commit 868d221
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 5 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/).

## [Unreleased]

### Fixed
- Fixed bug where `DiscreteTune`s returned only one value when called (regardless of input shape)

### Added
- more documentation

Expand Down
37 changes: 32 additions & 5 deletions attune/_discrete_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Dict, Tuple, Optional

import WrightTools as wt
import numpy as np


class DiscreteTune:
Expand Down Expand Up @@ -34,13 +35,39 @@ def __init__(
def __repr__(self):
return f"DiscreteTune({repr(self.ranges)}, {repr(self.default)})"

def __call__(self, ind_value, *, ind_units=None, dep_units=None):
def __call__(self, ind_value, *, ind_units=None):
"""Evaluate the DiscreteTune at specific independent value(s).
Paramters
---------
ind_val: float-like or ndarray
The value or values at which to evaluate the DiscreteTune.
ind_units: Optional[str]
Units of the independent variable. Default is "nm".
Returns
-------
key: str or ndarray
The string identifier for the independent value.
For an array of ind_val, an array of identifiers is given.
"""
if ind_units is not None and self._ind_units is not None:
ind_value = wt.units.convert(ind_value, ind_units, self._ind_units)
for key, (min, max) in self.ranges.items():
if min <= ind_value <= max:
return key
return self.default
if isinstance(ind_value, np.ndarray):
out = np.full(
ind_value.shape,
self.default,
dtype=f"U{max([len(s) for s in self.ranges.keys()])}",
)
for key, (imin, imax) in self.ranges.items():
out[(ind_value >= imin) & (ind_value <= imax)] = key
return out
else:
for key, (imin, imax) in self.ranges.items():
if imin <= ind_value <= imax:
return key
return self.default

def __eq__(self, other):
return self.ranges == other.ranges and self.default == other.default
Expand Down

0 comments on commit 868d221

Please sign in to comment.