Skip to content

Commit

Permalink
Merge pull request #78 from mjskay/trans-inverse-group
Browse files Browse the repository at this point in the history
Automatically supply inverses of known functions
  • Loading branch information
mitchelloharawild authored Jan 4, 2022
2 parents 0640bcf + cf3a2e5 commit 0476938
Show file tree
Hide file tree
Showing 7 changed files with 166 additions and 12 deletions.
89 changes: 85 additions & 4 deletions R/default.R
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,85 @@ dim.dist_default <- function(x){

invert_fail <- function(...) stop("Inverting transformations for distributions is not yet supported.")

#' Attempt to get the inverse of f(x) by name. Returns invert_fail
#' (a function that raises an error if called) if there is no known inverse.
#' @param f string. Name of a function.
#' @noRd
get_unary_inverse <- function(f) {
switch(f,
sqrt = function(x) x^2,
exp = log,
log = function(x, base = exp(1)) base ^ x,
log2 = function(x) 2^x,
log10 = function(x) 10^x,
expm1 = log1p,
log1p = expm1,
cos = acos,
sin = asin,
tan = atan,
acos = cos,
asin = sin,
atan = tan,
cosh = acosh,
sinh = asinh,
tanh = atanh,
acosh = cosh,
asinh = sinh,
atanh = tanh,

invert_fail
)
}

#' Attempt to get the inverse of f(x, constant) by name. Returns invert_fail
#' (a function that raises an error if called) if there is no known inverse.
#' @param f string. Name of a function.
#' @param constant a constant value
#' @noRd
get_binary_inverse_1 <- function(f, constant) {
force(constant)

switch(f,
`+` = function(x) x - constant,
`-` = function(x) x + constant,
`*` = function(x) x / constant,
`/` = function(x) x * constant,
`^` = function(x) x ^ (1/constant),

invert_fail
)
}

#' Attempt to get the inverse of f(constant, x) by name. Returns invert_fail
#' (a function that raises an error if called) if there is no known inverse.
#' @param f string. Name of a function.
#' @param constant a constant value
#' @noRd
get_binary_inverse_2 <- function(f, constant) {
force(constant)

switch(f,
`+` = function(x) x - constant,
`-` = function(x) constant - x,
`*` = function(x) x / constant,
`/` = function(x) constant / x,
`^` = function(x) log(x, base = constant),

invert_fail
)
}

#' @method Math dist_default
#' @export
Math.dist_default <- function(x, ...) {
if(dim(x) > 1) stop("Transformations of multivariate distributions are not yet supported.")

trans <- new_function(exprs(x = ), body = expr((!!sym(.Generic))(x, !!!dots_list(...))))
vec_data(dist_transformed(wrap_dist(list(x)), trans, invert_fail))[[1]]

inverse_fun <- get_unary_inverse(.Generic)
inverse <- new_function(exprs(x = ), body = expr((!!inverse_fun)(x, !!!dots_list(...))))

vec_data(dist_transformed(wrap_dist(list(x)), trans, inverse))[[1]]
}

#' @method Ops dist_default
Expand All @@ -205,10 +278,18 @@ Ops.dist_default <- function(e1, e2) {
stop(sprintf("The %s operation is not supported for <%s> and <%s>", .Generic, class(e1)[1], class(e2)[1]))
}
} else if(is_dist[1]){
new_function(exprs(x = ), body = expr((!!sym(.Generic))((!!e1$transform)(x), !!e2)))
new_function(exprs(x = ), body = expr((!!sym(.Generic))(x, !!e2)))
} else {
new_function(exprs(x = ), body = expr((!!sym(.Generic))(!!e1, x)))
}

inverse <- if(all(is_dist)) {
invert_fail
} else if(is_dist[1]){
get_binary_inverse_1(.Generic, e2)
} else {
new_function(exprs(x = ), body = expr((!!sym(.Generic))(!!e1, (!!e2$transform)(x))))
get_binary_inverse_2(.Generic, e1)
}

