Skip to content

Commit

Permalink
lint and reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
azmyrajab committed Mar 15, 2024
1 parent 81de0a0 commit 06eda5e
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 53 deletions.
51 changes: 3 additions & 48 deletions polars_ols/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
from functools import reduce
from typing import Sequence

import polars as pl

from polars_ols.least_squares import pl_least_squares
from polars_ols.utils import build_expressions_from_patsy_formula
from polars_ols.least_squares import pl_least_squares, pl_least_squares_from_formula

__all__ = [
"pl_least_squares",
Expand All @@ -12,48 +9,6 @@
]


def _build_expressions_from_patsy_formula(
formula: str, include_dependent_variable: bool = False
) -> (Sequence[pl.Expr], bool):
try:
import patsy as pa
except ImportError as e:
raise NotImplementedError(
"'patsy' needs to be installed in your python environment in order to use "
"formula api"
) from e
desc = pa.ModelDesc.from_formula(formula)

if include_dependent_variable:
assert len(desc.lhs_termlist) == 1, "must provide exactly one LHS variable"
terms = desc.lhs_termlist + desc.rhs_termlist
else:
assert len(desc.lhs_termlist) == 0, "can not provide LHS variables in this context"
terms = desc.rhs_termlist

add_intercept: bool = "-1" not in formula

expressions = []
for term in terms:
if any("C(" in f.code for f in term.factors):
raise NotImplementedError(
"building patsy categories into polars expressions is not supported"
)
if len(term.factors) == 1:
expressions.append(pl.col(term.factors[0].code))
elif len(term.factors) >= 2:
expr = reduce((lambda x, y: x * pl.col(y)), (f.code for f in term.factors), pl.lit(1))
expressions.append(expr.alias(":".join(f.code for f in term.factors)))
return expressions, add_intercept


def pl_least_squares_from_formula(formula: str, **kwargs) -> pl.Expr:
expressions, add_intercept = _build_expressions_from_patsy_formula(
formula, include_dependent_variable=True
)
return pl_least_squares(expressions[0], *expressions[1:], add_intercept=add_intercept, **kwargs)


@pl.api.register_expr_namespace("least_squares")
class LeastSquares:
def __init__(self, expr: pl.Expr):
Expand All @@ -72,7 +27,7 @@ def ridge(self, *features: pl.Expr, alpha: float) -> pl.Expr:
return self.least_squares(*features, ridge_alpha=alpha)

def from_formula(self, formula: str, **kwargs) -> pl.Expr:
features, add_intercept = _build_expressions_from_patsy_formula(
features, add_intercept = build_expressions_from_patsy_formula(
formula, include_dependent_variable=False
)
return self.least_squares(*features, add_intercept=add_intercept, **kwargs)
11 changes: 9 additions & 2 deletions polars_ols/least_squares.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from polars.type_aliases import IntoExpr
from polars.utils.udfs import _get_shared_lib_location

from polars_ols.utils import parse_into_expr
from polars_ols.utils import parse_into_expr, build_expressions_from_patsy_formula

lib = _get_shared_lib_location(__file__)

__all__ = ["pl_least_squares"]
__all__ = ["pl_least_squares", "pl_least_squares_from_formula"]


def pl_least_squares(
Expand Down Expand Up @@ -41,3 +41,10 @@ def pl_least_squares(
)
/ sqrt_w
) # undo the sqrt(w) scaling implicit in predictions (:= scaled_features @ coef)


def pl_least_squares_from_formula(formula: str, **kwargs) -> pl.Expr:
expressions, add_intercept = build_expressions_from_patsy_formula(
formula, include_dependent_variable=True
)
return pl_least_squares(expressions[0], *expressions[1:], add_intercept=add_intercept, **kwargs)
38 changes: 37 additions & 1 deletion polars_ols/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from functools import reduce
from typing import TYPE_CHECKING, Sequence

import polars as pl

Expand Down Expand Up @@ -46,3 +47,38 @@ def parse_into_expr(
expr = pl.lit(expr, dtype=dtype)

return expr


def build_expressions_from_patsy_formula(
formula: str, include_dependent_variable: bool = False
) -> (Sequence[pl.Expr], bool):
try:
import patsy as pa
except ImportError as e:
raise NotImplementedError(
"'patsy' needs to be installed in your python environment in order to use "
"formula api"
) from e
desc = pa.ModelDesc.from_formula(formula)

if include_dependent_variable:
assert len(desc.lhs_termlist) == 1, "must provide exactly one LHS variable"
terms = desc.lhs_termlist + desc.rhs_termlist
else:
assert len(desc.lhs_termlist) == 0, "can not provide LHS variables in this context"
terms = desc.rhs_termlist

add_intercept: bool = "-1" not in formula

expressions = []
for term in terms:
if any("C(" in f.code for f in term.factors):
raise NotImplementedError(
"building patsy categories into polars expressions is not supported"
)
if len(term.factors) == 1:
expressions.append(pl.col(term.factors[0].code))
elif len(term.factors) >= 2:
expr = reduce((lambda x, y: x * pl.col(y)), (f.code for f in term.factors), pl.lit(1))
expressions.append(expr.alias(":".join(f.code for f in term.factors)))
return expressions, add_intercept
3 changes: 1 addition & 2 deletions tests/test_ols.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import polars as pl
import numpy as np

from polars_ols import pl_least_squares_from_formula
from polars_ols.least_squares import pl_least_squares
from polars_ols import pl_least_squares_from_formula, pl_least_squares
import statsmodels.formula.api as smf


Expand Down

0 comments on commit 06eda5e

Please sign in to comment.