Skip to content

Commit

Permalink
test fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
vincentarelbundock committed Nov 16, 2024
1 parent 26182d0 commit ae23a4f
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 87 deletions.
81 changes: 34 additions & 47 deletions R/broom.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,86 +9,72 @@ generics::glance


#' tidy helper
#'
#'
#' @noRd
#' @export
tidy.comparisons <- function(x, ...) {
insight::check_if_installed("tibble")
out <- tibble::as_tibble(x)
if (!"term" %in% names(out)) {
lab <- seq_len(nrow(out))
if ("group" %in% colnames(out) || is.character(attr(x, "by"))) {
tmp <- c("group", attr(x, "by"))
tmp <- Filter(function(j) j %in% colnames(x), tmp)
if (length(tmp) > 0) {
tmp <- do.call(paste, out[, tmp])
if (anyDuplicated(tmp)) {
tmp <- paste(seq_len(nrow(out)), tmp)
}
lab <- tmp
}
}
out[["term"]] <- lab
}
return(out)
insight::check_if_installed("tibble")
out <- tibble::as_tibble(x)
return(out)
}


#' tidy helper
#'
#'
#' @noRd
#' @export
tidy.slopes <- tidy.comparisons


#' tidy helper
#'
#'
#' @noRd
#' @export
tidy.predictions <- tidy.comparisons


#' tidy helper
#'
#'
#' @noRd
#' @export
tidy.hypotheses <- tidy.comparisons


#' tidy helper
#'
#'
#' @noRd
#' @export
tidy.marginalmeans <- function(x, ...) {
insight::check_if_installed("tibble")
tibble::as_tibble(x)
insight::check_if_installed("tibble")
tibble::as_tibble(x)
}


#' @noRd
#' @export
glance.slopes <- function(x, ...) {
insight::check_if_installed("insight")
insight::check_if_installed("modelsummary")
model <- tryCatch(attr(x, "model"), error = function(e) NULL)
if (is.null(model) || isTRUE(checkmate::check_string(model))) {
model <- tryCatch(attr(x, "call")[["model"]], error = function(e) NULL)
}
gl <- suppressMessages(suppressWarnings(try(
modelsummary::get_gof(model, ...), silent = TRUE)))
if (inherits(gl, "data.frame")) {
out <- data.frame(gl)
} else {
out <- NULL
}
vcov.type <- attr(x, "vcov.type")
if (is.null(out) && !is.null(vcov.type)) {
out <- data.frame("vcov.type" = vcov.type)
} else if (!is.null(out)) {
out[["vcov.type"]] <- vcov.type
}
out <- tibble::as_tibble(out)
return(out)
insight::check_if_installed("insight")
insight::check_if_installed("modelsummary")
model <- tryCatch(attr(x, "model"), error = function(e) NULL)
if (is.null(model) || isTRUE(checkmate::check_string(model))) {
model <- tryCatch(attr(x, "call")[["model"]], error = function(e) NULL)
}
gl <- suppressMessages(suppressWarnings(try(
modelsummary::get_gof(model, ...),
silent = TRUE)))
if (inherits(gl, "data.frame")) {
out <- data.frame(gl)
} else {
out <- NULL
}
vcov.type <- attr(x, "vcov.type")
if (is.null(out) && !is.null(vcov.type)) {
out <- data.frame("vcov.type" = vcov.type)
} else if (!is.null(out)) {
out[["vcov.type"]] <- vcov.type
}
out <- tibble::as_tibble(out)
return(out)
}


Expand All @@ -109,4 +95,5 @@ glance.hypotheses <- glance.slopes

#' @noRd
#' @export
glance.marginalmeans <- glance.slopes
glance.marginalmeans <- glance.slopes

81 changes: 41 additions & 40 deletions inst/tinytest/test-by.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,10 @@ mod <- lm(mpg ~ factor(cyl) * hp + wt, data = dat)
mar <- margins(mod, at = list(cyl = unique(dat$cyl)))
mar <- data.frame(summary(mar))
mfx <- slopes(
mod,
by = "cyl",
newdata = datagrid(cyl = c(4, 6, 8), grid_type = "counterfactual"))
mod,
by = "cyl",
newdata = datagrid(cyl = c(4, 6, 8), grid_type = "counterfactual"))
mfx <- mfx[order(mfx$term, mfx$contrast), ]
expect_equivalent(mfx$estimate, mar$AME)
expect_equivalent(mfx$std.error, mar$SE, tolerance = 1e6)

Expand All @@ -77,9 +78,9 @@ expect_equivalent(mfx$std.error, mar$SE, tolerance = 1e6)
# issue #434 by with character precitors
dat <- read.csv("https://vincentarelbundock.github.io/Rdatasets/csv/AER/Affairs.csv")
mod <- glm(
affairs ~ children + gender + yearsmarried,
family = poisson,
data = dat)
affairs ~ children + gender + yearsmarried,
family = poisson,
data = dat)
p <- predictions(mod, by = "children")
expect_equivalent(nrow(p), 2)
expect_false(anyNA(p$estimate))
Expand All @@ -97,8 +98,8 @@ cmp <- comparisons(mod, type = "probs", by = "group")
expect_equivalent(nrow(cmp), 9)

