Skip to content

Commit

Permalink
minor changes to datagridcf (#47)
Browse files Browse the repository at this point in the history
* minor changes to datagridcf

* datagridcf size and column names testing
  • Loading branch information
LamAdr authored Nov 23, 2023
1 parent ad449a6 commit 71bea74
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
8 changes: 4 additions & 4 deletions marginaleffects/datagrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,18 +122,18 @@ def datagridcf(model=None, newdata=None, **kwargs):

if "rowid" not in newdata.columns:
newdata = newdata.with_columns(pl.Series(range(newdata.shape[0])).alias("rowid"))
newdata = newdata.rename({"rowid" : "rowidcf"})

# Create dataframe from kwargs
dfs = [pl.DataFrame({k: v}) for k, v in kwargs.items()]

# Perform cross join
df_cross = reduce(lambda df1, df2: df1.join(df2, how='cross'), dfs)

result = newdata.join(df_cross, how = "cross")
# Drop would-be duplicates
newdata = newdata.drop(df_cross.columns)

# Create rowid and rowidcf
rowidcf = [i for i in range(newdata.shape[0]) for _ in range(result.select(pl.count()).item() // newdata.select(pl.count()).item())]
result = result.with_columns(pl.Series(rowidcf).alias("rowidcf"))
result = newdata.join(df_cross, how = "cross")

result.datagrid_explicit = list(kwargs.keys())

Expand Down
6 changes: 5 additions & 1 deletion tests/test_datagrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,8 @@ def test_simple_grid():
def test_cf():
assert datagrid(newdata = mtcars, mpg = 32).shape[0] == 1
assert datagridcf(newdata = mtcars, mpg = [30, 32]).shape[0] == 64
assert datagridcf(newdata = mtcars, mpg = [30, 32]).unique("rowidcf").shape[0] == 32
assert datagridcf(newdata = mtcars, mpg = 32, am = 0, hp = 100).shape[0] == 32
assert datagridcf(newdata = mtcars, am = [0, 1], hp = [100, 110, 120]).shape[0] == 192
assert datagridcf(newdata = mtcars, mpg = [30, 32]).unique("rowidcf").shape[0] == 32
assert set(datagridcf(newdata = mtcars, mpg = [30, 32]).columns) \
== {'gear', 'qsec', 'mpg', 'cyl', 'am', 'wt', 'vs', 'drat', 'rowidcf', 'disp', 'rownames', 'hp', 'carb'}

0 comments on commit 71bea74

Please sign in to comment.