Skip to content

Commit

Permalink
Merge pull request #23 from tidymodels/pca-steps
Browse files Browse the repository at this point in the history
add all pca steps
  • Loading branch information
EmilHvitfeldt authored Jul 21, 2024
2 parents 025dc83 + b9dd9aa commit c65357f
Show file tree
Hide file tree
Showing 10 changed files with 248 additions and 19 deletions.
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ S3method(orbital,step_novel)
S3method(orbital,step_nzv)
S3method(orbital,step_other)
S3method(orbital,step_pca)
S3method(orbital,step_pca_sparse)
S3method(orbital,step_pca_sparse_bayes)
S3method(orbital,step_pca_truncated)
S3method(orbital,step_range)
S3method(orbital,step_ratio)
S3method(orbital,step_rename)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@

* Support for `step_bin2factor()`, `step_discretize()`, `step_lencode_mixed()`, `step_lencode_glm()`, `step_lencode_bayes()` has been added. (#22)

* Support for `step_pca_sparse()`, `step_pca_sparse_bayes()` and `step_pca_truncated()` as been added. (#23)

* `orbital()` now works on `tune::last_fit()` objects. (#13)

* `orbital_predict()` has been removed and replaced with the more idiomatic `predict()` method. (#10)
Expand Down
23 changes: 23 additions & 0 deletions R/recipes-utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,27 @@ lencode_helper <- function(x) {
out <- c(out, eq)
}
out
}

pca_helper <- function(rot, prefix, all_vars) {
colnames(rot) <- recipes::names0(ncol(rot), prefix)

used_vars <- pca_naming(colnames(rot), prefix) %in%
pca_naming(all_vars, prefix)

rot <- rot[, used_vars]

row_nms <- rownames(rot)

out <- character(length(all_vars))
for (i in seq_along(all_vars)) {
out[i] <- paste(row_nms, "*", rot[, i], collapse = " + ")
}

names(out) <- all_vars
out
}

pca_naming <- function(x, prefix) {
gsub(paste0(prefix, "0"), prefix, x)
}
20 changes: 1 addition & 19 deletions R/step_pca.R
Original file line number Diff line number Diff line change
@@ -1,24 +1,6 @@
#' @export
orbital.step_pca <- function(x, all_vars, ...) {
rot <- x$res$rotation
colnames(rot) <- recipes::names0(ncol(rot), x$prefix)

used_vars <- pca_naming(colnames(rot), x$prefix) %in%
pca_naming(all_vars, x$prefix)

rot <- rot[, used_vars]

row_nms <- rownames(rot)

out <- character(length(all_vars))
for (i in seq_along(all_vars)) {
out[i] <- paste(row_nms, "*", rot[, i], collapse = " + ")
}

names(out) <- all_vars
out <- pca_helper(rot, x$prefix, all_vars)
out
}

pca_naming <- function(x, prefix) {
gsub(paste0(prefix, "0"), prefix, x)
}
6 changes: 6 additions & 0 deletions R/step_pca_sparse.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#' @export
orbital.step_pca_sparse <- function(x, all_vars, ...) {
rot <- x$res
out <- pca_helper(rot, x$prefix, all_vars)
out
}
6 changes: 6 additions & 0 deletions R/step_pca_sparse_bayes.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#' @export
orbital.step_pca_sparse_bayes <- function(x, all_vars, ...) {
rot <- x$res
out <- pca_helper(rot, x$prefix, all_vars)
out
}
6 changes: 6 additions & 0 deletions R/step_pca_truncated.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#' @export
orbital.step_pca_truncated <- function(x, all_vars, ...) {
rot <- x$res$rotation
out <- pca_helper(rot, x$prefix, all_vars)
out
}
66 changes: 66 additions & 0 deletions tests/testthat/test-step_pca_sparse.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
test_that("step_pca_sparse works", {
skip_if_not_installed("recipes")
skip_if_not_installed("embed")

mtcars <- dplyr::as_tibble(mtcars)
mtcars$hp <- NULL

suppressWarnings(
rec <- recipes::recipe(mpg ~ ., data = mtcars) %>%
embed::step_pca_sparse(recipes::all_predictors()) %>%
recipes::prep()
)

exp <- recipes::bake(rec, new_data = mtcars)

res <- dplyr::mutate(mtcars, !!!orbital_inline(orbital(rec)))
res <- res[names(exp)]

expect_equal(res, exp)
})

test_that("step_pca_sparse works with more than 9 PCs", {
skip_if_not_installed("recipes")
skip_if_not_installed("embed")

mtcars <- dplyr::as_tibble(mtcars)

suppressWarnings(
rec <- recipes::recipe(mpg ~ ., data = mtcars) %>%
embed::step_pca_sparse(recipes::all_predictors()) %>%
recipes::prep()
)

exp <- recipes::bake(rec, new_data = mtcars)

res <- dplyr::mutate(mtcars, !!!orbital_inline(orbital(rec)))
res <- res[names(exp)]

expect_equal(res, exp)
})

test_that("spark - step_pca_sparse works", {
skip_if_not_installed("recipes")
skip_if_not_installed("embed")
skip_if_not_installed("sparklyr")
skip_if(is.na(testthat_spark_env_version()))

mtcars0 <- dplyr::as_tibble(mtcars)
mtcars0$hp <- NULL

suppressWarnings(
rec <- recipes::recipe(mpg ~ ., data = mtcars0) %>%
embed::step_pca_sparse(recipes::all_predictors()) %>%
recipes::prep()
)

exp <- dplyr::mutate(mtcars0, !!!orbital_inline(orbital(rec)))

sc <- testthat_spark_connection()
mtcars_tbl <- testthat_tbl("mtcars0")

res_spark <- dplyr::mutate(mtcars_tbl, !!!orbital_inline(orbital(rec))) %>%
dplyr::collect()

expect_equal(res_spark, exp)
})
69 changes: 69 additions & 0 deletions tests/testthat/test-step_pca_sparse_bayes.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
test_that("step_pca_sparse_bayes works", {
skip_if_not_installed("recipes")
skip_if_not_installed("embed")
skip_if_not_installed("VBsparsePCA")

mtcars <- dplyr::as_tibble(mtcars)
mtcars$hp <- NULL

suppressWarnings(
rec <- recipes::recipe(mpg ~ ., data = mtcars) %>%
embed::step_pca_sparse_bayes(recipes::all_predictors()) %>%
recipes::prep()
)

exp <- recipes::bake(rec, new_data = mtcars)

res <- dplyr::mutate(mtcars, !!!orbital_inline(orbital(rec)))
res <- res[names(exp)]

expect_equal(res, exp)
})

test_that("step_pca_sparse_bayes works with more than 9 PCs", {
skip_if_not_installed("recipes")
skip_if_not_installed("embed")
skip_if_not_installed("VBsparsePCA")

mtcars <- dplyr::as_tibble(mtcars)

suppressWarnings(
rec <- recipes::recipe(mpg ~ ., data = mtcars) %>%
embed::step_pca_sparse_bayes(recipes::all_predictors()) %>%
recipes::prep()
)

exp <- recipes::bake(rec, new_data = mtcars)

res <- dplyr::mutate(mtcars, !!!orbital_inline(orbital(rec)))
res <- res[names(exp)]

expect_equal(res, exp)
})

test_that("spark - step_pca_sparse_bayes works", {
skip_if_not_installed("recipes")
skip_if_not_installed("sparklyr")
skip_if_not_installed("embed")
skip_if_not_installed("VBsparsePCA")
skip_if(is.na(testthat_spark_env_version()))

mtcars0 <- dplyr::as_tibble(mtcars)
mtcars0$hp <- NULL

suppressWarnings(
rec <- recipes::recipe(mpg ~ ., data = mtcars0) %>%
embed::step_pca_sparse_bayes(recipes::all_predictors()) %>%
recipes::prep()
)

exp <- dplyr::mutate(mtcars0, !!!orbital_inline(orbital(rec)))

sc <- testthat_spark_connection()
mtcars_tbl <- testthat_tbl("mtcars0")

res_spark <- dplyr::mutate(mtcars_tbl, !!!orbital_inline(orbital(rec))) %>%
dplyr::collect()

expect_equal(res_spark, exp)
})
66 changes: 66 additions & 0 deletions tests/testthat/test-step_pca_truncated.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
test_that("step_pca_truncated works", {
skip_if_not_installed("recipes")
skip_if_not_installed("embed")

mtcars <- dplyr::as_tibble(mtcars)
mtcars$hp <- NULL

suppressWarnings(
rec <- recipes::recipe(mpg ~ ., data = mtcars) %>%
embed::step_pca_truncated(recipes::all_predictors()) %>%
recipes::prep()
)

exp <- recipes::bake(rec, new_data = mtcars)

res <- dplyr::mutate(mtcars, !!!orbital_inline(orbital(rec)))
res <- res[names(exp)]

expect_equal(res, exp)
})

test_that("step_pca_truncated works with more than 9 PCs", {
skip_if_not_installed("recipes")
skip_if_not_installed("embed")

mtcars <- dplyr::as_tibble(mtcars)

suppressWarnings(
rec <- recipes::recipe(mpg ~ ., data = mtcars) %>%
embed::step_pca_truncated(recipes::all_predictors()) %>%
recipes::prep()
)

exp <- recipes::bake(rec, new_data = mtcars)

res <- dplyr::mutate(mtcars, !!!orbital_inline(orbital(rec)))
res <- res[names(exp)]

expect_equal(res, exp)
})

test_that("spark - step_pca_truncated works", {
skip_if_not_installed("recipes")
skip_if_not_installed("embed")
skip_if_not_installed("sparklyr")
skip_if(is.na(testthat_spark_env_version()))

mtcars0 <- dplyr::as_tibble(mtcars)
mtcars0$hp <- NULL

suppressWarnings(
rec <- recipes::recipe(mpg ~ ., data = mtcars0) %>%
embed::step_pca_truncated(recipes::all_predictors()) %>%
recipes::prep()
)

exp <- dplyr::mutate(mtcars0, !!!orbital_inline(orbital(rec)))

sc <- testthat_spark_connection()
mtcars_tbl <- testthat_tbl("mtcars0")

res_spark <- dplyr::mutate(mtcars_tbl, !!!orbital_inline(orbital(rec))) %>%
dplyr::collect()

expect_equal(res_spark, exp)
})

0 comments on commit c65357f

Please sign in to comment.