Skip to content

Commit

Permalink
Add typing to resolver
Browse files Browse the repository at this point in the history
  • Loading branch information
coroa committed Oct 10, 2024
1 parent 084de7a commit 44c211b
Showing 1 changed file with 64 additions and 59 deletions.
123 changes: 64 additions & 59 deletions src/pandas_indexing/iamc/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from contextlib import contextmanager
from functools import reduce
from itertools import product
from typing import Any, Callable, Sequence, TypeVar
from typing import Any, Callable, Iterator, Sequence, TypeVar

from attrs import define, evolve
from pandas import DataFrame, Index, MultiIndex, Series
Expand All @@ -31,7 +31,7 @@ def _summarize(index: MultiIndex, names: Sequence[str]) -> DataFrame:
"""
return (
index.to_frame()
.pix.project(names)[index.names.difference(names)]
.pix.project(names)[index.names.difference(names)] # type: ignore
.groupby(names)
.agg(lambda x: print_list(set(x), n=42))
)
Expand Down Expand Up @@ -59,9 +59,6 @@ class Context:
optional_combinations: bool = False


SelfVar = TypeVar("SelfVar", bound="Var")


@define
class Var:
"""Instance for a single variant.
Expand All @@ -82,6 +79,8 @@ class Var:
data: DataFrame
provenance: str

SV = TypeVar("SV", bound="Var")

def __repr__(self) -> str:
return f"Var {self.provenance}\n{self.data.pix}"

Expand All @@ -92,35 +91,42 @@ def empty(self) -> bool:
def index(self, levels: Sequence[str]) -> MultiIndex:
if not set(levels).issubset(self.data.index.names):
return MultiIndex.from_tuples([], names=levels)
return self.data.pix.unique(levels)
return self.data.pix.unique(levels) # type: ignore

def _binop(
self, op, x: SelfVar, y: SelfVar, provenance_maker: Callable[[str, str], str]
) -> SelfVar:
self: SV,
op,
x: SV,
y: SV,
provenance_maker: Callable[[str, str], str],
) -> SV:
if not all(isinstance(v, Var) for v in (x, y)):
return NotImplemented
provenance = provenance_maker(x.provenance, y.provenance)
return self.__class__(op(x.data, y.data, join="inner"), provenance)

def __add__(self, other: SelfVar) -> SelfVar:
return self._binop(arithmetics.add, self, other, lambda x, y: f"{x} + {y}")
def __add__(self: SV, other: SV) -> SV:
return self._binop(arithmetics.add, self, other, lambda x, y: f"{x} + {y}") # type: ignore

def __sub__(self, other: SelfVar) -> SelfVar:
def __sub__(self: SV, other: SV) -> SV:
return self._binop(
arithmetics.sub, self, other, lambda x, y: f"{x} - {maybe_parens(y)}"
arithmetics.sub, # type: ignore
self,
other,
lambda x, y: f"{x} - {maybe_parens(y)}",
)

def __mul__(self, other: SelfVar) -> SelfVar:
def __mul__(self: SV, other: SV) -> SV:
return self._binop(
arithmetics.mul,
arithmetics.mul, # type: ignore
self,
other,
lambda x, y: f"{maybe_parens(x)} * {maybe_parens(y)}",
)

