From d1e36c82bddbe3222622eedce1f9c530c49bd536 Mon Sep 17 00:00:00 2001 From: Alexander Puck Neuwirth Date: Sat, 16 Nov 2024 13:09:02 +0100 Subject: [PATCH] Add PlottableProtocol --- src/babyyoda/axis.py | 40 ++++++++++++++++++++++++++++++++ src/babyyoda/grogu/histo2d_v3.py | 4 ++-- src/babyyoda/histo1d.py | 21 +++++++++++------ src/babyyoda/histo2d.py | 22 +++++++++++------- 4 files changed, 70 insertions(+), 17 deletions(-) create mode 100644 src/babyyoda/axis.py diff --git a/src/babyyoda/axis.py b/src/babyyoda/axis.py new file mode 100644 index 0000000..57e458e --- /dev/null +++ b/src/babyyoda/axis.py @@ -0,0 +1,40 @@ +from collections.abc import Iterator + +from uhi.typing.plottable import PlottableAxisGeneric, PlottableTraits + + +class UHITraits(PlottableTraits): + @property + def circular(self) -> bool: + return False + + @property + def discrete(self) -> bool: + return False + + +class UHIAxis(PlottableAxisGeneric[tuple[float, float]]): + @property + def traits(self) -> UHITraits: + return UHITraits() + + def __init__(self, values: list[tuple[float, float]]): + self.values = values + + # access axis[i] + def __getitem__(self, i: int) -> tuple[float, float]: + return self.values[i] + + def __len__(self) -> int: + return len(self.values) + + def __eq__(self, other: object) -> bool: + if isinstance(other, UHIAxis): + return self.values == other.values # noqa: PD011 + return False + + def __iter__(self) -> Iterator[tuple[float, float]]: + return iter(self.values) + + def index(self, value: tuple[float, float]) -> int: + return self.values.index(value) diff --git a/src/babyyoda/grogu/histo2d_v3.py b/src/babyyoda/grogu/histo2d_v3.py index ab1000a..6192d7c 100644 --- a/src/babyyoda/grogu/histo2d_v3.py +++ b/src/babyyoda/grogu/histo2d_v3.py @@ -245,9 +245,9 @@ def yMin(self) -> float: assert min(self.yEdges()) == self.yEdges()[0], "yMin is not the first edge" return self.yEdges()[0] - def bins(self, includeOverflows: bool = False) -> np.ndarray: + def bins(self, includeOverflows: bool = False) -> np.typing.NDArray[Any]: if includeOverflows: - return self.d_bins + return np.array(self.d_bins) # TODO consider represent data always as numpy return ( np.array(self.d_bins) diff --git a/src/babyyoda/histo1d.py b/src/babyyoda/histo1d.py index 2efdb05..ffdce34 100644 --- a/src/babyyoda/histo1d.py +++ b/src/babyyoda/histo1d.py @@ -4,9 +4,13 @@ import mplhep as hep import numpy as np +from uhi.typing.plottable import ( + PlottableHistogram, +) import babyyoda from babyyoda.analysisobject import UHIAnalysisObject +from babyyoda.axis import UHIAxis from babyyoda.util import loc, overflow, project, rebin, rebinBy_to_rebinTo, underflow @@ -40,7 +44,10 @@ def Histo1D(*args: Any, **kwargs: Any) -> "UHIHisto1D": # TODO make this implementation independent (no V2 or V3...) -class UHIHisto1D(UHIAnalysisObject): +class UHIHisto1D( + UHIAnalysisObject, + PlottableHistogram, +): ###### # Minimum required functions ###### @@ -170,7 +177,7 @@ def overflow(self) -> Any: def underflow(self) -> Any: return self.bins(includeOverflows=True)[0] - def errWs(self) -> np.ndarray: + def errWs(self) -> Any: return np.sqrt(np.array([b.sumW2() for b in self.bins()])) def xMins(self) -> list[float]: @@ -218,21 +225,21 @@ def dVols(self) -> list[float]: ######################################################## @property - def axes(self) -> list[list[tuple[float, float]]]: - return [list(zip(self.xMins(), self.xMaxs()))] + def axes(self) -> list[UHIAxis]: + return [UHIAxis(list(zip(self.xMins(), self.xMaxs())))] @property def kind(self) -> str: # TODO reevaluate this return "COUNT" - def counts(self) -> np.ndarray: + def counts(self) -> np.typing.NDArray[Any]: return np.array([b.numEntries() for b in self.bins()]) - def values(self) -> np.ndarray: + def values(self) -> np.typing.NDArray[Any]: return np.array([b.sumW() for b in self.bins()]) - def variances(self) -> np.ndarray: + def variances(self) -> np.typing.NDArray[Any]: return np.array([(b.sumW2()) for b in self.bins()]) def __getitem__( diff --git a/src/babyyoda/histo2d.py b/src/babyyoda/histo2d.py index ee59886..ead8086 100644 --- a/src/babyyoda/histo2d.py +++ b/src/babyyoda/histo2d.py @@ -4,8 +4,12 @@ import mplhep as hep import numpy as np +from uhi.typing.plottable import ( + PlottableHistogram, +) from babyyoda.analysisobject import UHIAnalysisObject +from babyyoda.axis import UHIAxis from babyyoda.util import ( loc, overflow, @@ -48,7 +52,7 @@ def Histo2D(*args: list[Any], **kwargs: list[Any]) -> "UHIHisto2D": return grogu.Histo2D(*args, **kwargs) -class UHIHisto2D(UHIAnalysisObject): +class UHIHisto2D(UHIAnalysisObject, PlottableHistogram): ###### # Minimum required functions ###### @@ -235,10 +239,10 @@ def dVols(self) -> list[float]: ######################################################## @property - def axes(self) -> list[list[tuple[float, float]]]: + def axes(self) -> list[UHIAxis]: return [ - list(zip(self.xMins(), self.xMaxs())), - list(zip(self.yMins(), self.yMaxs())), + UHIAxis(list(zip(self.xMins(), self.xMaxs()))), + UHIAxis(list(zip(self.yMins(), self.yMaxs()))), ] @property @@ -246,17 +250,17 @@ def kind(self) -> str: # TODO reeavaluate this return "COUNT" - def values(self) -> np.ndarray: + def values(self) -> np.typing.NDArray[Any]: return np.array(self.sumWs()).reshape((len(self.axes[1]), len(self.axes[0]))).T - def variances(self) -> np.ndarray: + def variances(self) -> np.typing.NDArray[Any]: return ( np.array([b.sumW2() for b in self.bins()]) .reshape((len(self.axes[1]), len(self.axes[0]))) .T ) - def counts(self) -> np.ndarray: + def counts(self) -> np.typing.NDArray[Any]: return ( np.array([b.numEntries() for b in self.bins()]) .reshape((len(self.axes[1]), len(self.axes[0]))) @@ -272,7 +276,9 @@ def __get_by_indices(self, ix: int, iy: int) -> Any: self.__single_index(ix, iy) ] # THIS is the fault with/without overflows! - def __get_index_by_loc(self, oloc: loc, bins: list[tuple[float, float]]) -> int: + def __get_index_by_loc( + self, oloc: loc, bins: Union[list[tuple[float, float]], UHIAxis] + ) -> int: # find the index in bin where loc is for a, b in bins: if a <= oloc.value and oloc.value < b: