Many modern problems leverage machine learning methods to predict outcomes based on observable covariates. In the subsequent statistical modeling, e.g. to understand population trends in outcomes, often involves treating predicted outcomes as interchangeable with observed outcomes. This leads to biases in the downstream inference results due to the reason that machine learning models capture the mean function well but not the variance. postpi
is an R package we developed to correct such biases in the downstream inferential analyses using outcomes predicted with an arbitrary machine learning method.
For example, if we want to study associations between a phenotype of interest and expression for gene k, we can fit a simple regression model between the outcome vector and the covariate of interest , and we then estimate the coefficient :
However, since we do not always observe outcomes for all samples, we often predict outcomes using observable covariates from a machine learning tool. Then predicted outcomes from complicated prediction models become inputs into subsequent statistical analyses as:
Then we estimate and conduct hypothesis testing as if contains the same property of -- this is not true! Interchanging predicted outcomes as observed outcomes in the subsequent inference model causes problems.
Let's further examine the problem by plotting observed outcomes (grey color) and predicted outcomes (blue color) against a covariate of interest X from a simulated example:
From the plot, we observe the problem -- standard errors of the estimate using predicted outcomes are much smaller than the standard errors of the estimate using observed outcomes. Reduced standard errors lead to larger test statistics when conducting hypothesis testings for the estimates and therefore further lead to more false positives in the analyses.
Now our package postpi
becomes handy to correct such biases in the estimates, standard errors and test statistics for both continuous and categorical data. postpi
contains three functions:
-
postpi_relate
- required inputs: a data set (i.e. testing set) containing only observed and predicted outcomes, and column name for observed outcomes.
- optional inputs: a method from the caret package that user defines to relate categorical observed outcome and probablities of predicted categories. The default method set for the function is k-nearest neighbors.
- purpose: the function models the relationship between observed outcomes and predicted outcomes/probabilities, and returns the relationship model.
-
postpi
- required inputs: a data set (i.e. validation set) containing predicted outcomes and covariates, the relationship model estimated from
postpi_relate()
, and an inference formula. - optional inputs: the number of bootstrap times, and a seed number. The default number of bootstrap times is 100 and the default seed number is 1234.
- purposes: the function provides the corrected inference result table using a bootstrap approach for continuous/catigorical outcomes. The format of the output is a tidy table with 5 colomns: term, estimate, std.error, statistic, p.value.
- required inputs: a data set (i.e. validation set) containing predicted outcomes and covariates, the relationship model estimated from
-
postpi_der
- required inputs: a testing set containing observed and predicted continuous outcomes, column names for observed and predicted outcomes, a validation set containing predicted outcomes and covariates, and an inference formula.
- optional inputs: None.
- purposes: the function provides the corrected inference result table using a derivation approach only for continuous outcomes. The format of the output is a tidy table with 5 colomns: term, estimate, std.error, statistic, p.value.
-
Prepare a data set with observed outcomes and predicted outcomes/probabilities for each predicted categories, and covariates of interest for subsequant inferential analyses.
-
Split the data set into testing and validation sets. On testing set, use
postpi_relate()
to estimate a relationship model. -
On validation set, use
postpi()
/postpi_der()
to conduct inferential analyses.
Note: If users have a subset of observed outcomes but no predicted outcomes, they should split the data set into three sets -- training, testing, and validation sets. On training set, use a machine learning method to train a prediction model, and apply it to testing and validation sets to get predicted results. Then users should repeat step 2-3 above to obtain inference results.
devtools::install_github("SiruoWang/postpi")
Package can be loaded into R as follows:
library('postpi')
In this section, we include a simple example to use the package on a data set with continuous outcomes. We also provide a detailed tutorial on how to use our package with multiple examples on both continuous and categorical data in the vignettes.
In this example, we use a data set RINdata
available in the package. RINdata
contains a column of observed RIN values named actual
, a column of predicted RIN values named prediction
obtained from a previous trained data set, and 200 columns of gene expression regions. We want to study associations between RINs and gene expression levels. A detailed description of the RINdata
data set is available at our paper [Post-prediction inference](a link to preprint).
-
We load
RINdata
and split data into testing and validation sets using thersample
package.data("RINdata") data <- RINdata set.seed(2019) data_split <- rsample::initial_split(data, prop = 1/2) testing <- rsample::training(data_split) validation <- rsample::testing(data_split)
-
We select the columns of the observed and predicted outcomes from
RINdata
, and pass it topostpi_relate()
to estimate a relationship model namedrel_model
. We pipe our code using%>%
from thedplyr
package.library(dplyr) rel_model <- testing %>% select(actual, predictions) %>% postpi_relate(actual)
-
We define an inference formula
predictions ~ region_10
such that we want to relate gene expression levels in region 10 to predicted RINs. Then we pass in the validation set, the defined inference formula, and the relationship modelrel_model
estimated above to the inference functionpostpi()
. Inpostpi()
we estiamte inference results using a bootstrap approach and we obtain the results in a tidy table format namedresults_postpi
.inf_formula <- predictions ~ region_10 results_postpi <- validation %>% postpi(rel_model, inf_formula)
We compare the inference results using our package to the gold standard and the no correction approach & include more examples in the vignettes.
Siruo Wang: [email protected]