diff --git a/src/bas.h b/src/bas.h index dfa8207e..920efec8 100644 --- a/src/bas.h +++ b/src/bas.h @@ -55,7 +55,8 @@ double trunc_beta_binomial(int modeldim, int p, double *hyper); double trunc_poisson(int modeldim, int p, double *hyper); double trunc_power_prior(int modeldim, int p, double *hyper); double Bernoulli(int *model, int p, double *hyper); -double compute_prior_probs(int *model, int modeldim, int p, SEXP modelprior); +int no_prior_inclusion_is_1(int p, double *probs); +double compute_prior_probs(int *model, int modeldim, int p, SEXP modelprior, int noInclusionIs1); void compute_margprobs_old(Bit **models, SEXP Rmodelprobs, double *margprobs, int k, int p); void compute_modelprobs(SEXP modelprobs, SEXP logmarg, SEXP priorprobs, int k); void set_bits(char *bits, int subset, int *pattern, int *position, int n); diff --git a/src/glm_deterministic.c b/src/glm_deterministic.c index 41708ab0..95c86686 100644 --- a/src/glm_deterministic.c +++ b/src/glm_deterministic.c @@ -61,6 +61,7 @@ SEXP glm_deterministic(SEXP Y, SEXP X, SEXP Roffset, SEXP Rweights, int *model = (int *) R_alloc(p, sizeof(int)); memset(model, 0, p*sizeof(int)); + int noInclusionIs1 = no_prior_inclusion_is_1(p, probs); k = topk(models, probs, k, vars, n, p); /* now fit all top k models */ @@ -82,7 +83,7 @@ SEXP glm_deterministic(SEXP Y, SEXP X, SEXP Roffset, SEXP Rweights, SEXP glm_fit = PROTECT(glm_FitModel(X, Y, Rmodel_m, Roffset, Rweights, glmfamily, Rcontrol, Rlaplace, betapriorfamily)); - double prior_m = compute_prior_probs(model,pmodel,p, modelprior); + double prior_m = compute_prior_probs(model,pmodel,p, modelprior, noInclusionIs1); logmargy = REAL(getListElement(getListElement(glm_fit, "lpy"),"lpY"))[0]; shrinkage_m = REAL(getListElement(getListElement(glm_fit, "lpy"), "shrinkage"))[0]; diff --git a/src/glm_mcmc.c b/src/glm_mcmc.c index 61453d62..7764cf36 100644 --- a/src/glm_mcmc.c +++ b/src/glm_mcmc.c @@ -52,6 +52,7 @@ SEXP glm_mcmc(SEXP Y, SEXP X, SEXP Roffset, SEXP Rweights, n = sortvars(vars, probs, p); for (i =n; i 1) R2_m = 1.0 - (mse_m * (double) ( nobs - pmodel))/SSY; - prior_m = compute_prior_probs(model,pmodel,p, modelprior); + prior_m = compute_prior_probs(model,pmodel,p, modelprior, noInclusionIs1); gexpectations(p, pmodel, nobs, R2_m, alpha, INTEGER(method)[0], RSquareFull, SSY, &logmargy, &shrinkage_m); postnew = logmargy + log(prior_m); } @@ -515,7 +516,7 @@ SEXP mcmcbas(SEXP Y, SEXP X, SEXP Rweights, SEXP Rprobinit, SEXP Rmodeldim, SEXP gexpectations(p, pmodel, nobs, R2_m, alpha, INTEGER(method)[0], RSquareFull, SSY, &logmargy, &shrinkage_m); REAL(logmarg)[m] = logmargy; REAL(shrinkage)[m] = shrinkage_m; - REAL(priorprobs)[m] = compute_prior_probs(model,pmodel,p, modelprior); + REAL(priorprobs)[m] = compute_prior_probs(model,pmodel,p, modelprior, noInclusionIs1); if (m > 1) { diff --git a/src/lm_sampleworep.c b/src/lm_sampleworep.c index da236994..2fc5ce95 100644 --- a/src/lm_sampleworep.c +++ b/src/lm_sampleworep.c @@ -241,6 +241,7 @@ extern SEXP sampleworep_new(SEXP Y, SEXP X, SEXP Rweights, SEXP Rprobinit, struct Var *vars = (struct Var *) R_alloc(p, sizeof(struct Var)); // Info about the model variables. probs = REAL(Rprobs); int n = sortvars(vars, probs, p); + int noInclusionIs1 = no_prior_inclusion_is_1(p, probs); SEXP Rse_m = NULL, Rcoef_m = NULL, Rmodel_m = NULL; RSquareFull = CalculateRSquareFull(XtY, XtX, XtXwork, XtYwork, Rcoef_m, Rse_m, p, nobs, yty, SSY); @@ -293,7 +294,7 @@ extern SEXP sampleworep_new(SEXP Y, SEXP X, SEXP Rweights, SEXP Rprobinit, // gexpectations(p, pmodel, nobs, R2_m, alpha, INTEGER(method)[0], RSquareFull, SSY, &logmargy, &shrinkage_m); // check should this depend on rank or pmodel? - double prior_m = compute_prior_probs(model,pmodel,p, modelprior); + double prior_m = compute_prior_probs(model,pmodel,p, modelprior, noInclusionIs1); @@ -349,7 +350,7 @@ extern SEXP sampleworep_new(SEXP Y, SEXP X, SEXP Rweights, SEXP Rprobinit, // Rprintf("rank %d dim %d\n", rank_m, pmodel); // gexpectations(p, pmodel, nobs, R2_m, alpha, INTEGER(method)[0], RSquareFull, SSY, &logmargy, &shrinkage_m); - prior_m = compute_prior_probs(model,pmodel,p, modelprior); + prior_m = compute_prior_probs(model,pmodel,p, modelprior, noInclusionIs1); SetModel2(logmargy, shrinkage_m, prior_m, sampleprobs, logmarg, shrinkage, priorprobs, m); SetModel_lm(Rcoef_m, Rse_m, Rmodel_m, mse_m, R2_m, beta, se, modelspace, mse, R2,m); UNPROTECT(3); diff --git a/src/model_probabilities.c b/src/model_probabilities.c index 46cb8e5c..b9d32328 100644 --- a/src/model_probabilities.c +++ b/src/model_probabilities.c @@ -60,7 +60,19 @@ void compute_margprobs_old(Bit **models, SEXP Rmodelprobs, double *margprobs, in } } -double compute_prior_probs(int *model, int modeldim, int p, SEXP modelprior) { +int no_prior_inclusion_is_1(int p, double *probs) { + + int noInclusionIs1 = 0; + // loop starts from 1 since the intercept is corrected for in the model prior functions + for (int i = 1; i < p; i++) { + if (probs[i] > (1.0 - DBL_EPSILON)) { + noInclusionIs1++; + } + } + return noInclusionIs1; +} + +double compute_prior_probs(int *model, int modeldim, int p, SEXP modelprior, int noInclusionIs1) { const char *family; double *hyper_parameters, priorprob = 1.0; @@ -68,6 +80,9 @@ double compute_prior_probs(int *model, int modeldim, int p, SEXP modelprior) { family = CHAR(STRING_ELT(getListElement(modelprior, "family"),0)); hyper_parameters = REAL(getListElement(modelprior,"hyper.parameters")); + // reduce the model space by the number of predictors that are always included + p -= noInclusionIs1; + modeldim -= noInclusionIs1; if (strcmp(family, "Beta-Binomial") == 0) priorprob = beta_binomial(modeldim, p, hyper_parameters); diff --git a/tests/testthat/test-model-priors.R b/tests/testthat/test-model-priors.R index 2ab9184a..8226f825 100644 --- a/tests/testthat/test-model-priors.R +++ b/tests/testthat/test-model-priors.R @@ -37,3 +37,14 @@ test_that("Bernoulli hereditary prior", { modelprior = Bernoulli.heredity(.5, NULL)) ) }) + +test_that("Always include changes model-prior", { + + res <- bas.lm(Y ~ ., data = Hald, modelprior = beta.binomial(1, 1), + include.always = ~ 1 + X1 + X2) + ord <- order(lengths(res$which)) + expect_equal( + object = res$priorprobs[ord], + expected = c(0.333333333333333, 0.166666666666667, 0.166666666666667, 0.333333333333333) + ) +})