From 71bea74333e342a16937bf0a631b9d7e1057a4f6 Mon Sep 17 00:00:00 2001 From: LamAdr <102169875+LamAdr@users.noreply.github.com> Date: Thu, 23 Nov 2023 15:38:00 -0500 Subject: [PATCH] minor changes to datagridcf (#47) * minor changes to datagridcf * datagridcf size and column names testing --- marginaleffects/datagrid.py | 8 ++++---- tests/test_datagrid.py | 6 +++++- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/marginaleffects/datagrid.py b/marginaleffects/datagrid.py index 6fea645..9e84694 100644 --- a/marginaleffects/datagrid.py +++ b/marginaleffects/datagrid.py @@ -122,6 +122,7 @@ def datagridcf(model=None, newdata=None, **kwargs): if "rowid" not in newdata.columns: newdata = newdata.with_columns(pl.Series(range(newdata.shape[0])).alias("rowid")) + newdata = newdata.rename({"rowid" : "rowidcf"}) # Create dataframe from kwargs dfs = [pl.DataFrame({k: v}) for k, v in kwargs.items()] @@ -129,11 +130,10 @@ def datagridcf(model=None, newdata=None, **kwargs): # Perform cross join df_cross = reduce(lambda df1, df2: df1.join(df2, how='cross'), dfs) - result = newdata.join(df_cross, how = "cross") + # Drop would-be duplicates + newdata = newdata.drop(df_cross.columns) - # Create rowid and rowidcf - rowidcf = [i for i in range(newdata.shape[0]) for _ in range(result.select(pl.count()).item() // newdata.select(pl.count()).item())] - result = result.with_columns(pl.Series(rowidcf).alias("rowidcf")) + result = newdata.join(df_cross, how = "cross") result.datagrid_explicit = list(kwargs.keys()) diff --git a/tests/test_datagrid.py b/tests/test_datagrid.py index c25b31e..18c45ab 100644 --- a/tests/test_datagrid.py +++ b/tests/test_datagrid.py @@ -19,4 +19,8 @@ def test_simple_grid(): def test_cf(): assert datagrid(newdata = mtcars, mpg = 32).shape[0] == 1 assert datagridcf(newdata = mtcars, mpg = [30, 32]).shape[0] == 64 - assert datagridcf(newdata = mtcars, mpg = [30, 32]).unique("rowidcf").shape[0] == 32 \ No newline at end of file + assert datagridcf(newdata = mtcars, mpg = 32, am = 0, hp = 100).shape[0] == 32 + assert datagridcf(newdata = mtcars, am = [0, 1], hp = [100, 110, 120]).shape[0] == 192 + assert datagridcf(newdata = mtcars, mpg = [30, 32]).unique("rowidcf").shape[0] == 32 + assert set(datagridcf(newdata = mtcars, mpg = [30, 32]).columns) \ + == {'gear', 'qsec', 'mpg', 'cyl', 'am', 'wt', 'vs', 'drat', 'rowidcf', 'disp', 'rownames', 'hp', 'carb'}