Skip to content

Commit

Permalink
* missing mean.x in GLMs caused predict to error when `se.fit = T…
Browse files Browse the repository at this point in the history
…RUE` [Issue #50](#50)

* Prediction under the MPM failed with GLMs [Issue #51](#51)

close #50
close #51

issue #52 regarding SE still open!
  • Loading branch information
merliseclyde committed Nov 10, 2020
1 parent 9a1ee2f commit c47f2d8
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 39 deletions.
22 changes: 15 additions & 7 deletions R/bas_glm.R
Original file line number Diff line number Diff line change
Expand Up @@ -159,18 +159,26 @@ normalize.initprobs.glm <- function(initprobs, glm.obj) {
#' \item{priorprobs}{the prior probabilities of the models selected}
#' \item{logmarg}{values of the log of the marginal likelihood for the models}
#' \item{n.vars}{total number of independent variables in the full model,
#' including the intercept} \item{size}{the number of independent variables in
#' each of the models, includes the intercept} \item{which}{a list of lists
#' including the intercept}
#' \item{size}{the number of independent variables in
#' each of the models, includes the intercept}
#' \item{which}{a list of lists
#' with one list per model with variables that are included in the model}
#' \item{probne0}{the posterior probability that each variable is non-zero}
#' \item{coefficients}{list of lists with one list per model giving the GLM
#' estimate of each (nonzero) coefficient for each model.} \item{se}{list of
#' estimate of each (nonzero) coefficient for each model.}
#' \item{se}{list of
#' lists with one list per model giving the GLM standard error of each
#' coefficient for each model} \item{deviance}{the GLM deviance for each model}
#' coefficient for each model}
#' \item{deviance}{the GLM deviance for each model}
#' \item{modelprior}{the prior distribution on models that created the BMA
#' object} \item{Q}{the Q statistic for each model used in the marginal
#' likelihood approximation} \item{Y}{response} \item{X}{matrix of predictors}
#' \item{family}{family object from the original call} \item{betaprior}{family
#' object}
#' \item{Q}{the Q statistic for each model used in the marginal
#' likelihood approximation}
#' \item{Y}{response}
#' \item{X}{matrix of predictors}
#' \item{family}{family object from the original call}
#' \item{betaprior}{family
#' object for prior on coefficients, including hyperparameters}
#' \item{modelprior}{family object for prior on the models}
#' \item{include.always}{indices of variables that are forced into the model}
Expand Down
2 changes: 2 additions & 0 deletions R/bas_lm.R
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ is.solaris<-function() {
#' \item{X}{matrix of predictors}
#' \item{mean.x}{vector of means for each column of X (used in
#' \code{\link{predict.bas}})}
#' \link{weights} used in model fitting
#' \item{include.always}{indices of variables that are forced into the model}
#'
#' The function \code{\link{summary.bas}}, is used to print a summary of the
Expand Down Expand Up @@ -844,6 +845,7 @@ bas.lm <- function(formula,
result$xlevels <- .getXlevels(mt, mf)
result$terms <- mt
result$model <- mf
result$weights <- weights

class(result) <- c("bas")
if (prior == "EB-global") {
Expand Down
75 changes: 62 additions & 13 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,13 @@ predict.bas <- function(object,
estimator = "BMA",
na.action = na.pass,
...) {
if (!(estimator %in% c("BMA", "HPM", "MPM", "BPM"))) {
if (!(estimator %in% c("BMA", "HPM", "MPM", "BPM", "MPMold"))) {
stop("Estimator must be one of 'BMA', 'BPM', 'HPM', or 'MPM'.")
}

if (estimator == "MPM") {
object = extract_MPM(object)
}
tt <- terms(object)

if (missing(newdata) || is.null(newdata)) {
Expand Down Expand Up @@ -277,7 +280,7 @@ predict.bas <- function(object,
df <- object$df


if (estimator == "MPM") {
if (estimator == "MPMold") {
nvar <- object$n.vars - 1
bestmodel <- (0:nvar)[object$probne0 > .5]
newX <- cbind(1, newdata)
Expand Down Expand Up @@ -415,12 +418,11 @@ predict.bas <- function(object,

if (se.fit) {
if (estimator != "BMA") {
se <- .se.fit(fit, newdata, object, insample)
se <- .se.fit(fit, newdata, object, insample, type)
}
else {
se <- .se.bma(
Ybma, newdata, Ypred, best, object,
insample
Ybma, newdata, Ypred, best, object, insample, type
)
}
}
Expand Down Expand Up @@ -521,15 +523,20 @@ fitted.bas <- function(object,
if (is.null(top)) {
top <- nmodels
}

if (estimator == "MPM") { top = 1 }


if (estimator == "HPM") {
yhat <- predict(
yhat <- predict(
object,
newdata = NULL,
top = 1,
estimator = "HPM", type = type,
na.action = na.action
)$fit
}

if (estimator == "BMA") {
yhat <- predict(
object,
Expand All @@ -548,6 +555,15 @@ fitted.bas <- function(object,
na.action = na.action
)$fit
}
if (estimator == "MPMold") {
yhat <- predict(
object,
newdata = NULL,
top = top,
estimator = "MPMold", type = type,
na.action = na.action
)$fit
}
if (estimator == "BPM") {
yhat <- predict(
object,
Expand All @@ -558,18 +574,27 @@ fitted.bas <- function(object,
)$fit
}

yhat <- predict(
object,
newdata = NULL,
top = top,
estimator = estimator,
type = type,
na.action = na.action
)$fit

return(as.vector(yhat))
}

.se.fit <- function(yhat, X, object, insample) {
.se.fit <- function(yhat, X, object, insample, type) {
n <- object$n
model <- attr(yhat, "model")
best <- attr(yhat, "best")

df <- object$df[best]

mean.x = object$mean.x # glms don't have centered X for intercept so need t
# to center X and newX to get the right hat values with orthogonal X
# to center X and newX to get the right hat values with weights
if (is.null(mean.x)) {
mean.x =colMeans(object$X[,-1])
X = sweep(X, 2, mean.x)
Expand All @@ -578,9 +603,33 @@ fitted.bas <- function(object,


shrinkage <- object$shrinkage[best]

if (insample) {
xiXTXxiT <- hat(object$X[, model + 1]) - 1 / n
} else {

if (!is.null(object$family$family)) {
if (type == 'link') {
mu.eta <- object$family$mu.eta(as.vector(yhat))
weights <- mu.eta^2/object$family$variance(object$family$linkinv(yhat))
}
else {
mu.eta <- object$family$mu.eta(object$family$link(as.vector(yhat)))
weights <- mu.eta^2/object$family$variance(as.vector(yhat))
}
}
else {
if (!is.null(object$weights)) {
weights <- object$weights
}
else {
weights = rep(1, object$n)
}
}
# browser() FIX issue #52
xiXTXxiT <- hat(diag(sqrt(weights)) %*% object$X[, model + 1])/weights - 1 / sum(weights)
}
else {

#Fix below! FIX issue #52
X <- cbind(1, X[, model[-1], drop = FALSE])
oldX <- (sweep(object$X[, -1], 2, mean.x))[, model[-1]] #center
# browser()
Expand All @@ -589,9 +638,9 @@ fitted.bas <- function(object,
}
scale_fit <- 1 / n + object$shrinkage[best] * xiXTXxiT
if (is.null(object$family)) {
family <- gaussian()
object$family <- gaussian()
}
if (eval(family)$family == "gaussian") {
if (object$family$family == "gaussian") {
ssy <- var(object$Y) * (n - 1)
bayes_mse <- ssy * (1 - shrinkage * object$R2[best]) / df
}
Expand All @@ -607,7 +656,7 @@ fitted.bas <- function(object,
))
}

.se.bma <- function(fit, Xnew, Ypred, best, object, insample) {
.se.bma <- function(fit, Xnew, Ypred, best, object, insample, type) {
n <- object$n

df <- object$df[best]
Expand Down
22 changes: 15 additions & 7 deletions man/bas.glm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/bas.lm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/fitted.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/variable.names.pred.bas.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

68 changes: 56 additions & 12 deletions tests/testthat/test-predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,42 @@ test_that("predict.bas.glm", {
pima_pred <- predict(pima_gprior,
estimator = "HPM",
se.fit = FALSE)
pima_top <- predict(pima_gprior,
estimator = "BMA", top=1,
se.fit = TRUE)

expect_equal(pima_pred$fit, pima_top$fit, check.attributes = FALSE)

})



#Fixed Issue #51
test_that("MPM and predict glm", {
data("Pima.tr", package="MASS")
data("Pima.te", package="MASS")
pima_gprior <- bas.glm(type ~ ., data = Pima.tr,
betaprior = g.prior(g=as.numeric(nrow(Pima.tr))),
family=binomial())
pima_MPM = extract_MPM(pima_gprior)

expect_equal(predict(pima_gprior, estimator = "MPM", se.fit = FALSE)$fit,
predict(pima_MPM, se.fit = FALSE)$fit,
check.attributes = FALSE)


pima_pred <- predict(pima_gprior,
estimator = "MPM", type = "link",
se.fit = FALSE)
pima_fit <- fitted(pima_gprior,
estimator = "MPM")

expect_equal(pima_pred$fit, pima_fit, check.attributes = FALSE)

# should not error
expect_error(predict(pima_gprior,
estimator = "HPM",
se.fit = TRUE))
# expect_null(plot(confint(pima_pred, parm = "mean")))
# should not error
expect_error( predict(hald_gprior, newdata=Pima.te, estimator = "HPM",
se.fit = TRUE))
#expect_null(plot(confint(pima_pred)))
})


# Issue #52 SE's are incorrect for glms and weighted regression
test_that("se.fit.glm", {
data("Pima.tr", package="MASS")
data("Pima.te", package="MASS")
Expand All @@ -57,20 +81,40 @@ pima.bic = bas.glm(type ~ ., data=Pima.tr, n.models= 2^7,
betaprior=bic.prior(n=200), family=binomial(),
modelprior=beta.binomial(1,1))

fit.bic = predict(pima.bic, se.fit = TRUE, top=1, type="link", estimator="HPM")
pred.bic = predict(pima.bic, newdata=Pima.te, se.fit = TRUE, top=1, type="link")

form = paste("type ~ ",
paste0((pred.bic$best.vars[pred.bic$bestmodel[[1]] + 1])[- 1],
collapse = "+"))

pima.glm = glm(form, data=Pima.tr, family=binomial())
fit.glm = predict(pima.glm, se.fit=TRUE, type='link')
pred.glm = predict(pima.glm, newdata=Pima.te, se.fit=TRUE, type='link')

expect_true(all.equal(pred.glm$fit, pred.bic$fit, check.attributes = FALSE))
expect_true(all.equal(fit.glm$fit, fit.bic$fit, check.attributes = FALSE))


# issue #50 in github regarding se.fit failing; debugging indicates se.fit is
# incorrect
# Should be expect_true
expect_false(all.equal(pred.glm$se.fit, pred.bic$se.fit, check.attributes = FALSE))
# Should be expect_equal

expect_equal(fit.glm$se.fit, fit.bic$se.fit, check.attributes = FALSE)
expect_equal(pred.glm$se.fit, pred.bic$se.fit, check.attributes = FALSE)


})

# Added feature issue #53
test_that("MPM and predict in lm", {
data(Hald, package="BAS")
hald_bic = bas.lm(Y ~ ., data=Hald, alpha=13, prior="BIC",
modelprior = uniform())

hald_MPM = extract_MPM(hald_bic)
expect_equal(predict(hald_bic, estimator = "MPM")$fit,
predict(hald_MPM)$fit, check.attributes = FALSE)
expect_equal(predict(hald_bic, estimator = "MPM")$fit,
predict(hald_bic, estimator = "MPMold")$fit,
check.attributes = FALSE)
})

0 comments on commit c47f2d8

Please sign in to comment.