def __truediv__(self, other: SelfVar) -> SelfVar:
def __truediv__(self: SV, other: SV) -> SV:
return self._binop(
arithmetics.div,
arithmetics.div, # type: ignore
self,
other,
lambda x, y: f"{maybe_parens(x)} / {maybe_parens(y)}",
Expand All @@ -130,10 +136,6 @@ def as_df(self, **assign: str) -> DataFrame:
return self.data.pix.assign(**(assign | dict(provenance=self.provenance)))


SelfVars = TypeVar("SelfVars", bound="Vars")
T = TypeVar("T", bound=Index | DataFrame | Series)


@define
class Vars:
"""`Vars` holds several derivations of a variable from data in a `Resolver`
Expand All @@ -160,18 +162,21 @@ class Vars:
data: list[Var]
context: Context

SV = TypeVar("SV", bound="Vars")
T = TypeVar("T", bound=Index | DataFrame | Series)

@classmethod
def from_data(
cls,
cls: type[SV],
data: DataFrame,
value: Any,
value: str,
*,
context: Context,
provenance: str | None = None,
) -> SelfVars:
) -> SV:
if provenance is None:
provenance = value
data = data.loc[isin(**{context.level: value})].droplevel(context.level)
data = data.loc[isin(**{context.level: value})].droplevel(context.level) # type: ignore
return cls([Var(data, provenance)] if not data.empty else [], context)

def __repr__(self) -> str:
Expand Down Expand Up @@ -203,20 +208,20 @@ def __bool__(self) -> bool:
def __len__(self) -> int:
return len(self.data)

def __iter__(self):
def __iter__(self) -> Iterator[Var]:
return iter(self.data)

def __getitem__(self, i: int) -> SelfVars:
def __getitem__(self: SV, i: int) -> SV:
ret = self.data[i]
if isinstance(ret, Var):
ret = [ret]
return self.__class__(ret, self.context)

class _LocIndexer:
def __init__(self, obj: SelfVars):
def __init__(self, obj):
self._obj = obj

def __getitem__(self, x: Any) -> SelfVars:
def __getitem__(self, x: Any) -> Vars:
obj = self._obj
return obj.__class__(
[
Expand All @@ -228,24 +233,24 @@ def __getitem__(self, x: Any) -> SelfVars:
)

@property
def loc(self) -> SelfVars:
def loc(self) -> _LocIndexer:
return self._LocIndexer(self)

@staticmethod
def _antijoin(data: T, vars: SelfVars) -> T:
return reduce(lambda d, v: d.pix.antijoin(v.data.index), vars.data, data)
def _antijoin(data: T, vars: Vars) -> T:
return reduce(lambda d, v: d.pix.antijoin(v.data.index), vars.data, data) # type: ignore

def antijoin(self, other: SelfVars) -> SelfVars:
def antijoin(self: SV, other: SV) -> SV:
"""Remove everything from self that is already in `other`
Parameters
----------
other : SelfVars
other : Vars
Another set of derivations for the same variable
Returns
-------
SelfVars
Vars
Subset of self that is not already provided by `other`
"""
return self.__class__(
Expand Down Expand Up @@ -278,9 +283,10 @@ def missing(

def existing(self, summarize: bool = True) -> DataFrame | MultiIndex:
index = concat(var.data.index for var in self.data)
assert isinstance(index, MultiIndex)
return _summarize(index, self.context.index) if summarize else index

def _binop(self, op, x: SelfVars, y: SelfVars) -> SelfVars:
def _binop(self: SV, op: Callable[[Var, Var], Var], x: SV, y: SV) -> SV:
if not all(isinstance(v, Vars) for v in (x, y)):
return NotImplemented

Expand All @@ -291,37 +297,37 @@ def _binop(self, op, x: SelfVars, y: SelfVars) -> SelfVars:
res = res | x | y
return res

def __add__(self, other: SelfVars) -> SelfVars:
def __add__(self: SV, other: SV) -> SV:
if other == 0:
return self
return self._binop(operator.add, self, other)

def __radd__(self, other: SelfVars) -> SelfVars:
def __radd__(self: SV, other: SV) -> SV:
if other == 0:
return self
return self._binop(operator.add, other, self)

def __sub__(self, other: SelfVars) -> SelfVars:
def __sub__(self: SV, other: SV) -> SV:
if other == 0:
return self
return self._binop(operator.sub, self, other)

def __rsub__(self, other: SelfVars) -> SelfVars:
def __rsub__(self: SV, other: SV) -> SV:
if other == 0:
return self
return self._binop(operator.sub, other, self)

def __mul__(self, other: SelfVars) -> SelfVars:
def __mul__(self: SV, other: SV) -> SV:
if other == 1:
return self
return self._binop(operator.mul, self, other)

def __rmul__(self, other: SelfVars) -> SelfVars:
def __rmul__(self: SV, other: SV) -> SV:
if other == 1:
return self
return self._binop(operator.mul, other, self)

def __or__(self, other: SelfVars | float | int) -> SelfVars:
def __or__(self: SV, other: SV | float | int) -> SV:
if isinstance(other, (float, int)):
provenance = str(other)
other_index = self._missing()
Expand All @@ -336,10 +342,10 @@ def __or__(self, other: SelfVars | float | int) -> SelfVars:

return self ^ other.antijoin(self)

def __ror__(self, other: SelfVars) -> SelfVars:
def __ror__(self: SV, other: SV) -> SV:
return other ^ self.antijoin(other)

def __xor__(self, other: SelfVars) -> SelfVars:
def __xor__(self: SV, other: SV) -> SV:
if not isinstance(other, Vars):
return NotImplemented
return self.__class__(self.data + other.data, self.context)
Expand All @@ -348,15 +354,12 @@ def as_df(self, **assign: str) -> DataFrame:
return concat(var.as_df(**assign) for var in self.data)


SelfResolver = TypeVar("SelfResolver", bound="Resolver")


@define
class Resolver:
"""Resolver allows to consolidate variables by composing variants.
Examples
--------
Usage
-----
>>> co2emissions = ar6.loc[isin(gas="CO2")]
>>> r = Resolver.from_data(co2emissions, "sector", ["AFOLU", "Energy"])
>>> r["Energy"] |= r["Energy|Demand"] + r["Energy|Supply"]
Expand All @@ -369,15 +372,17 @@ class Resolver:
context: Context # context is shared with all Vars created
external_data: dict[str, DataFrame]

SR = TypeVar("SR", bound="Resolver")

@classmethod
def from_data(
cls,
cls: type[SR],
data: DataFrame,
level: str,
values: Sequence[str] | None = None,
index: Sequence[str] = ("model", "scenario"),
**external_data: DataFrame,
) -> SelfResolver:
) -> SR:
context = Context(
level,
full_index=data.index.droplevel(level).unique(),
Expand All @@ -395,10 +400,10 @@ def add(self, value: str) -> Vars:
return vars

class _LocIndexer:
def __init__(self, obj: SelfResolver):
def __init__(self, obj: Resolver):
self._obj = obj

def __getitem__(self, x: Any) -> SelfResolver:
def __getitem__(self, x: Any) -> Resolver:
obj = self._obj
vars = {name: z for name, vars in obj.vars.items() if (z := vars.loc[x])}
data = obj.data.loc[x]
Expand All @@ -408,7 +413,7 @@ def __getitem__(self, x: Any) -> SelfResolver:
return obj.__class__(vars, data, context, obj.external_data)

@property
def loc(self) -> SelfResolver:
def loc(self) -> _LocIndexer:
return self._LocIndexer(self)

def __len__(self) -> int:
Expand Down Expand Up @@ -439,9 +444,9 @@ def __getitem__(self, value: str) -> Vars:

def _ipython_key_completions_(self) -> list[str]:
comps = list(self.vars)
comps.extend(self.data.pix.unique(self.context.level).difference(comps))
comps.extend(self.data.pix.unique(self.context.level).difference(comps)) # type: ignore
for n, v in self.external_data.items():
comps.extend(f"{n}|" + v.pix.unique(self.context.level))
comps.extend(f"{n}|" + v.pix.unique(self.context.level)) # type: ignore
return comps

@property
Expand All @@ -454,7 +459,7 @@ def index(self) -> MultiIndex:
)

def __repr__(self) -> str:
num_scenarios = len(self.data.pix.unique(self.context.index))
num_scenarios = len(self.data.pix.unique(self.context.index)) # type: ignore
level = self.context.level

s = (
Expand All @@ -465,7 +470,7 @@ def __repr__(self) -> str:
s += (
f"* {name} ({len(vars)}): "
+ ", ".join(
str(len(var.data.pix.unique(self.context.index)))
str(len(var.data.pix.unique(self.context.index))) # type: ignore
for var in vars.data
)
+ "\n"
Expand All @@ -476,7 +481,7 @@ def __repr__(self) -> str:
]

unused_values = (
self.data.pix.unique([*self.context.index, level])
self.data.pix.unique([*self.context.index, level]) # type: ignore
.pix.antijoin(Index(self.vars, name=level).union(existing_provenances))
.pix.project(level)
.value_counts()
Expand Down

0 comments on commit 44c211b

Please sign in to comment.