vec_data(dist_transformed(wrap_dist(list(e1,e2)[which(is_dist)]), trans, invert_fail))[[1]]
vec_data(dist_transformed(wrap_dist(list(e1,e2)[which(is_dist)]), trans, inverse))[[1]]
}
17 changes: 14 additions & 3 deletions R/dist_lognormal.R
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,21 @@ kurtosis.dist_lognormal <- function(x, ...) {
exp(4*s2) + 2*exp(3*s2) + 3*exp(2*s2) - 6
}

# make a normal distribution from a lognormal distribution using the
# specified base
normal_dist_with_base <- function(x, base = exp(1)) {
vec_data(dist_normal(x[["mu"]], x[["sigma"]]) / log(base))[[1]]
}

#' @method Math dist_lognormal
#' @export
Math.dist_lognormal <- function(x, ...) {
# Shortcut to get Normal distribution from log-normal.
if(.Generic == "log") return(vec_data(dist_normal(x[["mu"]], x[["sigma"]]))[[1]])
NextMethod()
switch(.Generic,
# Shortcuts to get Normal distribution from log-normal.
log = normal_dist_with_base(x, ...),
log2 = normal_dist_with_base(x, 2),
log10 = normal_dist_with_base(x, 10),

NextMethod()
)
}
1 change: 0 additions & 1 deletion R/hilo.R
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ vec_math.hilo <- function(.fn, .x, ...){
vec_restore(out, .x)
}

