Skip to content

Commit

Permalink
expand predict method + distributed srr internal documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
pachadotdev committed Nov 20, 2024
1 parent 2a8d78e commit dac82ff
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 31 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Package: capybara
Type: Package
Title: Fast and Memory Efficient Fitting of Linear Models With High-Dimensional
Fixed Effects
Version: 0.6.0
Version: 0.7.0
Authors@R: c(
person(
given = "Mauricio",
Expand Down
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# capybara 0.7.0

* The predict method now allows to pass new data to predict the outcome.
* Fully documented code and tests according to rOpenSci standards.

# capybara 0.6.0

* Moves all the heavy computation to C++ using Armadillo and it exports the
Expand Down
4 changes: 1 addition & 3 deletions R/feglm_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,7 @@ partial_mu_eta_ <- function(eta, family, order) {
temp_var_ <- function(data) {
repeat {
tmp_var <- paste0("capybara_internal_variable_",
sample(letters, 5L, replace = TRUE),
collapse = ""
)
paste0(sample(letters, 5L, replace = TRUE), collapse = ""))
if (!(tmp_var %in% colnames(data))) {
break
}
Expand Down
77 changes: 68 additions & 9 deletions R/generics_predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,83 @@ NULL
#' @description Similar to the 'predict' method for 'glm' objects
#' @export
#' @noRd
predict.feglm <- function(object, type = c("link", "response"), ...) {
# Check validity of 'type'
predict.feglm <- function(object, newdata = NULL, type = c("link", "response"), ...) {
type <- match.arg(type)

# Compute requested type of prediction
x <- object[["eta"]]
if (!is.null(newdata)) {
check_data_(newdata)

data <- NA # just to avoid global variable warning
lhs <- NA
nobs_na <- NA
nobs_full <- NA
model_frame_(newdata, object$formula, NULL)
check_response_(data, lhs, object$family)
k_vars <- attr(terms(object$formula, rhs = 2L), "term.labels")
data <- transform_fe_(data, object$formula, k_vars)

x <- NA
nms_sp <- NA
p <- NA
model_response_(data, object$formula)

fes <- fixed_effects(object)
fes2 <- list()

for (name in names(fes)) {
# # match the FE rownames and replace each level in the data with the FE
fe <- fes[[name]]
fes2[[name]] <- fe[match(data[[name]], rownames(fe)), ]
}

eta <- x %*% object$coefficients + Reduce("+", fes2)
} else {
eta <- object[["eta"]]
}

if (type == "response") {
x <- object[["family"]][["linkinv"]](x)
eta <- object[["family"]][["linkinv"]](eta)
}

# Return prediction
x
as.numeric(eta)
}

#' @title Predict method for 'felm' objects
#' @description Similar to the 'predict' method for 'lm' objects
#' @export
#' @noRd
predict.felm <- function(object, ...) {
object[["fitted.values"]]
predict.felm <- function(object, newdata = NULL, type = c("response", "terms"), ...) {
type <- match.arg(type)

if (!is.null(newdata)) {
check_data_(newdata)

data <- NA # just to avoid global variable warning
lhs <- NA
nobs_na <- NA
nobs_full <- NA
model_frame_(newdata, object$formula, NULL)
k_vars <- attr(terms(object$formula, rhs = 2L), "term.labels")
data <- transform_fe_(data, object$formula, k_vars)

x <- NA
nms_sp <- NA
p <- NA
model_response_(data, object$formula)

fes <- fixed_effects(object)
fes2 <- list()

for (name in names(fes)) {
# # match the FE rownames and replace each level in the data with the FE
fe <- fes[[name]]
fes2[[name]] <- fe[match(data[[name]], rownames(fe)), ]
}

yhat <- x %*% object$coefficients + Reduce("+", fes2)
} else {
yhat <- object[["fitted.values"]]
}

as.numeric(yhat)
}
1 change: 0 additions & 1 deletion tests/testthat/test-feglm.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#' @srrstats {G5.8b} *Data of unsupported types (e.g., character or complex numbers in for functions designed only for numeric data)*
#' @srrstats {RE7.2} Demonstrate that output objects retain aspects of input data such as row or case names (see **RE1.3**).
#' @srrstats {RE7.3} Demonstrate and test expected behaviour when objects returned from regression software are submitted to the accessor methods of **RE4.2**--**RE4.7**.
#' @srrstats {RE7.4} Extending directly from **RE4.15**, where appropriate, tests should demonstrate and confirm that forecast errors, confidence intervals, or equivalent values increase with forecast horizons.
#'
#' @noRd
NULL
Expand Down
34 changes: 17 additions & 17 deletions tests/testthat/test-fepoisson.R
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ test_that("fepoisson is similar to fixest", {
#' @noRd
NULL

test_that("fepoisson time is the same adding noise to the data", {
test_that("fepoisson estimation is the same adding noise to the data", {
set.seed(123)
d <- data.frame(
x = rnorm(1000),
Expand All @@ -76,20 +76,20 @@ test_that("fepoisson time is the same adding noise to the data", {
expect_equal(coef(m1), coef(m2))
expect_equal(fixed_effects(m1), fixed_effects(m2))

t1 <- rep(NA, 10)
t2 <- rep(NA, 10)
for (i in 1:10) {
a <- Sys.time()
m1 <- fepoisson(y ~ x | f, d)
b <- Sys.time()
t1[i] <- b - a

a <- Sys.time()
m2 <- fepoisson(y2 ~ x | f, d)
b <- Sys.time()
t2[i] <- b - a
}
expect_gt(abs(median(t1) / median(t2)), 0.9)
expect_lt(abs(median(t1) / median(t2)), 1)
expect_lt(median(t1), median(t2))
# t1 <- rep(NA, 10)
# t2 <- rep(NA, 10)
# for (i in 1:10) {
# a <- Sys.time()
# m1 <- fepoisson(y ~ x | f, d)
# b <- Sys.time()
# t1[i] <- b - a

# a <- Sys.time()
# m2 <- fepoisson(y2 ~ x | f, d)
# b <- Sys.time()
# t2[i] <- b - a
# }
# expect_gt(abs(median(t1) / median(t2)), 0.9)
# expect_lt(abs(median(t1) / median(t2)), 1)
# expect_lt(median(t1), median(t2))
})
68 changes: 68 additions & 0 deletions tests/testthat/test-predict.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#' srr_stats (tests)
#'
#' @srrstats {G5.0} The tests use the widely known mtcars data set. It has few
#' observations, and it is easy to compare the results with the base R
#' functions.
#' @srrstats {G5.4b} We determine correctess for GLMs by comparison, checking
#' the estimates versus base R and hardcoded values obtained with Alpaca
#' (Stammann, 2018).
#' @srrstats {RE7.4} Extending directly from **RE4.15**, where appropriate, tests should demonstrate and confirm that forecast errors, confidence intervals, or equivalent values increase with forecast horizons.
#'
#' @noRd
NULL

test_that("predicted values increase the error outside the inter-quartile range for GLMs", {
m1 <- fepoisson(mpg ~ wt + disp | cyl, mtcars)

d1 <- mtcars[mtcars$mpg >= quantile(mtcars$mpg, 0.25) & mtcars$mpg <= quantile(mtcars$mpg, 0.75), ]
d2 <- mtcars[mtcars$mpg < quantile(mtcars$mpg, 0.25) | mtcars$mpg > quantile(mtcars$mpg, 0.75), ]

pred1 <- predict(m1, newdata = d1, type = "response")
pred2 <- predict(m1, newdata = d2, type = "response")

mape <- function(y, yhat) {
mean(abs(y - yhat) / y)
}

mape1 <- mape(d1$mpg, pred1)
mape2 <- mape(d2$mpg, pred2)

expect_lt(mape1, mape2)

# verify prediction compared to base R
m2 <- glm(mpg ~ wt + disp + as.factor(cyl), mtcars, family = quasipoisson())

pred1_base <- predict(m2, newdata = d1, type = "response")
pred2_base <- predict(m2, newdata = d2, type = "response")

expect_equal(round(pred1, 3), round(unname(pred1_base), 3))
expect_equal(round(pred2, 3), round(unname(pred2_base), 3))
})

test_that("predicted values increase the error outside the inter-quartile range for LMs", {
m1 <- felm(mpg ~ wt + disp | cyl, mtcars)

d1 <- mtcars[mtcars$mpg >= quantile(mtcars$mpg, 0.25) & mtcars$mpg <= quantile(mtcars$mpg, 0.75), ]
d2 <- mtcars[mtcars$mpg < quantile(mtcars$mpg, 0.25) | mtcars$mpg > quantile(mtcars$mpg, 0.75), ]

pred1 <- predict(m1, newdata = d1)
pred2 <- predict(m1, newdata = d2)

mape <- function(y, yhat) {
mean(abs(y - yhat) / y)
}

mape1 <- mape(d1$mpg, pred1)
mape2 <- mape(d2$mpg, pred2)

expect_lt(mape1, mape2)

# verify prediction compared to base R
m2 <- lm(mpg ~ wt + disp + as.factor(cyl), mtcars)

pred1_base <- predict(m2, newdata = d1)
pred2_base <- predict(m2, newdata = d2)

expect_equal(round(pred1, 3), round(unname(pred1_base), 3))
expect_equal(round(pred2, 3), round(unname(pred2_base), 3))
})

0 comments on commit dac82ff

Please sign in to comment.