Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revise arithmetics #34

Merged
merged 3 commits into from
Jul 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 52 additions & 13 deletions src/pandas_indexing/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import pandas as pd
from deprecated.sphinx import deprecated
from pandas import DataFrame, Index, MultiIndex, Series
from pandas.api.extensions import no_default

from . import arithmetics
from .core import (
Expand Down Expand Up @@ -131,20 +132,17 @@ def semijoin(
level: Union[str, int, None] = None,
sort: bool = False,
axis: Axis = 0,
fill_value: Any = no_default,
) -> Union[DataFrame, Series]:
return semijoin(self._obj, other, how=how, level=level, sort=sort, axis=axis)

def multiply(self, other, **align_kwds):
return arithmetics.multiply(self._obj, other, **align_kwds)

def divide(self, other, **align_kwds):
return arithmetics.divide(self._obj, other, **align_kwds)

def add(self, other, **align_kwds):
return arithmetics.add(self._obj, other, **align_kwds)

def subtract(self, other, **align_kwds):
return arithmetics.subtract(self._obj, other, **align_kwds)
return semijoin(
self._obj,
other,
how=how,
level=level,
sort=sort,
axis=axis,
fill_value=fill_value,
)

@doc(quantify, data="", example_call="s.pix.quantify()")
def quantify(
Expand Down Expand Up @@ -195,6 +193,47 @@ def aggregate(
)


def _create_forward_binop(op):
def forward_binop(
self,
other: Data,
assign: Optional[Dict[str, Any]] = None,
axis: Optional[Axis] = None,
**align_kwargs: Any,
):
return getattr(arithmetics, op)(
self._obj, other, assign=assign, axis=axis, **align_kwargs
)

return forward_binop


def _create_forward_unitbinop(op):
def forward_unitbinop(
self,
other: Data,
level: str = "unit",
assign: Optional[Dict[str, Any]] = None,
axis: Optional[Axis] = None,
**align_kwargs: Any,
):
return getattr(arithmetics, f"unit{op}")(
self._obj, other, level=level, assign=assign, axis=axis, **align_kwargs
)

return forward_unitbinop


for op in arithmetics.ARITHMETIC_BINOPS:
forward_binop = _create_forward_binop(op)
forward_unitbinop = _create_forward_unitbinop(op)
setattr(_DataPixAccessor, op, forward_binop)
setattr(_DataPixAccessor, f"unit{op}", forward_unitbinop)
for alt in arithmetics.ALTERNATIVE_NAMES.get(op, []):
setattr(_DataPixAccessor, alt, forward_binop)
setattr(_DataPixAccessor, f"unit{alt}", forward_unitbinop)


@pd.api.extensions.register_dataframe_accessor("pix")
class DataFramePixAccessor(_DataPixAccessor):
pass
Expand Down
103 changes: 73 additions & 30 deletions src/pandas_indexing/arithmetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,21 @@
pandas.DataFrame.align
"""

from typing import Any, Mapping, Tuple
import operator
from typing import Any, Dict, Optional

from pandas import DataFrame, Series
from pandas.core.ops import ARITHMETIC_BINOPS

from .types import Data
from .core import assignlevel, uniquelevel
from .types import Axis, Data


ALTERNATIVE_NAMES = {
"truediv": ["div", "divide"],
"mul": ["multiply"],
"sub": ["subtract"],
}


def _needs_axis(df: Data, other: Data) -> bool:
Expand All @@ -27,31 +37,64 @@ def _needs_axis(df: Data, other: Data) -> bool:
)


def _prepare_op(
df: Data, other: Data, kwargs: Mapping[str, Any]
) -> Tuple[Data, Data, Mapping[str, Any]]:
kwargs.setdefault("copy", True)
if _needs_axis(df, other):
kwargs.setdefault("axis", 0)
df, other = df.align(other, **kwargs)
return df, other, kwargs


def add(df: Data, other: Data, **align_kwargs: Any) -> Data:
df, other, align_kwargs = _prepare_op(df, other, align_kwargs)
return df.add(other, axis=align_kwargs.get("axis", 0))


def divide(df: Data, other: Data, **align_kwargs: Any) -> Data:
df, other, align_kwargs = _prepare_op(df, other, align_kwargs)
return df.div(other, axis=align_kwargs.get("axis", 0))


def multiply(df: Data, other: Data, **align_kwargs: Any) -> Data:
df, other, align_kwargs = _prepare_op(df, other, align_kwargs)
return df.mul(other, axis=align_kwargs.get("axis", 0))


def subtract(df: Data, other: Data, **align_kwargs: Any) -> Data:
df, other, align_kwargs = _prepare_op(df, other, align_kwargs)
return df.sub(other, axis=align_kwargs.get("axis", 0))
def _create_binop(op: str):
def binop(
df: Data,
other: Data,
assign: Optional[Dict[str, Any]] = None,
axis: Optional[Axis] = None,
**align_kwargs: Any,
):
if assign is not None:
df = assignlevel(df, **assign)
other = assignlevel(other, **assign)

align_kwargs.setdefault("copy", False)
if _needs_axis(df, other):
axis = 0
if isinstance(df, Series) and isinstance(other, DataFrame):
if align_kwargs.get("join") in ("left", "right"):
align_kwargs["join"] = {"left": "right", "right": "left"}[
align_kwargs["join"]
]
other, df = other.align(df, axis=axis, **align_kwargs)
else:
df, other = df.align(other, axis=axis, **align_kwargs)

return getattr(df, op)(other, axis=axis)

return binop


def _create_unitbinop(op, binop):
def unitbinop(
df: Data,
other: Data,
level: str = "unit",
assign: Optional[Dict[str, Any]] = None,
axis: Optional[Axis] = None,
**align_kwargs: Any,
):
df_unit = uniquelevel(df, level, axis=axis).item()
other_unit = uniquelevel(other, level, axis=axis).item()

import pint

ur = pint.get_application_registry()
quantity = getattr(operator, op)(ur(df_unit), ur(other_unit)).to_reduced_units()

if assign is None:
assign = dict()
assign = {level: f"{quantity.units:~}"} | assign

return binop(df, other, assign=assign, axis=axis, **align_kwargs) * quantity.m

return unitbinop


for op in ARITHMETIC_BINOPS:
binop = _create_binop(op)
unitbinop = _create_unitbinop(op, binop)
globals().update({op: binop, f"unit{op}": unitbinop})
for alt in ALTERNATIVE_NAMES.get(op, []):
globals().update({alt: binop, f"unit{alt}": unitbinop})
10 changes: 8 additions & 2 deletions src/pandas_indexing/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import pandas as pd
from deprecated import deprecated
from pandas import DataFrame, Index, MultiIndex, Series
from pandas.api.extensions import no_default
from pandas.core.indexes.frozen import FrozenList

from .types import Axis, Data, S, T
Expand Down Expand Up @@ -461,6 +462,7 @@ def semijoin(
level: Union[str, int, None] = None,
sort: bool = False,
axis: Axis = 0,
fill_value: Any = no_default,
) -> S:
"""Semijoin ``data`` by index ``other``.

Expand All @@ -477,6 +479,8 @@ def semijoin(
Whether to sort the index
axis : {{0, 1, "index", "columns"}}
Axis on which to join
fill_value
Value for filling gaps introduced by right or outer joins

Returns
-------
Expand Down Expand Up @@ -516,11 +520,13 @@ def semijoin(
data = frame_or_series.iloc[:, left_idx]
index = data.columns
if any_missing:
data = data.where(pd.Series(left_idx != -1, index), axis=axis)
data = data.where(
pd.Series(left_idx != -1, index), other=fill_value, axis=axis
)
elif isinstance(frame_or_series, Series):
data = frame_or_series.iloc[left_idx]
if any_missing:
data = data.where(left_idx != -1)
data = data.where(left_idx != -1, other=fill_value)
else:
raise TypeError(
f"frame_or_series must derive from DataFrame or Series, but is {type(frame_or_series)}"
Expand Down
15 changes: 3 additions & 12 deletions src/pandas_indexing/units.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
--------
pint.set_application_registry
"""

from typing import Callable, Mapping, Optional, Union

from pandas import DataFrame, Series
Expand All @@ -73,12 +72,6 @@
except ImportError:
has_pint = False

try:
import openscm_units

has_openscm_units = True
except ImportError:
has_openscm_units = False

from .core import assignlevel, uniquelevel
from .types import Axis, Data
Expand Down Expand Up @@ -306,12 +299,12 @@ def _convert_unit(df, old_unit=None):
_openscm_registry = None


def get_openscm_registry(add_co2e: bool = True) -> "openscm_units.ScmUnitRegistry":
def get_openscm_registry(add_co2e: bool = True):
global _openscm_registry
if _openscm_registry is not None:
return _openscm_registry

assert has_openscm_units, INSTALL_PACKAGE_WARNING.format(package="openscm-units")
import openscm_units

if add_co2e:
_openscm_registry = openscm_units.ScmUnitRegistry()
Expand All @@ -324,9 +317,7 @@ def get_openscm_registry(add_co2e: bool = True) -> "openscm_units.ScmUnitRegistr
return _openscm_registry


def set_openscm_registry_as_default(
add_co2e: bool = True,
) -> "openscm_units.ScmUnitRegistry":
def set_openscm_registry_as_default(add_co2e: bool = True):
unit_registry = get_openscm_registry(add_co2e=add_co2e)

assert has_pint, INSTALL_PACKAGE_WARNING.format(package="pint")
Expand Down
6 changes: 5 additions & 1 deletion tests/test_units.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from importlib.util import find_spec

import pytest
from pandas import DataFrame, Series
from pandas.testing import assert_frame_equal, assert_series_equal

from pandas_indexing import assignlevel, convert_unit, set_openscm_registry_as_default
from pandas_indexing.units import has_openscm_units, has_pint, has_pint_pandas, is_unit
from pandas_indexing.units import has_pint, has_pint_pandas, is_unit


has_openscm_units = bool(find_spec("openscm_units"))

needs_pint = pytest.mark.skipif(not has_pint, reason="Needs pint package")
needs_openscm_units = pytest.mark.skipif(
Expand Down