Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatically supply inverses of known functions #78

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 85 additions & 4 deletions R/default.R
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,85 @@ dim.dist_default <- function(x){

invert_fail <- function(...) stop("Inverting transformations for distributions is not yet supported.")

#' Attempt to get the inverse of f(x) by name. Returns invert_fail
#' (a function that raises an error if called) if there is no known inverse.
#' @param f string. Name of a function.
#' @noRd
get_unary_inverse <- function(f) {
switch(f,
sqrt = function(x) x^2,
exp = log,
log = function(x, base = exp(1)) base ^ x,
log2 = function(x) 2^x,
log10 = function(x) 10^x,
expm1 = log1p,
log1p = expm1,
cos = acos,
sin = asin,
tan = atan,
acos = cos,
asin = sin,
atan = tan,
cosh = acosh,
sinh = asinh,
tanh = atanh,
acosh = cosh,
asinh = sinh,
atanh = tanh,

invert_fail
)
}

#' Attempt to get the inverse of f(x, constant) by name. Returns invert_fail
#' (a function that raises an error if called) if there is no known inverse.
#' @param f string. Name of a function.
#' @param constant a constant value
#' @noRd
get_binary_inverse_1 <- function(f, constant) {
force(constant)

switch(f,
`+` = function(x) x - constant,
`-` = function(x) x + constant,
`*` = function(x) x / constant,
`/` = function(x) x * constant,
`^` = function(x) x ^ (1/constant),

invert_fail
)
}

#' Attempt to get the inverse of f(constant, x) by name. Returns invert_fail
#' (a function that raises an error if called) if there is no known inverse.
#' @param f string. Name of a function.
#' @param constant a constant value
#' @noRd
get_binary_inverse_2 <- function(f, constant) {
force(constant)

switch(f,
`+` = function(x) x - constant,
`-` = function(x) constant - x,
`*` = function(x) x / constant,
`/` = function(x) constant / x,
`^` = function(x) log(x, base = constant),

invert_fail
)
}

#' @method Math dist_default
#' @export
Math.dist_default <- function(x, ...) {
if(dim(x) > 1) stop("Transformations of multivariate distributions are not yet supported.")

trans <- new_function(exprs(x = ), body = expr((!!sym(.Generic))(x, !!!dots_list(...))))
vec_data(dist_transformed(wrap_dist(list(x)), trans, invert_fail))[[1]]

inverse_fun <- get_unary_inverse(.Generic)
inverse <- new_function(exprs(x = ), body = expr((!!inverse_fun)(x, !!!dots_list(...))))

vec_data(dist_transformed(wrap_dist(list(x)), trans, inverse))[[1]]
}

#' @method Ops dist_default
Expand All @@ -182,10 +255,18 @@ Ops.dist_default <- function(e1, e2) {
stop(sprintf("The %s operation is not supported for <%s> and <%s>", .Generic, class(e1)[1], class(e2)[1]))
}
} else if(is_dist[1]){
new_function(exprs(x = ), body = expr((!!sym(.Generic))((!!e1$transform)(x), !!e2)))
new_function(exprs(x = ), body = expr((!!sym(.Generic))(x, !!e2)))
} else {
new_function(exprs(x = ), body = expr((!!sym(.Generic))(!!e1, x)))
}

inverse <- if(all(is_dist)) {
invert_fail
} else if(is_dist[1]){
get_binary_inverse_1(.Generic, e2)
} else {
new_function(exprs(x = ), body = expr((!!sym(.Generic))(!!e1, (!!e2$transform)(x))))
get_binary_inverse_2(.Generic, e1)
}

vec_data(dist_transformed(wrap_dist(list(e1,e2)[which(is_dist)]), trans, invert_fail))[[1]]
vec_data(dist_transformed(wrap_dist(list(e1,e2)[which(is_dist)]), trans, inverse))[[1]]
}
17 changes: 14 additions & 3 deletions R/dist_lognormal.R
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,21 @@ kurtosis.dist_lognormal <- function(x, ...) {
exp(4*s2) + 2*exp(3*s2) + 3*exp(2*s2) - 6
}

# make a normal distribution from a lognormal distribution using the
# specified base
normal_dist_with_base <- function(x, base = exp(1)) {
vec_data(dist_normal(x[["mu"]], x[["sigma"]]) / log(base))[[1]]
}

#' @method Math dist_lognormal
#' @export
Math.dist_lognormal <- function(x, ...) {
# Shortcut to get Normal distribution from log-normal.
if(.Generic == "log") return(vec_data(dist_normal(x[["mu"]], x[["sigma"]]))[[1]])
NextMethod()
switch(.Generic,
# Shortcuts to get Normal distribution from log-normal.
log = normal_dist_with_base(x, ...),
log2 = normal_dist_with_base(x, 2),
log10 = normal_dist_with_base(x, 10),

NextMethod()
)
}
1 change: 0 additions & 1 deletion R/hilo.R
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ vec_math.hilo <- function(.fn, .x, ...){
vec_restore(out, .x)
}

