Skip to content

Commit

Permalink
isort
Browse files Browse the repository at this point in the history
  • Loading branch information
vincentarelbundock committed Dec 22, 2023
1 parent 47eb2c0 commit 3e343f7
Show file tree
Hide file tree
Showing 36 changed files with 110 additions and 77 deletions.
6 changes: 3 additions & 3 deletions marginaleffects/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from .comparisons import avg_comparisons, comparisons
from .datagrid import datagrid, datagridcf
from .hypotheses import hypotheses
from .plot_comparisons import plot_comparisons
from .plot_predictions import plot_predictions
from .plot_slopes import plot_slopes
from .predictions import avg_predictions, predictions
from .slopes import avg_slopes, slopes
from .plot_predictions import plot_predictions
from .plot_comparisons import plot_comparisons
from .plot_slopes import plot_slopes
1 change: 1 addition & 0 deletions marginaleffects/classes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import polars as pl


class MarginaleffectsDataFrame(pl.DataFrame):
def __init__(self, data=None, by=None, conf_level=0.95, newdata=None):
if isinstance(data, pl.DataFrame):
Expand Down
9 changes: 5 additions & 4 deletions marginaleffects/comparisons.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
import re
from functools import reduce

import numpy as np
import patsy
import polars as pl
import re

from .classes import MarginaleffectsDataFrame
from .equivalence import get_equivalence
from .estimands import estimands
from .getters import get_coef, get_modeldata, get_predict
from .hypothesis import get_hypothesis
from .sanity import sanitize_newdata, sanitize_variables, sanitize_vcov, sanitize_by, sanitize_hypothesis_null
from .sanity import (sanitize_by, sanitize_hypothesis_null, sanitize_newdata,
sanitize_variables, sanitize_vcov)
from .transform import get_transform
from .uncertainty import get_jacobian, get_se, get_z_p_ci
from .utils import get_pad, sort_columns, upcast
from .getters import get_modeldata, get_predict, get_coef
from .classes import MarginaleffectsDataFrame


