diff --git a/marginaleffects/datagrid.py b/marginaleffects/datagrid.py index 0b29759..6fea645 100644 --- a/marginaleffects/datagrid.py +++ b/marginaleffects/datagrid.py @@ -132,7 +132,8 @@ def datagridcf(model=None, newdata=None, **kwargs): result = newdata.join(df_cross, how = "cross") # Create rowid and rowidcf - result = result.with_columns(pl.Series(range(result.shape[0])).alias("rowidcf")) + rowidcf = [i for i in range(newdata.shape[0]) for _ in range(result.select(pl.count()).item() // newdata.select(pl.count()).item())] + result = result.with_columns(pl.Series(rowidcf).alias("rowidcf")) result.datagrid_explicit = list(kwargs.keys()) diff --git a/tests/test_datagrid.py b/tests/test_datagrid.py index 9559900..c25b31e 100644 --- a/tests/test_datagrid.py +++ b/tests/test_datagrid.py @@ -19,3 +19,4 @@ def test_simple_grid(): def test_cf(): assert datagrid(newdata = mtcars, mpg = 32).shape[0] == 1 assert datagridcf(newdata = mtcars, mpg = [30, 32]).shape[0] == 64 + assert datagridcf(newdata = mtcars, mpg = [30, 32]).unique("rowidcf").shape[0] == 32 \ No newline at end of file