Skip to content

Commit

Permalink
Merge pull request #232 from adibender/cr
Browse files Browse the repository at this point in the history
Cr
  • Loading branch information
adibender authored Jul 11, 2023
2 parents e966432 + 77a4db0 commit 94e7ea1
Show file tree
Hide file tree
Showing 22 changed files with 332 additions and 104 deletions.
2 changes: 2 additions & 0 deletions CRAN-RELEASE
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
This package was submitted to CRAN on 2022-01-08.
Once it is accepted, delete this file and tag the release (commit b1af7ed).
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: pammtools
Title: Piece-Wise Exponential Additive Mixed Modeling Tools for Survival Analysis
Version: 0.5.91
Date: 2023-03-23
Version: 0.5.92
Date: 2023-07-09
Authors@R: c(
person("Andreas", "Bender", , "[email protected]", role = c("aut", "cre"), comment=c(ORCID = "0000-0001-5628-8611")),
person("Fabian", "Scheipl", , "[email protected]", role = c("aut"), comment = c(ORCID = "0000-0001-8172-3603")),
Expand Down
2 changes: 1 addition & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ export(add_tdc)
export(add_term)
export(arrange)
export(as_ped)
export(as_ped_recurrent)
export(as_ped_multistate)
export(combine_df)
export(cumulative)
export(distinct)
Expand Down
38 changes: 19 additions & 19 deletions R/add-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ add_term <- function(
se <- unname(sqrt(rowSums( (X %*% cov.coefs) * X )))
newdata <- newdata %>%
mutate(
ci_lower = .data$fit - se_mult * se,
ci_upper = .data$fit + se_mult * se)
ci_lower = .data[["fit"]] - se_mult * se,
ci_upper = .data[["fit"]] + se_mult * se)
}

return(newdata)
Expand Down Expand Up @@ -252,7 +252,7 @@ get_hazard.default <- function(
add_ci(object, X, type = type, ci_type = ci_type, se_mult = se_mult, ...)
}
if (type == "response") {
newdata <- newdata %>% mutate(hazard = exp(.data$hazard))
newdata <- newdata %>% mutate(hazard = exp(.data[["hazard"]]))
}

