Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove parameter values of unsatisfied dependencies #343

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 69 additions & 33 deletions R/ParamSet.R
Original file line number Diff line number Diff line change
Expand Up @@ -132,28 +132,45 @@ ParamSet = R6Class("ParamSet",
#' Return values `with_token`, `without_token` or `only_token`?
#' @param check_required (`logical(1)`)\cr
#' Check if all required parameters are set?
#' @param remove_dependencies (`logical(1)`)\cr
#' Determines if values of parameters with unsatisfied dependencies are removed.
#' @return Named `list()`.
get_values = function(class = NULL, is_bounded = NULL, tags = NULL,
type = "with_token", check_required = TRUE) {
get_values = function(class = NULL, is_bounded = NULL, tags = NULL, type = "with_token", check_required = TRUE,
remove_dependencies = TRUE) {
assert_choice(type, c("with_token", "without_token", "only_token"))
assert_flag(check_required)
assert_flag(remove_dependencies)
values = self$values
params = self$params_unid
ns = names(values)
deps = self$deps

if (type == "without_token") {
values = discard(values, is, "TuneToken")
} else if (type == "only_token") {
values = keep(values, is, "TuneToken")
}

if(check_required) {
if (check_required) {
required = setdiff(names(keep(params, function(p) "required" %in% p$tags)), ns)
if (length(required) > 0L) {
stop(sprintf("Missing required parameters: %s", str_collapse(required)))
}
}

if (remove_dependencies) {
if (nrow(deps)) {
for (j in seq_row(deps)) {
p1id = deps$id[j]
p2id = deps$on[j]
cond = deps$cond[[j]]
if (p1id %in% ns && !inherits(values[[p2id]], "TuneToken") && !isTRUE(cond$test(values[[p2id]]))) {
values[p1id] = NULL
}
}
}
}

values[intersect(names(values), self$ids(class = class, is_bounded = is_bounded, tags = tags))]
},

Expand Down Expand Up @@ -203,8 +220,11 @@ ParamSet = R6Class("ParamSet",
#' Params for which dependencies are not satisfied should not be part of `x`.
#'
#' @param xs (named `list()`).
#' @param check_strict (`logical(1)`)\cr
#' Determines if dependencies and required parameters are checked.
#' @return If successful `TRUE`, if not a string with the error message.
check = function(xs) {
check = function(xs, check_strict = FALSE) {
assert_flag(check_strict)

ok = check_list(xs, names = "unique")
if (!isTRUE(ok)) {
Expand All @@ -227,33 +247,41 @@ ParamSet = R6Class("ParamSet",
}
}

# check dependencies
deps = self$deps
if (nrow(deps)) {
for (j in seq_row(deps)) {
if (check_strict) {
# check required
required = setdiff(names(keep(params, function(p) "required" %in% p$tags)), ns)
if (length(required) > 0L) {
stop(sprintf("Missing required parameters: %s", str_collapse(required)))
}

p1id = deps$id[j]
p2id = deps$on[j]
if (inherits(xs[[p1id]], "TuneToken") || inherits(xs[[p2id]], "TuneToken")) {
next # be lenient with dependencies when any parameter involved is a TuneToken
}
# we are ONLY ok if:
# - if param is there, then parent must be there, then cond must be true
# - if param is not there
cond = deps$cond[[j]]
ok = (p1id %in% ns && p2id %in% ns && cond$test(xs[[p2id]])) ||
(p1id %nin% ns)
if (isFALSE(ok)) {
message = sprintf("The parameter '%s' can only be set if the following condition is met '%s'.",
p1id, cond$as_string(p2id))
val = xs[[p2id]]
if (is.null(val)) {
message = sprintf(paste("%s Instead the parameter value for '%s' is not set at all.",
"Try setting '%s' to a value that satisfies the condition"), message, p2id, p2id)
} else {
message = sprintf("%s Instead the current parameter value is: %s=%s", message, p2id, val)
# check dependencies
deps = self$deps
if (nrow(deps)) {
for (j in seq_row(deps)) {

p1id = deps$id[j]
p2id = deps$on[j]
if (inherits(xs[[p1id]], "TuneToken") || inherits(xs[[p2id]], "TuneToken")) {
next # be lenient with dependencies when any parameter involved is a TuneToken
}
# we are ONLY ok if:
# - if param is there, then parent must be there, then cond must be true
# - if param is not there
cond = deps$cond[[j]]
ok = (p1id %in% ns && p2id %in% ns && cond$test(xs[[p2id]])) ||
(p1id %nin% ns)
if (isFALSE(ok)) {
message = sprintf("The parameter '%s' can only be set if the following condition is met '%s'.",
p1id, cond$as_string(p2id))
val = xs[[p2id]]
if (is.null(val)) {
message = sprintf(paste("%s Instead the parameter value for '%s' is not set at all.",
"Try setting '%s' to a value that satisfies the condition"), message, p2id, p2id)
} else {
message = sprintf("%s Instead the current parameter value is: %s=%s", message, p2id, val)
}
return(message)
}
return(message)
}
}
}
Expand All @@ -268,8 +296,10 @@ ParamSet = R6Class("ParamSet",
#' Params for which dependencies are not satisfied should not be part of `x`.
#'
#' @param xs (named `list()`).
#' @param check_strict (`logical(1)`)\cr
#' Determines if dependencies and required parameters are checked.
#' @return If successful `TRUE`, if not `FALSE`.
test = function(xs) makeTest(res = self$check(xs)),
test = function(xs, check_strict = FALSE) makeTest(res = self$check(xs, check_strict)),

#' @description
#' \pkg{checkmate}-like assert-function. Takes a named list.
Expand All @@ -281,8 +311,12 @@ ParamSet = R6Class("ParamSet",
#' @param .var.name (`character(1)`)\cr
#' Name of the checked object to print in error messages.\cr
#' Defaults to the heuristic implemented in [vname][checkmate::vname].
#' @param check_strict (`logical(1)`)\cr
#' Determines if dependencies and required parameters are checked.
#' @return If successful `xs` invisibly, if not an error message.
assert = function(xs, .var.name = vname(xs)) makeAssertion(xs, self$check(xs), .var.name, NULL), # nolint
assert = function(xs, .var.name = vname(xs), check_strict = FALSE) {
makeAssertion(xs, self$check(xs, check_strict), .var.name, NULL) # nolint
},

#' @description
#' \pkg{checkmate}-like check-function. Takes a [data.table::data.table]
Expand All @@ -292,11 +326,13 @@ ParamSet = R6Class("ParamSet",
#' dependencies are not satisfied should be set to `NA` in `xdt`.
#'
#' @param xdt ([data.table::data.table] | `data.frame()`).
#' @param check_strict (`logical(1)`)\cr
#' Determines if dependencies and required parameters are checked.
#' @return If successful `TRUE`, if not a string with the error message.
check_dt = function(xdt) {
check_dt = function(xdt, check_strict = FALSE) {
xss = map(transpose_list(xdt), discard, is.na)
for (xs in xss) {
ok = self$check(xs)
ok = self$check(xs, check_strict)
if (!isTRUE(ok)) {
return(ok)
}
Expand Down
26 changes: 21 additions & 5 deletions man/ParamSet.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

97 changes: 93 additions & 4 deletions tests/testthat/test_ParamSet.R
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ test_that("ParamSet$check", {

ps = ParamLgl$new("x")$rep(2)
ps$add_dep("x_rep_1", "x_rep_2", CondEqual$new(TRUE))
expect_string(ps$check(list(x_rep_1 = FALSE, x_rep_2 = FALSE)), fixed = "x_rep_2 = TRUE")
expect_string(ps$check(list(x_rep_1 = FALSE, x_rep_2 = FALSE), check_strict = TRUE), fixed = "x_rep_2 = TRUE")
})

test_that("we cannot create ParamSet with non-strict R names", {
Expand Down Expand Up @@ -277,6 +277,95 @@ test_that("ParamSet$get_values", {
expect_equal(ps$get_values(), list(x = 1, y = 2))
expect_equal(ps$get_values(class = c("ParamInt", "ParamFct")), list(y = 2))
expect_equal(ps$get_values(is_bounded = TRUE), list(y = 2))

# 2 dependencies
pss = ps(
a = p_fct(c("b", "c")),
b = p_int(depends = a == "b"),
c = p_int(depends = a == "c")
)

pss$values$b = 1
expect_list(pss$get_values(), len = 0)
expect_equal(pss$get_values(remove_dependencies = FALSE), list(b = 1))

pss$values$a = "c"
expect_equal(pss$get_values(), list(a = "c"))
expect_equal(pss$get_values(remove_dependencies = FALSE), list(b = 1, a = "c"))

pss$values$a = "b"
expect_equal(pss$get_values(), list(b = 1, a = "b"))

pss$values$a = "b"
pss$values$b = 1
pss$values$c = 1

expect_equal(pss$get_values(), list(b = 1, a = "b"))
expect_equal(pss$get_values(remove_dependencies = FALSE), list(b = 1, a = "b", c = 1))

# 2 dependencies and tune token
pss = ps(
a = p_fct(c("b", "c")),
b = p_int(depends = a == "b"),
c = p_int(depends = a == "c")
)

pss$values$a = to_tune()
pss$values$b = 1
pss$values$c = 1

expect_equal(pss$get_values(), list(a = to_tune(), b = 1L, c = 1L))

# 3 dependencies
pss = ps(
a = p_fct(c("b", "c")),
b = p_int(depends = a == "b"),
c = p_int(depends = a == "c"),
d = p_lgl(),
e = p_int(depends = d == TRUE)
)

pss$values$a = "b"
pss$values$b = 1
pss$values$c = 1

expect_equal(pss$get_values(), list(a = "b", b = 1L))
expect_equal(pss$get_values(remove_dependencies = FALSE), list(a = "b", b = 1L, c = 1L))

pss$values$e = 1

expect_equal(pss$get_values(), list(a = "b", b = 1L))

pss$values$d = FALSE

expect_equal(pss$get_values(), list(a = "b", b = 1L, d = FALSE))

pss$values$d = TRUE

expect_equal(pss$get_values(), list(a = "b", b = 1L, e = 1, d = TRUE))

# nested dependencies
pss = ps(
a = p_fct(c("b", "c")),
b = p_int(depends = a == "b"),
c = p_int(depends = b == 1)
)

pss$values$c = 1
expect_list(pss$get_values(), len = 0)

pss$values$b = 1
expect_list(pss$get_values(), len = 0)

pss$values$a = "b"
expect_equal(pss$get_values(), list(c = 1, b = 1, a = "b"))

pss$values = list()
pss$values$b = 1
expect_list(pss$get_values(), len = 0)

pss$values$a = "b"
expect_equal(pss$get_values(), list(b = 1, a = "b"))
})

test_that("required tag", {
Expand Down Expand Up @@ -335,11 +424,11 @@ test_that("ParamSet$check_dt", {
ps = ParamLgl$new("x")$rep(2)
ps$add_dep("x_rep_2", "x_rep_1", CondEqual$new(TRUE))
xdt = data.table(x_rep_1 = c(TRUE, TRUE), x_rep_2 = c(FALSE, TRUE))
expect_true(ps$check_dt(xdt))
expect_true(ps$check_dt(xdt, check_strict = TRUE))
xdt = data.table(x_rep_1 = c(TRUE, TRUE, FALSE), x_rep_2 = c(FALSE, TRUE, FALSE))
expect_character(ps$check_dt(xdt), fixed = "x_rep_1 = TRUE")
expect_character(ps$check_dt(xdt, check_strict = TRUE), fixed = "x_rep_1 = TRUE")
xdt = data.table(x_rep_1 = c(TRUE, TRUE, FALSE), x_rep_2 = c(FALSE, TRUE, NA))
expect_true(ps$check_dt(xdt))
expect_true(ps$check_dt(xdt, check_strict = TRUE))
})

test_that("rd_info.ParamSet", {
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test_ParamSetCollection.R
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ test_that("deps", {
# check deps across sets
psc$add_dep("ps2.d", on = "ps1.f", CondEqual$new("a"))
expect_data_table(psc$deps, nrows = 2, ncols = 3)
expect_true(psc$check(list(ps1.f = "a", ps1.d = 0, ps2.d = 0)))
expect_string(psc$check(list(ps2.d = 0)))
expect_true(psc$check(list(ps1.f = "a", ps1.d = 0, ps2.d = 0), check_strict = TRUE))
expect_string(psc$check(list(ps2.d = 0), check_strict = TRUE))

# ps1 and ps2 should not be changed
expect_equal(ps1clone, ps1)
Expand Down
Loading