by <- data.frame(
group = c("3", "4", "5"),
by = c("(3,4)", "(3,4)", "(5)"))
group = c("3", "4", "5"),
by = c("(3,4)", "(3,4)", "(5)"))
p1 <- predictions(mod, type = "probs")
p2 <- predictions(mod, type = "probs", by = by)
p3 <- predictions(mod, type = "probs", by = by, hypothesis = "sequential")
Expand All @@ -113,26 +114,26 @@ cmp <- comparisons(mod, type = "probs", by = "am")
expect_equivalent(nrow(cmp), 18)

cmp <- comparisons(
mod,
variables = "am",
by = by,
type = "probs")
mod,
variables = "am",
by = by,
type = "probs")
expect_equivalent(nrow(cmp), 2)

cmp <- comparisons(
mod,
variables = "am",
by = by,
hypothesis = "sequential",
type = "probs")
mod,
variables = "am",
by = by,
hypothesis = "sequential",
type = "probs")
expect_equivalent(nrow(cmp), 1)


# Issue #481: warning on missing by categories
mod <- nnet::multinom(factor(gear) ~ mpg + am * vs, data = mtcars, trace = FALSE)
by <- data.frame(
by = c("4", "5"),
group = 4:5)
by = c("4", "5"),
group = 4:5)
expect_warning(comparisons(mod, variables = "mpg", newdata = "mean", by = by))
expect_warning(predictions(mod, newdata = "mean", by = by))

Expand All @@ -159,13 +160,13 @@ expect_equivalent(nrow(pre2), 96)
dat <- mtcars
mod <- glm(gear ~ cyl + am, family = poisson, data = dat)
mfx <- avg_slopes(
mod,
by = c("cyl", "am"),
newdata = datagrid(
cyl = unique,
am = unique,
grid_type = "counterfactual")) |>
dplyr::arrange(term, cyl, am)
mod,
by = c("cyl", "am"),
newdata = datagrid(
cyl = unique,
am = unique,
grid_type = "counterfactual")) |>
dplyr::arrange(term, cyl, am)
mar <- margins(mod, at = list(cyl = unique(dat$cyl), am = unique(dat$am)))
mar <- summary(mar)
# margins doesn't treat the binary am as binary automatically
Expand All @@ -179,16 +180,16 @@ dat$cyl <- factor(dat$cyl)
dat$am <- as.logical(dat$am)
mod <- glm(gear ~ cyl + am, family = poisson, data = dat)
mfx <- comparisons(
mod,
by = c("cyl", "am"),
newdata = datagrid(
cyl = unique,
am = unique,
grid_type = "counterfactual"))
mod,
by = c("cyl", "am"),
newdata = datagrid(
cyl = unique,
am = unique,
grid_type = "counterfactual"))

mfx <- tidy(mfx)

mfx <- mfx[order(mfx$term, mfx$contrast, mfx$cyl, mfx$am),]
mfx <- mfx[order(mfx$term, mfx$contrast, mfx$cyl, mfx$am), ]
mar <- margins(mod, at = list(cyl = unique(dat$cyl), am = unique(dat$am)))
mar <- summary(mar)
expect_equivalent(mfx$estimate, mar$AME, tolerance = tol)
Expand All @@ -200,9 +201,9 @@ dat <- transform(mtcars, vs = vs, am = as.factor(am), cyl = as.factor(cyl))
mod <- lm(mpg ~ qsec + am + cyl, dat)
fun <- \(hi, lo) mean(hi) / mean(lo)
cmp1 <- comparisons(mod, variables = "cyl", comparison = fun, by = "am") |>
dplyr::arrange(am, contrast)
dplyr::arrange(am, contrast)
cmp2 <- comparisons(mod, variables = "cyl", comparison = "ratioavg", by = "am") |>
dplyr::arrange(am, contrast)
dplyr::arrange(am, contrast)
expect_equivalent(cmp1$estimate, cmp2$estimate)
expect_equivalent(cmp1$std.error, cmp2$std.error)
expect_equal(nrow(cmp1), 4)
Expand All @@ -218,18 +219,18 @@ cmp2 <- comparisons(mod, variables = "am") %>%
dplyr::group_by(cyl) %>%
dplyr::summarize(estimate = mean(estimate), .groups = "keep") |>
dplyr::ungroup()
cmp3 <- predictions(mod) |>
aggregate(estimate ~ am + cyl, FUN = mean) |>
aggregate(estimate ~ cyl, FUN = diff)
cmp3 <- predictions(mod) |>
aggregate(estimate ~ am + cyl, FUN = mean) |>
aggregate(estimate ~ cyl, FUN = diff)
expect_equivalent(cmp1$estimate, cmp2$estimate)
expect_equivalent(cmp1$estimate, cmp3$estimate)


# Issue #1058
tmp <- mtcars
tmp <- tmp[c('mpg', 'cyl', 'hp')]
tmp <- tmp[c("mpg", "cyl", "hp")]
tmp$cyl <- as.factor(tmp$cyl) # 3 levels
tmp$hp <- as.factor(tmp$hp)
tmp$hp <- as.factor(tmp$hp)
bygrid <- datagrid(newdata = tmp, by = "cyl", hp = unique)
expect_equivalent(nrow(bygrid), 23)

Expand Down

0 comments on commit ae23a4f

Please sign in to comment.