newdata %>% arrange(.data[[time_var]], .by_group = TRUE)
Expand Down Expand Up @@ -323,7 +323,7 @@ get_cumu_hazard <- function(

interval_length <- sym(interval_length)

mutate_args <- list(cumu_hazard = quo(cumsum(.data$hazard *
mutate_args <- list(cumu_hazard = quo(cumsum(.data[["hazard"]] *
(!!interval_length))))
haz_vars_in_data <- map(c("hazard", "se", "ci_lower", "ci_upper"),
~ grep(.x, colnames(newdata), value = TRUE, fixed = TRUE)) %>% flatten_chr()
Expand All @@ -337,8 +337,8 @@ get_cumu_hazard <- function(
if (ci_type == "default") {
mutate_args <- mutate_args %>%
append(list(
cumu_lower = quo(cumsum(.data$ci_lower * (!!interval_length))),
cumu_upper = quo(cumsum(.data$ci_upper * (!!interval_length)))))
cumu_lower = quo(cumsum(.data[["ci_lower"]] * (!!interval_length))),
cumu_upper = quo(cumsum(.data[["ci_upper"]] * (!!interval_length)))))
} else {
# ci delta rule
newdata <- split(newdata, group_indices(newdata)) %>%
Expand Down Expand Up @@ -437,7 +437,7 @@ get_surv_prob <- function(

interval_length <- sym(interval_length)

mutate_args <- list(surv_prob = quo(exp(-cumsum(.data$hazard *
mutate_args <- list(surv_prob = quo(exp(-cumsum(.data[["hazard"]] *
(!!interval_length)))))
haz_vars_in_data <- map(c("hazard", "se", "ci_lower", "ci_upper"),
~grep(.x, colnames(newdata), value = TRUE, fixed = TRUE)) %>% flatten_chr()
Expand All @@ -451,8 +451,8 @@ get_surv_prob <- function(
if (ci_type == "default") {
mutate_args <- mutate_args %>%
append(list(
surv_upper = quo(exp(-cumsum(.data$ci_lower * (!!interval_length)))),
surv_lower = quo(exp(-cumsum(.data$ci_upper * (!!interval_length))))))
surv_upper = quo(exp(-cumsum(.data[["ci_lower"]] * (!!interval_length)))),
surv_lower = quo(exp(-cumsum(.data[["ci_upper"]] * (!!interval_length))))))
} else {
# ci delta rule
newdata <- split(newdata, group_indices(newdata)) %>%
Expand Down Expand Up @@ -506,16 +506,16 @@ add_ci <- function(
if (type == "link") {
newdata <- newdata %>%
mutate(
ci_lower = .data$hazard - se_mult * .data$se,
ci_upper = .data$hazard + se_mult * .data$se)
ci_lower = .data[["hazard"]] - se_mult * .data[["se"]],
ci_upper = .data[["hazard"]] + se_mult * .data[["se"]])
}

if (type != "link") {
if (ci_type == "default") {
newdata <- newdata %>%
mutate(
ci_lower = exp(.data$hazard - se_mult * .data$se),
ci_upper = exp(.data$hazard + se_mult * .data$se))
ci_lower = exp(.data[["hazard"]] - se_mult * .data[["se"]]),
ci_upper = exp(.data[["hazard"]] + se_mult * .data[["se"]]))
} else {
if (ci_type == "delta") {
newdata <- split(newdata, group_indices(newdata)) %>%
Expand All @@ -539,8 +539,8 @@ add_delta_ci <- function(newdata, object, se_mult = 2, ...) {
newdata %>%
mutate(
se = sqrt(rowSums( (Jacobi %*% V) * Jacobi )),
ci_lower = exp(.data$hazard) - .data$se * se_mult,
ci_upper = exp(.data$hazard) + .data$se * se_mult)
ci_lower = exp(.data[["hazard"]]) - .data[["se"]] * se_mult,
ci_upper = exp(.data[["hazard"]]) + .data[["se"]] * se_mult)

}

Expand All @@ -554,8 +554,8 @@ add_delta_ci_cumu <- function(newdata, object, se_mult = 2, ...) {
newdata %>%
mutate(
se = sqrt(rowSums( (LHS %*% V) * LHS )),
cumu_lower = cumsum(.data$intlen * .data$hazard) - .data$se * se_mult,
cumu_upper = cumsum(.data$intlen * .data$hazard) + .data$se * se_mult)
cumu_lower = cumsum(.data[["intlen"]] * .data[["hazard"]]) - .data[["se"]] * se_mult,
cumu_upper = cumsum(.data[["intlen"]] * .data[["hazard"]]) + .data[["se"]] * se_mult)

}

Expand All @@ -570,8 +570,8 @@ add_delta_ci_surv <- function(newdata, object, se_mult = 2, ...) {
newdata %>%
mutate(
se = sqrt(rowSums( (LHS %*% V) * LHS)),
surv_lower = exp(-cumsum(.data$hazard * .data$intlen)) - .data$se * se_mult,
surv_upper = exp(-cumsum(.data$hazard * .data$intlen)) + .data$se * se_mult)
surv_lower = exp(-cumsum(.data[["hazard"]] * .data[["intlen"]])) - .data[["se"]] * se_mult,
surv_upper = exp(-cumsum(.data[["hazard"]] * .data[["intlen"]])) + .data[["se"]] * se_mult)

}

Expand Down
54 changes: 33 additions & 21 deletions R/as-ped.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ as_ped.data.frame <- function(
min_events = 1L,
...) {

status_error(data, formula)
status_error(data, formula, censor_code)
assert_subset(tdc_specials, c("concurrent", "cumulative"))

if (test_character(transition, min.chars = 1L, min.len = 1L)) {
ped <- as_ped_recurrent(data = data, formula = formula, cut = cut,
ped <- as_ped_multistate(data = data, formula = formula, cut = cut,
max_time = max_time, tdc_specials = tdc_specials, censor_code = censor_code,
transition = transition, timescale = timescale, min_events = min_events, ... )
return(ped)
Expand Down Expand Up @@ -100,9 +100,10 @@ as_ped.data.frame <- function(

#' @rdname as_ped
#' @export
as_ped.nested_fdf <- function(data, formula, ...) {

status_error(data, formula)
as_ped.nested_fdf <- function(
data,
formula,
...) {

dots <- list(...)
# update interval break points (if necessary)
Expand All @@ -113,14 +114,19 @@ as_ped.nested_fdf <- function(data, formula, ...) {
ccr_breaks <- attr(data, "ccr_breaks")
cut <- union(cut, ccr_breaks[ccr_breaks <= max(cut)]) %>% sort()

ped <- data %>%
select_if(is.atomic) %>%
as.data.frame() %>%
as_ped(
formula = formula,
id = dots$id,
cut = cut,
max_time = dots$max_time)
# ped <- data %>%
# select_if(is.atomic) %>%
# as.data.frame() %>%
# as_ped(
# formula = formula,
# id = dots$id,
# cut = cut,
# max_time = dots$max_time,
# ...)
dots$formula <- formula
dots$data <- as.data.frame(select_if(data, is.atomic))
dots$cut <- cut
ped <- do.call(as_ped, dots)

# replace updated attributes
attr(data, "breaks") <- attr(ped, "breaks")
Expand Down Expand Up @@ -156,12 +162,13 @@ as_ped.list <- function(
data,
formula,
tdc_specials = c("concurrent", "cumulative"),
censor_code = 0L,
...) {

assert_class(data, "list")
assert_class(formula, "formula")

status_error(data[[1]], formula)
status_error(data[[1]], formula, censor_code)

nl <- length(data)
# form <- Formula(formula)
Expand Down Expand Up @@ -225,8 +232,10 @@ as_ped.pamm <- function(data, newdata, ...) {

}

## Competing risks

#' Competing risks trafo
#'
#' @inherit as_ped
#' @importFrom rlang .env
#'
Expand All @@ -242,7 +251,9 @@ as_ped_cr <- function(
...) {

lhs_vars <- get_lhs_vars(formula)
n_lhs <- length(lhs_vars)
event_types <- get_event_types(data, formula, censor_code)
n_events <- sum(event_types != censor_code)

cut <- map2(
event_types,
Expand All @@ -260,7 +271,7 @@ as_ped_cr <- function(
cut,
function(.event, .cut) {
ped_i <- data %>%
mutate(!!lhs_vars[2] := 1L * (.data[[lhs_vars[2]]] == .env[[".event"]])) %>%
mutate(!!lhs_vars[n_lhs] := 1L * (.data[[lhs_vars[n_lhs]]] == .env[[".event"]])) %>%
as_ped(
formula = formula,
cut = .cut,
Expand Down Expand Up @@ -319,15 +330,16 @@ get_event_types <- function(data, formula, censor_code) {
#' cgd2 <- cgd %>%
#' select(id, tstart, tstop, enum, status, age) %>%
#' filter(enum %in% c(1:2))
#' ped_re <- as_ped_recurrent(
#' ped_re <- as_ped_multistate(
#' formula = Surv(tstart, tstop, status) ~ age + enum,
#' data = cgd2,
#' transition = "enum")
#' transition = "enum",
#' timescale = "calendar")
#' }
#' @rdname as_ped
#' @export
#' @keywords internal
as_ped_recurrent <- function(
as_ped_multistate <- function(
data,
formula,
cut = NULL,
Expand All @@ -344,7 +356,7 @@ as_ped_recurrent <- function(
len = 1L)
assert_integer(min_events, lower = 1L, len = 1L)

status_error(data, formula)
status_error(data, formula, censor_code)
assert_subset(tdc_specials, c("concurrent", "cumulative"))

rhs_vars <- get_rhs_vars(formula)
Expand All @@ -355,13 +367,13 @@ as_ped_recurrent <- function(
dots <- list(...)
dots$data <- data
dots$formula <- get_ped_form(formula, data = data, tdc_specials = tdc_specials)
dots$cut <- cut
dots$cut <- sort(unique(cut))
dots$max_time <- max_time
dots$transition <- transition
dots$min_events <- min_events
dots$timescale <- timescale

ped <- do.call(split_data_recurrent, dots)
ped <- do.call(split_data_multistate, dots)
attr(ped, "time_var") <- get_lhs_vars(dots$formula)[1]

return(ped)
Expand Down
4 changes: 2 additions & 2 deletions R/cumulative-coefficient.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ get_cumu_coef.aalen <- function(model, data = NULL, terms, ci = TRUE, ...) {
cumu_coef <- model[["cum"]] %>%
as_tibble() %>%
select(one_of(terms)) %>%
gather("variable", "cumu_hazard", -.data$time)
gather("variable", "cumu_hazard", -.data[["time"]])
cumu_var <- model[["var.cum"]] %>%
as_tibble() %>%
select(terms) %>%
gather("variable", "cumu_var", -.data$time)
gather("variable", "cumu_var", -.data[["time"]])

suppressMessages(
left_join(cumu_coef, cumu_var) %>%
Expand Down
26 changes: 15 additions & 11 deletions R/formula-specials.R
Original file line number Diff line number Diff line change
Expand Up @@ -310,35 +310,39 @@ make_mat_names <- function(x, ...) {

#' @keywords internal
make_mat_names.default <- function(
col_vars,
x,
latency_var = NULL,
tz_var = NULL,
suffix = NULL,
nfunc = 1) {
nfunc = 1,
...) {

if (!is.null(suffix)) {
return(paste(col_vars, suffix, sep = "_"))
return(paste(x, suffix, sep = "_"))
} else {
if (!is.null(tz_var) & nfunc > 1) {
tz_ind <- col_vars == tz_var
col_vars[!tz_ind] <- paste(col_vars[!tz_ind], tz_var, sep = "_")
tz_ind <- x == tz_var
x[!tz_ind] <- paste(x[!tz_ind], tz_var, sep = "_")
}
if (!is.null(latency_var)) {
latency_ind <- col_vars == latency_var
col_vars[latency_ind] <- paste(col_vars[latency_ind], "latency",
latency_ind <- x == latency_var
x[latency_ind] <- paste(x[latency_ind], "latency",
sep = "_")
}
}

return(col_vars)
return(x)

}

#' @keywords internal
make_mat_names.list <- function(func_list, time_var) {
hist_names <- map(func_list, ~ make_mat_names(c(.x[["col_vars"]], "LL"),
make_mat_names.list <- function(
x,
time_var,
...) {
hist_names <- map(x, ~ make_mat_names(c(.x[["col_vars"]], "LL"),
.x[["latency_var"]], .x[["tz_var"]], .x[["suffix"]],
nfunc = length(func_list)))
nfunc = length(x)))

time_mat_ind <- map(hist_names, ~grepl(time_var, .))
for (i in seq_along(time_mat_ind)) {
Expand Down
3 changes: 2 additions & 1 deletion R/get-cut-points.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ get_cut.default <- function(
}
# sort interval cut points in case they are not (so that interval factor
# variables will be in correct ordering)
sort(cut)
sort(unique(cut))

}

Expand Down Expand Up @@ -71,5 +71,6 @@ get_cut.list <- function (
)

cuts <- Reduce(union, cuts)
sort(unique(cuts))

}
Loading

0 comments on commit 94e7ea1

Please sign in to comment.