From 3e33a05fb0e2ed6ee806299b79cf7c807bffd059 Mon Sep 17 00:00:00 2001 From: Adrien Lamarche Date: Mon, 20 Nov 2023 16:13:23 -0500 Subject: [PATCH] adds grid_type argument to datagrid --- marginaleffects/datagrid.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/marginaleffects/datagrid.py b/marginaleffects/datagrid.py index 0b29759..7333c2f 100644 --- a/marginaleffects/datagrid.py +++ b/marginaleffects/datagrid.py @@ -9,6 +9,7 @@ def datagrid( newdata=None, FUN_numeric=lambda x: x.mean(), FUN_other=lambda x: x.mode()[0], # mode can return multiple values + grid_type = "typical", **kwargs ): """ @@ -25,6 +26,9 @@ def datagrid( - `newdata`: DataFrame (one and only one of the `model` and `newdata` arguments can be used.) - `FUN_numeric`: The function to be applied to numeric variables. - `FUN_other`: The function to be applied to other variable types. + - `grid_type`: + * "typical": variables whose values are not explicitly specified by the user in `...` are set to their mean or mode, or to the output of the functions supplied to `FUN_type` arguments. + * "counterfactual": the entire dataset is duplicated for each combination of the variable values specified in `...`. Variables not explicitly supplied to `datagrid()` are set to their observed values in the original dataset. Returns: @@ -57,6 +61,9 @@ def datagrid( if model is None and newdata is None: raise ValueError("One of model or newdata must be specified") + if grid_type == "counterfactual": + return datagridcf(**kwargs, model=model, newdata=newdata) + if newdata is None: newdata = get_modeldata(model)