diff --git a/R/ParamSet.R b/R/ParamSet.R index 4f7e5c24..4448ce88 100644 --- a/R/ParamSet.R +++ b/R/ParamSet.R @@ -132,14 +132,18 @@ ParamSet = R6Class("ParamSet", #' Return values `with_token`, `without_token` or `only_token`? #' @param check_required (`logical(1)`)\cr #' Check if all required parameters are set? + #' @param remove_dependencies (`logical(1)`)\cr + #' Determines if values of parameters with unsatisfied dependencies are removed. #' @return Named `list()`. - get_values = function(class = NULL, is_bounded = NULL, tags = NULL, - type = "with_token", check_required = TRUE) { + get_values = function(class = NULL, is_bounded = NULL, tags = NULL, type = "with_token", check_required = TRUE, + remove_dependencies = TRUE) { assert_choice(type, c("with_token", "without_token", "only_token")) assert_flag(check_required) + assert_flag(remove_dependencies) values = self$values params = self$params_unid ns = names(values) + deps = self$deps if (type == "without_token") { values = discard(values, is, "TuneToken") @@ -147,13 +151,26 @@ ParamSet = R6Class("ParamSet", values = keep(values, is, "TuneToken") } - if(check_required) { + if (check_required) { required = setdiff(names(keep(params, function(p) "required" %in% p$tags)), ns) if (length(required) > 0L) { stop(sprintf("Missing required parameters: %s", str_collapse(required))) } } + if (remove_dependencies) { + if (nrow(deps)) { + for (j in seq_row(deps)) { + p1id = deps$id[j] + p2id = deps$on[j] + cond = deps$cond[[j]] + if (p1id %in% ns && !inherits(values[[p2id]], "TuneToken") && !isTRUE(cond$test(values[[p2id]]))) { + values[p1id] = NULL + } + } + } + } + values[intersect(names(values), self$ids(class = class, is_bounded = is_bounded, tags = tags))] }, @@ -203,8 +220,11 @@ ParamSet = R6Class("ParamSet", #' Params for which dependencies are not satisfied should not be part of `x`. #' #' @param xs (named `list()`). + #' @param check_strict (`logical(1)`)\cr + #' Determines if dependencies and required parameters are checked. #' @return If successful `TRUE`, if not a string with the error message. - check = function(xs) { + check = function(xs, check_strict = FALSE) { + assert_flag(check_strict) ok = check_list(xs, names = "unique") if (!isTRUE(ok)) { @@ -227,33 +247,41 @@ ParamSet = R6Class("ParamSet", } } - # check dependencies - deps = self$deps - if (nrow(deps)) { - for (j in seq_row(deps)) { + if (check_strict) { + # check required + required = setdiff(names(keep(params, function(p) "required" %in% p$tags)), ns) + if (length(required) > 0L) { + stop(sprintf("Missing required parameters: %s", str_collapse(required))) + } - p1id = deps$id[j] - p2id = deps$on[j] - if (inherits(xs[[p1id]], "TuneToken") || inherits(xs[[p2id]], "TuneToken")) { - next # be lenient with dependencies when any parameter involved is a TuneToken - } - # we are ONLY ok if: - # - if param is there, then parent must be there, then cond must be true - # - if param is not there - cond = deps$cond[[j]] - ok = (p1id %in% ns && p2id %in% ns && cond$test(xs[[p2id]])) || - (p1id %nin% ns) - if (isFALSE(ok)) { - message = sprintf("The parameter '%s' can only be set if the following condition is met '%s'.", - p1id, cond$as_string(p2id)) - val = xs[[p2id]] - if (is.null(val)) { - message = sprintf(paste("%s Instead the parameter value for '%s' is not set at all.", - "Try setting '%s' to a value that satisfies the condition"), message, p2id, p2id) - } else { - message = sprintf("%s Instead the current parameter value is: %s=%s", message, p2id, val) + # check dependencies + deps = self$deps + if (nrow(deps)) { + for (j in seq_row(deps)) { + + p1id = deps$id[j] + p2id = deps$on[j] + if (inherits(xs[[p1id]], "TuneToken") || inherits(xs[[p2id]], "TuneToken")) { + next # be lenient with dependencies when any parameter involved is a TuneToken + } + # we are ONLY ok if: + # - if param is there, then parent must be there, then cond must be true + # - if param is not there + cond = deps$cond[[j]] + ok = (p1id %in% ns && p2id %in% ns && cond$test(xs[[p2id]])) || + (p1id %nin% ns) + if (isFALSE(ok)) { + message = sprintf("The parameter '%s' can only be set if the following condition is met '%s'.", + p1id, cond$as_string(p2id)) + val = xs[[p2id]] + if (is.null(val)) { + message = sprintf(paste("%s Instead the parameter value for '%s' is not set at all.", + "Try setting '%s' to a value that satisfies the condition"), message, p2id, p2id) + } else { + message = sprintf("%s Instead the current parameter value is: %s=%s", message, p2id, val) + } + return(message) } - return(message) } } } @@ -268,8 +296,10 @@ ParamSet = R6Class("ParamSet", #' Params for which dependencies are not satisfied should not be part of `x`. #' #' @param xs (named `list()`). + #' @param check_strict (`logical(1)`)\cr + #' Determines if dependencies and required parameters are checked. #' @return If successful `TRUE`, if not `FALSE`. - test = function(xs) makeTest(res = self$check(xs)), + test = function(xs, check_strict = FALSE) makeTest(res = self$check(xs, check_strict)), #' @description #' \pkg{checkmate}-like assert-function. Takes a named list. @@ -281,8 +311,12 @@ ParamSet = R6Class("ParamSet", #' @param .var.name (`character(1)`)\cr #' Name of the checked object to print in error messages.\cr #' Defaults to the heuristic implemented in [vname][checkmate::vname]. + #' @param check_strict (`logical(1)`)\cr + #' Determines if dependencies and required parameters are checked. #' @return If successful `xs` invisibly, if not an error message. - assert = function(xs, .var.name = vname(xs)) makeAssertion(xs, self$check(xs), .var.name, NULL), # nolint + assert = function(xs, .var.name = vname(xs), check_strict = FALSE) { + makeAssertion(xs, self$check(xs, check_strict), .var.name, NULL) # nolint + }, #' @description #' \pkg{checkmate}-like check-function. Takes a [data.table::data.table] @@ -292,11 +326,13 @@ ParamSet = R6Class("ParamSet", #' dependencies are not satisfied should be set to `NA` in `xdt`. #' #' @param xdt ([data.table::data.table] | `data.frame()`). + #' @param check_strict (`logical(1)`)\cr + #' Determines if dependencies and required parameters are checked. #' @return If successful `TRUE`, if not a string with the error message. - check_dt = function(xdt) { + check_dt = function(xdt, check_strict = FALSE) { xss = map(transpose_list(xdt), discard, is.na) for (xs in xss) { - ok = self$check(xs) + ok = self$check(xs, check_strict) if (!isTRUE(ok)) { return(ok) } diff --git a/man/ParamSet.Rd b/man/ParamSet.Rd index 14fca66c..f8a35988 100644 --- a/man/ParamSet.Rd +++ b/man/ParamSet.Rd @@ -259,7 +259,8 @@ Only returns values of parameters that satisfy all conditions. is_bounded = NULL, tags = NULL, type = "with_token", - check_required = TRUE + check_required = TRUE, + remove_dependencies = TRUE )}\if{html}{\out{}} } @@ -277,6 +278,9 @@ Return values \code{with_token}, \code{without_token} or \code{only_token}?} \item{\code{check_required}}{(\code{logical(1)})\cr Check if all required parameters are set?} + +\item{\code{remove_dependencies}}{(\code{logical(1)})\cr +Determines if values of parameters with unsatisfied dependencies are removed.} } \if{html}{\out{}} } @@ -327,13 +331,16 @@ A point x is feasible, if it configures a subset of params, all individual param constraints are satisfied and all dependencies are satisfied. Params for which dependencies are not satisfied should not be part of \code{x}. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{ParamSet$check(xs)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{ParamSet$check(xs, check_strict = FALSE)}\if{html}{\out{
}} } \subsection{Arguments}{ \if{html}{\out{
}} \describe{ \item{\code{xs}}{(named \code{list()}).} + +\item{\code{check_strict}}{(\code{logical(1)})\cr +Determines if dependencies and required parameters are checked.} } \if{html}{\out{
}} } @@ -350,13 +357,16 @@ A point x is feasible, if it configures a subset of params, all individual param constraints are satisfied and all dependencies are satisfied. Params for which dependencies are not satisfied should not be part of \code{x}. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{ParamSet$test(xs)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{ParamSet$test(xs, check_strict = FALSE)}\if{html}{\out{
}} } \subsection{Arguments}{ \if{html}{\out{
}} \describe{ \item{\code{xs}}{(named \code{list()}).} + +\item{\code{check_strict}}{(\code{logical(1)})\cr +Determines if dependencies and required parameters are checked.} } \if{html}{\out{
}} } @@ -373,7 +383,7 @@ A point x is feasible, if it configures a subset of params, all individual param constraints are satisfied and all dependencies are satisfied. Params for which dependencies are not satisfied should not be part of \code{x}. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{ParamSet$assert(xs, .var.name = vname(xs))}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{ParamSet$assert(xs, .var.name = vname(xs), check_strict = FALSE)}\if{html}{\out{
}} } \subsection{Arguments}{ @@ -384,6 +394,9 @@ Params for which dependencies are not satisfied should not be part of \code{x}. \item{\code{.var.name}}{(\code{character(1)})\cr Name of the checked object to print in error messages.\cr Defaults to the heuristic implemented in \link[checkmate:vname]{vname}.} + +\item{\code{check_strict}}{(\code{logical(1)})\cr +Determines if dependencies and required parameters are checked.} } \if{html}{\out{}} } @@ -401,13 +414,16 @@ if it configures a subset of params, all individual param constraints are satisfied and all dependencies are satisfied. Params for which dependencies are not satisfied should be set to \code{NA} in \code{xdt}. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{ParamSet$check_dt(xdt)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{ParamSet$check_dt(xdt, check_strict = FALSE)}\if{html}{\out{
}} } \subsection{Arguments}{ \if{html}{\out{
}} \describe{ \item{\code{xdt}}{(\link[data.table:data.table]{data.table::data.table} | \code{data.frame()}).} + +\item{\code{check_strict}}{(\code{logical(1)})\cr +Determines if dependencies and required parameters are checked.} } \if{html}{\out{
}} } diff --git a/tests/testthat/test_ParamSet.R b/tests/testthat/test_ParamSet.R index 2785b3ef..d78bfac8 100644 --- a/tests/testthat/test_ParamSet.R +++ b/tests/testthat/test_ParamSet.R @@ -104,7 +104,7 @@ test_that("ParamSet$check", { ps = ParamLgl$new("x")$rep(2) ps$add_dep("x_rep_1", "x_rep_2", CondEqual$new(TRUE)) - expect_string(ps$check(list(x_rep_1 = FALSE, x_rep_2 = FALSE)), fixed = "x_rep_2 = TRUE") + expect_string(ps$check(list(x_rep_1 = FALSE, x_rep_2 = FALSE), check_strict = TRUE), fixed = "x_rep_2 = TRUE") }) test_that("we cannot create ParamSet with non-strict R names", { @@ -277,6 +277,95 @@ test_that("ParamSet$get_values", { expect_equal(ps$get_values(), list(x = 1, y = 2)) expect_equal(ps$get_values(class = c("ParamInt", "ParamFct")), list(y = 2)) expect_equal(ps$get_values(is_bounded = TRUE), list(y = 2)) + + # 2 dependencies + pss = ps( + a = p_fct(c("b", "c")), + b = p_int(depends = a == "b"), + c = p_int(depends = a == "c") + ) + + pss$values$b = 1 + expect_list(pss$get_values(), len = 0) + expect_equal(pss$get_values(remove_dependencies = FALSE), list(b = 1)) + + pss$values$a = "c" + expect_equal(pss$get_values(), list(a = "c")) + expect_equal(pss$get_values(remove_dependencies = FALSE), list(b = 1, a = "c")) + + pss$values$a = "b" + expect_equal(pss$get_values(), list(b = 1, a = "b")) + + pss$values$a = "b" + pss$values$b = 1 + pss$values$c = 1 + + expect_equal(pss$get_values(), list(b = 1, a = "b")) + expect_equal(pss$get_values(remove_dependencies = FALSE), list(b = 1, a = "b", c = 1)) + + # 2 dependencies and tune token + pss = ps( + a = p_fct(c("b", "c")), + b = p_int(depends = a == "b"), + c = p_int(depends = a == "c") + ) + + pss$values$a = to_tune() + pss$values$b = 1 + pss$values$c = 1 + + expect_equal(pss$get_values(), list(a = to_tune(), b = 1L, c = 1L)) + + # 3 dependencies + pss = ps( + a = p_fct(c("b", "c")), + b = p_int(depends = a == "b"), + c = p_int(depends = a == "c"), + d = p_lgl(), + e = p_int(depends = d == TRUE) + ) + + pss$values$a = "b" + pss$values$b = 1 + pss$values$c = 1 + + expect_equal(pss$get_values(), list(a = "b", b = 1L)) + expect_equal(pss$get_values(remove_dependencies = FALSE), list(a = "b", b = 1L, c = 1L)) + + pss$values$e = 1 + + expect_equal(pss$get_values(), list(a = "b", b = 1L)) + + pss$values$d = FALSE + + expect_equal(pss$get_values(), list(a = "b", b = 1L, d = FALSE)) + + pss$values$d = TRUE + + expect_equal(pss$get_values(), list(a = "b", b = 1L, e = 1, d = TRUE)) + + # nested dependencies + pss = ps( + a = p_fct(c("b", "c")), + b = p_int(depends = a == "b"), + c = p_int(depends = b == 1) + ) + + pss$values$c = 1 + expect_list(pss$get_values(), len = 0) + + pss$values$b = 1 + expect_list(pss$get_values(), len = 0) + + pss$values$a = "b" + expect_equal(pss$get_values(), list(c = 1, b = 1, a = "b")) + + pss$values = list() + pss$values$b = 1 + expect_list(pss$get_values(), len = 0) + + pss$values$a = "b" + expect_equal(pss$get_values(), list(b = 1, a = "b")) }) test_that("required tag", { @@ -335,11 +424,11 @@ test_that("ParamSet$check_dt", { ps = ParamLgl$new("x")$rep(2) ps$add_dep("x_rep_2", "x_rep_1", CondEqual$new(TRUE)) xdt = data.table(x_rep_1 = c(TRUE, TRUE), x_rep_2 = c(FALSE, TRUE)) - expect_true(ps$check_dt(xdt)) + expect_true(ps$check_dt(xdt, check_strict = TRUE)) xdt = data.table(x_rep_1 = c(TRUE, TRUE, FALSE), x_rep_2 = c(FALSE, TRUE, FALSE)) - expect_character(ps$check_dt(xdt), fixed = "x_rep_1 = TRUE") + expect_character(ps$check_dt(xdt, check_strict = TRUE), fixed = "x_rep_1 = TRUE") xdt = data.table(x_rep_1 = c(TRUE, TRUE, FALSE), x_rep_2 = c(FALSE, TRUE, NA)) - expect_true(ps$check_dt(xdt)) + expect_true(ps$check_dt(xdt, check_strict = TRUE)) }) test_that("rd_info.ParamSet", { diff --git a/tests/testthat/test_ParamSetCollection.R b/tests/testthat/test_ParamSetCollection.R index 541707d4..71bdf5e6 100644 --- a/tests/testthat/test_ParamSetCollection.R +++ b/tests/testthat/test_ParamSetCollection.R @@ -114,8 +114,8 @@ test_that("deps", { # check deps across sets psc$add_dep("ps2.d", on = "ps1.f", CondEqual$new("a")) expect_data_table(psc$deps, nrows = 2, ncols = 3) - expect_true(psc$check(list(ps1.f = "a", ps1.d = 0, ps2.d = 0))) - expect_string(psc$check(list(ps2.d = 0))) + expect_true(psc$check(list(ps1.f = "a", ps1.d = 0, ps2.d = 0), check_strict = TRUE)) + expect_string(psc$check(list(ps2.d = 0), check_strict = TRUE)) # ps1 and ps2 should not be changed expect_equal(ps1clone, ps1) diff --git a/tests/testthat/test_deps.R b/tests/testthat/test_deps.R index 0d3ab275..176066bb 100644 --- a/tests/testthat/test_deps.R +++ b/tests/testthat/test_deps.R @@ -6,19 +6,19 @@ test_that("basic example works", { ps$add_dep("th_param_int", on = "th_param_fct", CondEqual$new("a")) expect_true(ps$has_deps) x = list(th_param_int = 1) - expect_string(ps$check(x), fixed = "The parameter 'th_param_int' can only be set") + expect_string(ps$check(x, check_strict = TRUE), fixed = "The parameter 'th_param_int' can only be set") x = list(th_param_int = 1, th_param_fct = "a") - expect_true(ps$check(x)) + expect_true(ps$check(x, check_strict = TRUE)) x = list(th_param_int = 1, th_param_fct = "b") - expect_string(ps$check(x), fixed = "The parameter 'th_param_int' can only be set") + expect_string(ps$check(x, check_strict = TRUE), fixed = "The parameter 'th_param_int' can only be set") x = list(th_param_int = NA, th_param_fct = "b") - expect_string(ps$check(x), fixed = "May not be NA") + expect_string(ps$check(x, check_strict = TRUE), fixed = "May not be NA") x = list(th_param_fct = "a") - expect_true(ps$check(x)) + expect_true(ps$check(x, check_strict = TRUE)) x = list(th_param_fct = "b") - expect_true(ps$check(x)) + expect_true(ps$check(x, check_strict = TRUE)) x = list(th_param_dbl = 1.3) - expect_true(ps$check(x)) + expect_true(ps$check(x, check_strict = TRUE)) # test printer, with 2 deps ps = th_paramset_full() @@ -40,17 +40,17 @@ test_that("nested deps work", { ps$add_dep("th_param_lgl", on = "th_param_fct", CondEqual$new("c")) x1 = list(th_param_int = 1) - expect_string(ps$check(x1), fixed = "The parameter 'th_param_int' can only be set") + expect_string(ps$check(x1, check_strict = TRUE), fixed = "The parameter 'th_param_int' can only be set") x2 = list(th_param_int = 1, th_param_fct = "b") - expect_true(ps$check(x2)) + expect_true(ps$check(x2, check_strict = TRUE)) x3 = list(th_param_int = 1, th_param_fct = "c") - expect_string(ps$check(x3), fixed = "The parameter 'th_param_int' can only be set") + expect_string(ps$check(x3, check_strict = TRUE), fixed = "The parameter 'th_param_int' can only be set") x4 = list(th_param_fct = "a") - expect_true(ps$check(x4)) + expect_true(ps$check(x4, check_strict = TRUE)) x5 = list(th_param_dbl = 1.3) - expect_string(ps$check(x5), fixed = "The parameter 'th_param_dbl' can only be set") + expect_string(ps$check(x5, check_strict = TRUE), fixed = "The parameter 'th_param_dbl' can only be set") x6 = list(th_param_fct = "c", th_param_lgl = TRUE, th_param_dbl = 3) - expect_true(ps$check(x6)) + expect_true(ps$check(x6, check_strict = TRUE)) }) @@ -72,11 +72,11 @@ test_that("adding 2 sets with deps works", { expect_true(ps1$has_deps) expect_data_table(ps1$deps, nrows = 2) # do a few feasibility checks on larger set - expect_true(ps1$test(list(x1 = "a", y1 = 1, x2 = "a", y2 = 1))) - expect_true(ps1$test(list(x1 = "a", y1 = 1))) - expect_false(ps1$test(list(x1 = "b", y1 = 1))) - expect_true(ps1$test(list(x2 = "a", y2 = 1))) - expect_false(ps1$test(list(x2 = "b", y2 = 1))) + expect_true(ps1$test(list(x1 = "a", y1 = 1, x2 = "a", y2 = 1), check_strict = TRUE)) + expect_true(ps1$test(list(x1 = "a", y1 = 1), check_strict = TRUE)) + expect_false(ps1$test(list(x1 = "b", y1 = 1), check_strict = TRUE)) + expect_true(ps1$test(list(x2 = "a", y2 = 1), check_strict = TRUE)) + expect_false(ps1$test(list(x2 = "b", y2 = 1), check_strict = TRUE)) }) test_that("subsetting with deps works", { @@ -109,8 +109,8 @@ test_that("we can also dep on integer", { )) ps$add_dep("d", on = "i", CondAnyOf$new(1:3)) - expect_true(ps$check(list(i = 2, d = 5))) - expect_string(ps$check(list(i = 5, d = 5))) + expect_true(ps$check(list(i = 2, d = 5), check_strict = TRUE)) + expect_string(ps$check(list(i = 5, d = 5), check_strict = TRUE)) }) test_that("deps make sense", {