def comparisons(
Expand Down
3 changes: 2 additions & 1 deletion marginaleffects/datagrid.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from functools import reduce
from .getters import get_modeldata

import polars as pl

from .getters import get_modeldata


def datagrid(
model=None,
Expand Down
1 change: 1 addition & 0 deletions marginaleffects/getters.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re

import numpy as np
import polars as pl

Expand Down
8 changes: 4 additions & 4 deletions marginaleffects/hypotheses.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import numpy as np
import polars as pl

from .classes import MarginaleffectsDataFrame
from .equivalence import get_equivalence
from .getters import get_coef
from .hypothesis import get_hypothesis
from .sanity import sanitize_vcov, sanitize_hypothesis_null
from .sanity import sanitize_hypothesis_null, sanitize_vcov
from .uncertainty import get_jacobian, get_se, get_z_p_ci
from .equivalence import get_equivalence
from .utils import sort_columns
from .getters import get_coef
from .classes import MarginaleffectsDataFrame


def hypotheses(model, hypothesis=None, conf_level=0.95, vcov=True, equivalence=None):
Expand Down
8 changes: 4 additions & 4 deletions marginaleffects/plot_common.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import polars as pl
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
import polars as pl
from matplotlib.lines import Line2D

from .utils import get_variable_type
from .getters import get_modeldata
from .datagrid import datagrid
from .getters import get_modeldata
from .utils import get_variable_type


def dt_on_condition(model, condition):
Expand Down
6 changes: 3 additions & 3 deletions marginaleffects/plot_comparisons.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import polars as pl
import numpy as np
import polars as pl

from .utils import get_variable_type
from .comparisons import comparisons
from .getters import get_modeldata
from .plot_common import dt_on_condition, plot_common
from .comparisons import comparisons
from .utils import get_variable_type


def plot_comparisons(
Expand Down
2 changes: 1 addition & 1 deletion marginaleffects/plot_predictions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import polars as pl
import numpy as np
import polars as pl

from .getters import find_response
from .plot_common import dt_on_condition, plot_common
Expand Down
4 changes: 2 additions & 2 deletions marginaleffects/plot_slopes.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import polars as pl
import numpy as np
import polars as pl

from .utils import get_variable_type
from .getters import get_modeldata
from .plot_common import dt_on_condition, plot_common
from .slopes import slopes
from .utils import get_variable_type


def plot_slopes(
Expand Down
11 changes: 5 additions & 6 deletions marginaleffects/predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@
import polars as pl

from .by import get_by
from .classes import MarginaleffectsDataFrame
from .equivalence import get_equivalence
from .getters import get_coef, get_modeldata, get_predict, get_variables_names
from .hypothesis import get_hypothesis
from .sanity import sanitize_newdata, sanitize_vcov, sanitize_by, sanitize_hypothesis_null
from .sanity import (sanitize_by, sanitize_hypothesis_null, sanitize_newdata,
sanitize_vcov)
from .transform import get_transform
from .uncertainty import get_jacobian, get_se, get_z_p_ci
from .utils import sort_columns, get_pad, upcast
from .getters import get_modeldata, get_variables_names, get_predict, get_coef
from .classes import MarginaleffectsDataFrame


from .utils import get_pad, sort_columns, upcast


def predictions(
Expand Down
4 changes: 2 additions & 2 deletions marginaleffects/sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
from warnings import warn

import numpy as np
import polars as pl
import pandas as pd
import polars as pl

from .datagrid import datagrid
from .estimands import estimands
from .utils import get_variable_type
from .getters import get_modeldata, get_variables_names, get_vcov
from .utils import get_variable_type


def sanitize_vcov(vcov, model):
Expand Down
1 change: 1 addition & 0 deletions marginaleffects/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools

import numpy as np
import polars as pl

Expand Down
5 changes: 3 additions & 2 deletions tests/test_bugfix.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import numpy as np
import polars as pl
import pandas as pd
from marginaleffects import predictions
import polars as pl
import statsmodels.formula.api as smf

from marginaleffects import predictions


def test_issue_25():
d = pd.DataFrame(np.random.randint(0,100,size=(100, 4)), columns=list('ABCD'))
Expand Down
7 changes: 4 additions & 3 deletions tests/test_by.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import pytest
import polars as pl
import pytest
import statsmodels.formula.api as smf
from pytest import approx
import polars as pl

from marginaleffects import *

from .utilities import *
import statsmodels.formula.api as smf

Guerry = pl.read_csv("https://vincentarelbundock.github.io/Rdatasets/csv/HistData/Guerry.csv", null_values = "NA").drop_nulls()
mod_py = smf.ols("Literacy ~ Pop1831 * Desertion", Guerry).fit()
Expand Down
21 changes: 12 additions & 9 deletions tests/test_comparisons.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
import re

import numpy as np
import polars as pl
import statsmodels.api as sm
import statsmodels.formula.api as smf
import polars as pl
import numpy as np
from polars.testing import assert_series_equal
from pytest import approx
from marginaleffects import *

import marginaleffects
from marginaleffects import *
from marginaleffects.comparisons import estimands
from polars.testing import assert_series_equal


dat = pl.read_csv("https://vincentarelbundock.github.io/Rdatasets/csv/HistData/Guerry.csv", null_values = "NA") \
.drop_nulls() \
.with_columns(
(pl.col("Area") > pl.col("Area").median()).alias("Bool"),
(pl.col("Area") > pl.col("Area").median()).alias("Boolea"),
(pl.col("Distance") > pl.col("Distance").median()).alias("Bin"))
dat = dat \
.with_columns(
pl.col("Bin").apply(lambda x: int(x), return_dtype=pl.Int32).alias('Bin'),
pl.Series(np.random.choice(["a", "b", "c"], dat.shape[0])).alias("Char"))
pl.col("Bin").cast(pl.Int32),
pl.Series(np.random.choice(["a", "b", "c"], dat.shape[0])).alias("Char")) \
.to_pandas()

mod = smf.ols("Literacy ~ Pop1831 * Desertion", dat).fit()


Expand Down Expand Up @@ -70,7 +73,7 @@ def test_difference_wts():


def test_bare_minimum():
fit = smf.ols("Literacy ~ Pop1831 * Desertion + Bool + Bin + Char", data = dat.to_pandas).fit()
fit = smf.ols("Literacy ~ Pop1831 * Desertion + Boolea + Bin + Char", data = dat).fit()
assert type(comparisons(fit)) == marginaleffects.classes.MarginaleffectsDataFrame
assert type(comparisons(fit, variables = "Pop1831", comparison = "differenceavg")) == marginaleffects.classes.MarginaleffectsDataFrame
assert type(comparisons(fit, variables = "Pop1831", comparison = "difference").head()) == marginaleffects.classes.MarginaleffectsDataFrame
Expand Down
6 changes: 4 additions & 2 deletions tests/test_datagrid.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from marginaleffects import *
import statsmodels.formula.api as smf
import polars as pl
import statsmodels.formula.api as smf

from marginaleffects import *

mtcars = pl.read_csv("https://vincentarelbundock.github.io/Rdatasets/csv/datasets/mtcars.csv")

def test_FUN_numeric():
Expand Down
3 changes: 2 additions & 1 deletion tests/test_dt_on_condition.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import polars as pl
import statsmodels.formula.api as smf

from marginaleffects import *
from marginaleffects.plot_common import dt_on_condition
from .utilities import *

from .utilities import *

df = pl.read_csv("https://vincentarelbundock.github.io/Rdatasets/csv/HistData/Guerry.csv", null_values = "NA") \
.drop_nulls()
Expand Down
3 changes: 2 additions & 1 deletion tests/test_equivalence.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import polars as pl
import statsmodels.formula.api as smf
from polars.testing import assert_series_equal

from marginaleffects import *
import statsmodels.formula.api as smf

Guerry = pl.read_csv("https://vincentarelbundock.github.io/Rdatasets/csv/HistData/Guerry.csv", null_values = "NA").drop_nulls()
mod_py = smf.ols("Literacy ~ Pop1831 * Desertion", Guerry.to_pandas()).fit()
Expand Down
5 changes: 3 additions & 2 deletions tests/test_hypotheses.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import statsmodels.formula.api as smf
import numpy as np
from marginaleffects import *
import polars as pl
import statsmodels.formula.api as smf
from polars.testing import assert_series_equal

from marginaleffects import *

dat = pl.read_csv("https://vincentarelbundock.github.io/Rdatasets/csv/HistData/Guerry.csv")

mod = smf.ols("Literacy ~ Pop1831 * Desertion", dat).fit()
Expand Down
5 changes: 3 additions & 2 deletions tests/test_newdata.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from marginaleffects import *
import statsmodels.formula.api as smf
import polars as pl
import statsmodels.formula.api as smf

from marginaleffects import *

mtcars = pl.read_csv("https://vincentarelbundock.github.io/Rdatasets/csv/datasets/mtcars.csv")
mod = smf.probit("am ~ hp + wt", data = mtcars).fit()
Expand Down
6 changes: 4 additions & 2 deletions tests/test_plot_comparisons.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import os

import polars as pl
import statsmodels.formula.api as smf
import pytest
import statsmodels.formula.api as smf
from matplotlib.testing.compare import compare_images

from marginaleffects import *
from marginaleffects.plot_comparisons import *
from .utilities import *

from .utilities import *

df = pl.read_csv("https://vincentarelbundock.github.io/Rdatasets/csv/palmerpenguins/penguins.csv", null_values = "NA") \
.drop_nulls()
Expand Down
6 changes: 4 additions & 2 deletions tests/test_plot_predictions.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import os
import pytest

import polars as pl
import pytest
import statsmodels.formula.api as smf
from matplotlib.testing.compare import compare_images

from marginaleffects import *
from marginaleffects.plot_predictions import *
from .utilities import *

from .utilities import *

df = pl.read_csv("https://vincentarelbundock.github.io/Rdatasets/csv/palmerpenguins/penguins.csv", null_values = "NA") \
.drop_nulls()
Expand Down
6 changes: 4 additions & 2 deletions tests/test_plot_slopes.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import pytest
import os

import polars as pl
import pytest
import statsmodels.formula.api as smf
from matplotlib.testing.compare import compare_images

from marginaleffects import *
from marginaleffects.plot_slopes import *
from .utilities import *

from .utilities import *

df = pl.read_csv("https://vincentarelbundock.github.io/Rdatasets/csv/palmerpenguins/penguins.csv", null_values = "NA") \
.drop_nulls()
Expand Down
5 changes: 3 additions & 2 deletions tests/test_predictions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import polars as pl
import marginaleffects
import statsmodels.formula.api as smf

import marginaleffects
from marginaleffects import *
from .utilities import *

from .utilities import *

df = pl.read_csv("https://vincentarelbundock.github.io/Rdatasets/csv/HistData/Guerry.csv", null_values = "NA").drop_nulls()
df = df \
Expand Down
8 changes: 5 additions & 3 deletions tests/test_slopes.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import polars as pl
import statsmodels.formula.api as smf
from polars.testing import assert_series_equal

from marginaleffects import *
from .utilities import *
from marginaleffects.comparisons import estimands
from polars.testing import assert_series_equal
import polars as pl

from .utilities import *

mtcars = pl.read_csv("https://vincentarelbundock.github.io/Rdatasets/csv/datasets/mtcars.csv")
mod_py = smf.ols("mpg ~ wt * hp", mtcars).fit()
Expand Down
3 changes: 2 additions & 1 deletion tests/test_statsmodels_logit.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import polars as pl
import statsmodels.formula.api as smf
from marginaleffects import *
from polars.testing import assert_series_equal

from marginaleffects import *

dat = pl.read_csv("https://vincentarelbundock.github.io/Rdatasets/csv/datasets/iris.csv")
dat = dat.rename({"Sepal.Length": "Sepal_Length", "Sepal.Width": "Sepal_Width", "Petal.Length": "Petal_Length", "Petal.Width": "Petal_Width"})
dat = dat.with_columns((pl.col("Sepal_Width") < pl.col("Sepal_Width").median()).cast(pl.Int16).alias("bin"))
Expand Down
Loading

0 comments on commit 3e343f7

Please sign in to comment.