Skip to content

Commit

Permalink
Add PlottableProtocol
Browse files Browse the repository at this point in the history
  • Loading branch information
APN-Pucky committed Nov 16, 2024
1 parent 5065eec commit d1e36c8
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 17 deletions.
40 changes: 40 additions & 0 deletions src/babyyoda/axis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from collections.abc import Iterator

from uhi.typing.plottable import PlottableAxisGeneric, PlottableTraits


class UHITraits(PlottableTraits):
@property
def circular(self) -> bool:
return False

@property
def discrete(self) -> bool:
return False


class UHIAxis(PlottableAxisGeneric[tuple[float, float]]):
@property
def traits(self) -> UHITraits:
return UHITraits()

def __init__(self, values: list[tuple[float, float]]):
self.values = values

# access axis[i]
def __getitem__(self, i: int) -> tuple[float, float]:
return self.values[i]

def __len__(self) -> int:
return len(self.values)

def __eq__(self, other: object) -> bool:
if isinstance(other, UHIAxis):
return self.values == other.values # noqa: PD011
return False

def __iter__(self) -> Iterator[tuple[float, float]]:
return iter(self.values)

def index(self, value: tuple[float, float]) -> int:
return self.values.index(value)
4 changes: 2 additions & 2 deletions src/babyyoda/grogu/histo2d_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,9 @@ def yMin(self) -> float:
assert min(self.yEdges()) == self.yEdges()[0], "yMin is not the first edge"
return self.yEdges()[0]

def bins(self, includeOverflows: bool = False) -> np.ndarray:
def bins(self, includeOverflows: bool = False) -> np.typing.NDArray[Any]:
if includeOverflows:
return self.d_bins
return np.array(self.d_bins)
# TODO consider represent data always as numpy
return (
np.array(self.d_bins)
Expand Down
21 changes: 14 additions & 7 deletions src/babyyoda/histo1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@

import mplhep as hep
import numpy as np
from uhi.typing.plottable import (
PlottableHistogram,
)

import babyyoda
from babyyoda.analysisobject import UHIAnalysisObject
from babyyoda.axis import UHIAxis
from babyyoda.util import loc, overflow, project, rebin, rebinBy_to_rebinTo, underflow


Expand Down Expand Up @@ -40,7 +44,10 @@ def Histo1D(*args: Any, **kwargs: Any) -> "UHIHisto1D":


# TODO make this implementation independent (no V2 or V3...)
class UHIHisto1D(UHIAnalysisObject):
class UHIHisto1D(
UHIAnalysisObject,
PlottableHistogram,
):
######
# Minimum required functions
######
Expand Down Expand Up @@ -170,7 +177,7 @@ def overflow(self) -> Any:
def underflow(self) -> Any:
return self.bins(includeOverflows=True)[0]

def errWs(self) -> np.ndarray:
def errWs(self) -> Any:
return np.sqrt(np.array([b.sumW2() for b in self.bins()]))

def xMins(self) -> list[float]:
Expand Down Expand Up @@ -218,21 +225,21 @@ def dVols(self) -> list[float]:
########################################################

@property
def axes(self) -> list[list[tuple[float, float]]]:
return [list(zip(self.xMins(), self.xMaxs()))]
def axes(self) -> list[UHIAxis]:
return [UHIAxis(list(zip(self.xMins(), self.xMaxs())))]

@property
def kind(self) -> str:
# TODO reevaluate this
return "COUNT"

def counts(self) -> np.ndarray:
def counts(self) -> np.typing.NDArray[Any]:
return np.array([b.numEntries() for b in self.bins()])

def values(self) -> np.ndarray:
def values(self) -> np.typing.NDArray[Any]:
return np.array([b.sumW() for b in self.bins()])

def variances(self) -> np.ndarray:
def variances(self) -> np.typing.NDArray[Any]:
return np.array([(b.sumW2()) for b in self.bins()])

def __getitem__(
Expand Down
22 changes: 14 additions & 8 deletions src/babyyoda/histo2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@

import mplhep as hep
import numpy as np
from uhi.typing.plottable import (
PlottableHistogram,
)

from babyyoda.analysisobject import UHIAnalysisObject
from babyyoda.axis import UHIAxis
from babyyoda.util import (
loc,
overflow,
Expand Down Expand Up @@ -48,7 +52,7 @@ def Histo2D(*args: list[Any], **kwargs: list[Any]) -> "UHIHisto2D":
return grogu.Histo2D(*args, **kwargs)


class UHIHisto2D(UHIAnalysisObject):
class UHIHisto2D(UHIAnalysisObject, PlottableHistogram):
######
# Minimum required functions
######
Expand Down Expand Up @@ -235,28 +239,28 @@ def dVols(self) -> list[float]:
########################################################

@property
def axes(self) -> list[list[tuple[float, float]]]:
def axes(self) -> list[UHIAxis]:
return [
list(zip(self.xMins(), self.xMaxs())),
list(zip(self.yMins(), self.yMaxs())),
UHIAxis(list(zip(self.xMins(), self.xMaxs()))),
UHIAxis(list(zip(self.yMins(), self.yMaxs()))),
]

@property
def kind(self) -> str:
# TODO reeavaluate this
return "COUNT"

def values(self) -> np.ndarray:
def values(self) -> np.typing.NDArray[Any]:
return np.array(self.sumWs()).reshape((len(self.axes[1]), len(self.axes[0]))).T

def variances(self) -> np.ndarray:
def variances(self) -> np.typing.NDArray[Any]:
return (
np.array([b.sumW2() for b in self.bins()])
.reshape((len(self.axes[1]), len(self.axes[0])))
.T
)

def counts(self) -> np.ndarray:
def counts(self) -> np.typing.NDArray[Any]:
return (
np.array([b.numEntries() for b in self.bins()])
.reshape((len(self.axes[1]), len(self.axes[0])))
Expand All @@ -272,7 +276,9 @@ def __get_by_indices(self, ix: int, iy: int) -> Any:
self.__single_index(ix, iy)
] # THIS is the fault with/without overflows!

def __get_index_by_loc(self, oloc: loc, bins: list[tuple[float, float]]) -> int:
def __get_index_by_loc(
self, oloc: loc, bins: Union[list[tuple[float, float]], UHIAxis]
) -> int:
# find the index in bin where loc is
for a, b in bins:
if a <= oloc.value and oloc.value < b:
Expand Down

0 comments on commit d1e36c8

Please sign in to comment.