-
Notifications
You must be signed in to change notification settings - Fork 88
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
199 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
#' Dummify discrete features to binary columns | ||
#' | ||
#' Data dummification is also known as one hot encoding or feature binarization. It turns each category to a distinct column with binary (numeric) values. | ||
#' @param data input data, in either \link{data.frame} or \link{data.table} format. | ||
#' @param maxcat maximum categories allowed for each discrete feature. The default is 50. | ||
#' @keywords dummify | ||
#' @note This is different from \link{model.matrix}, where the latter aims to create a full rank matrix for regression-like use cases. If your intention is to create a design matrix, use \link{model.matrix} instead. | ||
#' @return dummified dataset (discrete features only) preserving original features. However, column order might be different. | ||
#' @import data.table | ||
#' @import reshape2 | ||
#' @export | ||
#' @examples | ||
#' ## Dummify iris dataset | ||
#' str(dummify(iris)) | ||
#' | ||
#' ## Dummify diamonds dataset ignoring features with more than 5 categories | ||
#' data("diamonds", package = "ggplot2") | ||
#' str(dummify(diamonds, maxcat = 5)) | ||
|
||
dummify <- function(data, maxcat = 50L) { | ||
## Declare variable first to pass R CMD check | ||
discrete_id <- NULL | ||
## Check if input is data.table | ||
is_data_table <- is.data.table(data) | ||
## Detect input data class | ||
data_class <- class(data) | ||
## Set data to data.table | ||
if (!is.data.table(data)) {data <- data.table(data)} | ||
## Split data | ||
split_data <- split_columns(data) | ||
continuous <- split_data$continuous | ||
## Scan feature type | ||
if (split_data$num_discrete > 0) { | ||
discrete <- split_data$discrete | ||
## Get number of categories for each feature | ||
ind <- .ignoreCat(discrete, maxcat) | ||
n_true_discrete <- split_data$num_discrete - length(ind) | ||
if (all(split_data$num_discrete, length(ind), !n_true_discrete)) { | ||
warning("Ignored all discrete features since `maxcat` set to ", maxcat, " categories!") | ||
final_data <- data | ||
} else { | ||
if (n_true_discrete > 0) { | ||
if (length(ind) > 0) { | ||
message(length(ind), " features with more than ", maxcat, " categories ignored!\n", paste0(names(ind), ": ", as.numeric(ind), " categories\n")) | ||
} | ||
## Calculate categorical correlation and melt into tidy data format | ||
discrete[, discrete_id := .I] | ||
discrete_pivot <- Reduce( | ||
function(x, y) {merge(x, y, by = "discrete_id")}, | ||
c( | ||
list(discrete[, c("discrete_id", names(ind)), with = FALSE]), | ||
lapply(names(discrete)[!(names(discrete) %in% c("discrete_id", names(ind)))], function(x) { | ||
dcast.data.table(discrete, discrete_id ~ make.names(paste0(x, "_", get(x))), length, value.var = "discrete_id") | ||
}) | ||
) | ||
) | ||
drop_columns(discrete_pivot, "discrete_id") | ||
if (split_data$num_continuous == 0) { | ||
final_data <- discrete_pivot | ||
} else { | ||
final_data <- cbind(continuous, discrete_pivot) | ||
} | ||
} | ||
} | ||
} else { | ||
warning("No discrete features found! Nothing is dummified!") | ||
final_data <- continuous | ||
} | ||
|
||
## Set data class back to original | ||
if (!is_data_table) {class(final_data) <- data_class} | ||
## Set return object | ||
return(final_data) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
#' Truncate category | ||
#' | ||
#' Output index and name for features that will be ignored | ||
#' @param data input data object. | ||
#' @param maxcat maximum categories allowed for each discrete feature. | ||
#' @return a named vector containing indices of features to be ignored. | ||
#' @import data.table | ||
.ignoreCat <- function(dt, maxcat) { | ||
if (!is.data.table(dt)) {dt <- data.table(dt)} | ||
n_cat <- sapply(dt, function(x) {length(unique(x))}) | ||
n_cat[which(n_cat > maxcat)] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
context("dummify") | ||
data("diamonds", package = "ggplot2") | ||
|
||
test_that("test return object class", { | ||
expect_equal(class(diamonds), class(dummify(diamonds, maxcat = 5))) | ||
expect_equal(class(iris), class(dummify(iris))) | ||
expect_is(dummify(data.table("D" = letters[1:5])), "data.table") | ||
}) | ||
|
||
test_that("test messages and warnings", { | ||
expect_message(dummify(diamonds, maxcat = 5)) | ||
expect_warning(dummify(iris, maxcat = 2)) | ||
expect_warning(dummify(airquality)) | ||
}) | ||
|
||
test_that("test feature count", { | ||
expect_equal(ncol(dummify(diamonds)), 27L) | ||
expect_equal(ncol(dummify(diamonds, maxcat = 5)), 14L) | ||
expect_equal(ncol(dummify(data.table("A" = letters[1:5]))), 5L) | ||
expect_equal(ncol(dummify(data.table("A" = letters[1:5], "B" = letters[6:10]))), 10L) | ||
}) | ||
|
||
test_that("test binary outcome", { | ||
expect_equal(max(dummify(data.table("A" = letters[1:5]))), 1L) | ||
expect_equal(min(dummify(data.table("A" = letters[1:5]))), 0L) | ||
}) | ||
|
||
test_that("test continuous features", { | ||
expect_equivalent(split_columns(diamonds)$continuous, split_columns(dummify(diamonds))$continuous[, 1:7]) | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
context("helper") | ||
|
||
test_that(".ignoreCat", { | ||
set.seed(1) | ||
dt <- data.table( | ||
"a" = as.factor(rep(1, 10)), | ||
"b" = as.factor(sample.int(2, 10, replace = TRUE)), | ||
"c" = as.factor(sample.int(5, 10, replace = TRUE)), | ||
"d" = as.factor(sample.int(10)) | ||
) | ||
expect_equal(as.numeric(.ignoreCat(dt, 0)), c(1L, 2L, 5L, 10L)) | ||
expect_equal(as.numeric(.ignoreCat(dt, 1)), c(2L, 5L, 10L)) | ||
expect_equal(as.numeric(.ignoreCat(dt, 2)), c(5L, 10L)) | ||
expect_equal(as.numeric(.ignoreCat(dt, 5)), 10L) | ||
expect_equal(names(.ignoreCat(dt, 0)), letters[1L:4L]) | ||
expect_equal(names(.ignoreCat(dt, 1)), letters[2L:4L]) | ||
expect_equal(names(.ignoreCat(dt, 2)), letters[3L:4L]) | ||
expect_equal(names(.ignoreCat(dt, 5)), letters[4L]) | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,13 @@ | ||
context("plot correlation heatmap") | ||
data(diamonds, package = "ggplot2") | ||
|
||
test_that("test maximum categories for discrete features", { | ||
data(diamonds, package = "ggplot2") | ||
expect_message(plot_correlation(diamonds, type = "d", maxcat = 5)) | ||
expect_silent(plot_correlation(diamonds, type = "d")) | ||
}) | ||
|
||
test_that("test error messages", { | ||
expect_error(plot_correlation(split_columns(diamonds)$continuous, type = "d")) | ||
expect_error(plot_correlation(split_columns(diamonds)$discrete, type = "c")) | ||
expect_error(suppressWarnings(plot_correlation(split_columns(diamonds)$discrete, maxcat = 2))) | ||
}) |