#' @rdname vctrs-compat
#' @method vec_arith hilo
#' @export
vec_arith.hilo <- function(op, x, y, ...){
Expand Down
2 changes: 1 addition & 1 deletion R/plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#' @param x The distribution(s) to plot.
#' @param ... Unused.
#'
#' @keyword internal
#' @keywords internal
#'
#' @export
autoplot.distribution <- function(x, ...){
Expand Down
21 changes: 18 additions & 3 deletions R/transformed.R
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ format.dist_transformed <- function(x, ...){

#' @export
density.dist_transformed <- function(x, at, ...){
density(x[["dist"]], x[["inverse"]](at))*vapply(at, numDeriv::jacobian, numeric(1L), func = x[["inverse"]])
density(x[["dist"]], x[["inverse"]](at))*abs(vapply(at, numDeriv::jacobian, numeric(1L), func = x[["inverse"]]))
}

#' @export
Expand Down Expand Up @@ -83,7 +83,11 @@ covariance.dist_transformed <- function(x, ...){
#' @export
Math.dist_transformed <- function(x, ...) {
trans <- new_function(exprs(x = ), body = expr((!!sym(.Generic))((!!x$transform)(x), !!!dots_list(...))))
vec_data(dist_transformed(wrap_dist(list(x[["dist"]])), trans, invert_fail))[[1]]

inverse_fun <- get_unary_inverse(.Generic)
inverse <- new_function(exprs(x = ), body = expr((!!x$inverse)((!!inverse_fun)(x, !!!dots_list(...)))))

vec_data(dist_transformed(wrap_dist(list(x[["dist"]])), trans, inverse))[[1]]
}

#' @method Ops dist_transformed
Expand All @@ -101,5 +105,16 @@ Ops.dist_transformed <- function(e1, e2) {
} else {
new_function(exprs(x = ), body = expr((!!sym(.Generic))(!!e1, (!!e2$transform)(x))))
}
vec_data(dist_transformed(wrap_dist(list(list(e1,e2)[[which(is_dist)[1]]][["dist"]])), trans, invert_fail))[[1]]

inverse <- if(all(is_dist)) {
invert_fail
} else if(is_dist[1]){
inverse_fun <- get_binary_inverse_1(.Generic, e2)
new_function(exprs(x = ), body = expr((!!e1$inverse)((!!inverse_fun)(x))))
} else {
inverse_fun <- get_binary_inverse_2(.Generic, e1)
new_function(exprs(x = ), body = expr((!!e2$inverse)((!!inverse_fun)(x))))
}

vec_data(dist_transformed(wrap_dist(list(list(e1,e2)[[which(is_dist)[1]]][["dist"]])), trans, inverse))[[1]]
}
1 change: 1 addition & 0 deletions man/autoplot.distribution.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

47 changes: 47 additions & 0 deletions tests/testthat/test-transformations.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ test_that("LogNormal distributions", {
dist_normal(0, 0.5)
)

# Test log() shortcut with different bases
expect_equal(log(dist_lognormal(0, log(3)), base = 3), dist_normal(0, 1))
expect_equal(log2(dist_lognormal(0, log(2))), dist_normal(0, 1))
expect_equal(log10(dist_lognormal(0, log(10))), dist_normal(0, 1))

# format
expect_equal(format(dist), sprintf("t(%s)", format(dist_normal(0, 0.5))))

Expand Down Expand Up @@ -71,3 +76,45 @@ test_that("LogNormal distributions", {
expect_equal(mean(dist), exp(0.25/2), tolerance = 0.01)
expect_equal(variance(dist), (exp(0.25) - 1)*exp(0.25), tolerance = 0.1)
})

test_that("inverses are applied automatically", {
dist <- dist_gamma(1,1)
log2dist <- log(dist, base = 2)
log2dist_t <- dist_transformed(dist, log2, function(x) 2 ^ x)

expect_equal(density(log2dist, 0.5), density(log2dist_t, 0.5))
expect_equal(cdf(log2dist, 0.5), cdf(log2dist_t, 0.5))
expect_equal(quantile(log2dist, 0.5), quantile(log2dist_t, 0.5))

# test multiple transformations that get stacked together by dist_transformed
explogdist <- exp(log(dist))
expect_equal(density(dist, 0.5), density(explogdist, 0.5))
expect_equal(cdf(dist, 0.5), cdf(explogdist, 0.5))
expect_equal(quantile(dist, 0.5), quantile(explogdist, 0.5))

# test multiple transformations created by operators (via Ops)
explog2dist <- 2 ^ log2dist
expect_equal(density(dist, 0.5), density(explog2dist, 0.5))
expect_equal(cdf(dist, 0.5), cdf(explog2dist, 0.5))
expect_equal(quantile(dist, 0.5), quantile(explog2dist, 0.5))

# basic set of inverses
expect_equal(density(sqrt(dist^2), 0.5), density(dist, 0.5))
expect_equal(density(exp(log(dist)), 0.5), density(dist, 0.5))
expect_equal(density(10^(log10(dist)), 0.5), density(dist, 0.5))
expect_equal(density(expm1(log1p(dist)), 0.5), density(dist, 0.5))
expect_equal(density(cos(acos(dist)), 0.5), density(dist, 0.5))
expect_equal(density(sin(asin(dist)), 0.5), density(dist, 0.5))
expect_equal(density(tan(atan(dist)), 0.5), density(dist, 0.5))
expect_equal(density(cosh(acosh(dist + 1)) - 1, 0.5), density(dist, 0.5))
expect_equal(density(sinh(asinh(dist)), 0.5), density(dist, 0.5))
expect_equal(density(tanh(atanh(dist)), 0.5), density(dist, 0.5))

expect_equal(density(dist + 1 - 1, 0.5), density(dist, 0.5))
expect_equal(density(dist * 2 / 2, 0.5), density(dist, 0.5))

# inverting a gamma distribution
expect_equal(density(1/dist_gamma(4, 3), 0.5), density(dist_inverse_gamma(4, 1/3), 0.5))
expect_equal(density(1/(1/dist_gamma(4, 3)), 0.5), density(dist_gamma(4, 3), 0.5))

})

0 comments on commit 0476938

Please sign in to comment.