#' @rdname vctrs-compat
#' @method vec_arith hilo
#' @export
vec_arith.hilo <- function(op, x, y, ...){
Expand Down
2 changes: 1 addition & 1 deletion R/plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#' @param x The distribution(s) to plot.
#' @param ... Unused.
#'
#' @keyword internal
#' @keywords internal
#'
#' @export
autoplot.distribution <- function(x, ...){
Expand Down
21 changes: 18 additions & 3 deletions R/transformed.R
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ format.dist_transformed <- function(x, ...){

#' @export
density.dist_transformed <- function(x, at, ...){
density(x[["dist"]], x[["inverse"]](at))*vapply(at, numDeriv::jacobian, numeric(1L), func = x[["inverse"]])
density(x[["dist"]], x[["inverse"]](at))*abs(vapply(at, numDeriv::jacobian, numeric(1L), func = x[["inverse"]]))
}

#' @export
Expand Down Expand Up @@ -83,7 +83,11 @@ covariance.dist_transformed <- function(x, ...){
#' @export
Math.dist_transformed <- function(x, ...) {
trans <- new_function(exprs(x = ), body = expr((!!sym(.Generic))((!!x$transform)(x), !!!dots_list(...))))
vec_data(dist_transformed(wrap_dist(list(x[["dist"]])), trans, invert_fail))[[1]]

inverse_fun <- get_unary_inverse(.Generic)
inverse <- new_function(exprs(x = ), body = expr((!!x$inverse)((!!inverse_fun)(x, !!!dots_list(...)))))

vec_data(dist_transformed(wrap_dist(list(x[["dist"]])), trans, inverse))[[1]]
}

#' @method Ops dist_transformed
Expand All @@ -101,5 +105,16 @@ Ops.dist_transformed <- function(e1, e2) {
} else {
new_function(exprs(x = ), body = expr((!!sym(.Generic))(!!e1, (!!e2$transform)(x))))
}
vec_data(dist_transformed(wrap_dist(list(list(e1,e2)[[which(is_dist)[1]]][["dist"]])), trans, invert_fail))[[1]]

inverse <- if(all(is_dist)) {
invert_fail
} else if(is_dist[1]){
inverse_fun <- get_binary_inverse_1(.Generic, e2)
new_function(exprs(x = ), body = expr((!!e1$inverse)((!!inverse_fun)(x))))
} else {
inverse_fun <- get_binary_inverse_2(.Generic, e1)
new_function(exprs(x = ), body = expr((!!e2$inverse)((!!inverse_fun)(x))))
}

vec_data(dist_transformed(wrap_dist(list(list(e1,e2)[[which(is_dist)[1]]][["dist"]])), trans, inverse))[[1]]
}
1 change: 1 addition & 0 deletions man/autoplot.distribution.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

47 changes: 47 additions & 0 deletions tests/testthat/test-transformations.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ test_that("LogNormal distributions", {
dist_normal(0, 0.5)
)

# Test log() shortcut with different bases
expect_equal(log(dist_lognormal(0, log(3)), base = 3), dist_normal(0, 1))
expect_equal(log2(dist_lognormal(0, log(2))), dist_normal(0, 1))
expect_equal(log10(dist_lognormal(0, log(10))), dist_normal(0, 1))

# format
expect_equal(format(dist), sprintf("t(%s)", format(dist_normal(0, 0.5))))

Expand Down Expand Up @@ -71,3 +76,45 @@ test_that("LogNormal distributions", {
expect_equal(mean(dist), exp(0.25/2), tolerance = 0.01)
expect_equal(variance(dist), (exp(0.25) - 1)*exp(0.25), tolerance = 0.1)
})

test_that("inverses are applied automatically", {
dist <- dist_gamma(1,1)
log2dist <- log(dist, base = 2)
log2dist_t <- dist_transformed(dist, log2, function(x) 2 ^ x)

expect_equal(density(log2dist, 0.5), density(log2dist_t, 0.5))
expect_equal(cdf(log2dist, 0.5), cdf(log2dist_t, 0.5))
expect_equal(quantile(log2dist, 0.5), quantile(log2dist_t, 0.5))

# test multiple transformations that get stacked together by dist_transformed
explogdist <- exp(log(dist))
expect_equal(density(dist, 0.5), density(explogdist, 0.5))
expect_equal(cdf(dist, 0.5), cdf(explogdist, 0.5))
expect_equal(quantile(dist, 0.5), quantile(explogdist, 0.5))

# test multiple transformations created by operators (via Ops)
explog2dist <- 2 ^ log2dist
expect_equal(density(dist, 0.5), density(explog2dist, 0.5))
expect_equal(cdf(dist, 0.5), cdf(explog2dist, 0.5))
expect_equal(quantile(dist, 0.5), quantile(explog2dist, 0.5))

# basic set of inverses
expect_equal(density(sqrt(dist^2), 0.5), density(dist, 0.5))
expect_equal(density(exp(log(dist)), 0.5), density(dist, 0.5))
expect_equal(density(10^(log10(dist)), 0.5), density(dist, 0.5))
expect_equal(density(expm1(log1p(dist)), 0.5), density(dist, 0.5))
expect_equal(density(cos(acos(dist)), 0.5), density(dist, 0.5))
expect_equal(density(sin(asin(dist)), 0.5), density(dist, 0.5))
expect_equal(density(tan(atan(dist)), 0.5), density(dist, 0.5))
expect_equal(density(cosh(acosh(dist + 1)) - 1, 0.5), density(dist, 0.5))
expect_equal(density(sinh(asinh(dist)), 0.5), density(dist, 0.5))
expect_equal(density(tanh(atanh(dist)), 0.5), density(dist, 0.5))

expect_equal(density(dist + 1 - 1, 0.5), density(dist, 0.5))
expect_equal(density(dist * 2 / 2, 0.5), density(dist, 0.5))

# inverting a gamma distribution
expect_equal(density(1/dist_gamma(4, 3), 0.5), density(dist_inverse_gamma(4, 1/3), 0.5))
expect_equal(density(1/(1/dist_gamma(4, 3)), 0.5), density(dist_gamma(4, 3), 0.5))

})