diff --git a/R/ParamSet.R b/R/ParamSet.R index f5ede7c9..26944879 100644 --- a/R/ParamSet.R +++ b/R/ParamSet.R @@ -230,7 +230,34 @@ ParamSet = R6Class("ParamSet", private$.deps = rbind(private$.deps, data.table(id = id, on = on, cond = list(cond))) invisible(self) }, - + learnerside = function(last = TRUE) { + if (!self$has_interface) + return(self) + if (last) + return(private$.learnerside$learnerside(last = TRUE)) + private$.learnerside + }, + add_interface = function(param_set) { + private$.learnerside = self$clone(deep = TRUE) + private$copy_param_set(param_set) + }, + remove_interface = function(param_set, all = FALSE) { + if (!self$has_interface) + stop("no interface to remove") + replace_with = self$learnerside(last = all) + private$copy_param_set(replace_with) + private$.learnerside = replace_with$.learnerside + }, + get_values = function(class = NULL, tags = NULL, learnerside = FALSE, env) { + if (learnerside && self$has_interface) { + private$.learnerside$values = self$trafo(x = self$values, env = env) + return(private$.learnerside$get_values( + class = class, tags = tags, learnerside = learnerside, env = env + )) + } + values = self$values + values[intersect(names(values), self$ids(class = class, tags = tags))] + }, # printer, prints the set as a datatable, with the option to hide some cols print = function(..., hide.cols = c("nlevels", "is_bounded", "special_vals", "tags", "storage_type")) { catf("ParamSet: %s", self$set_id) @@ -290,7 +317,10 @@ ParamSet = R6Class("ParamSet", if (missing(f)) { private$.trafo } else { - assert_function(f, args = c("x", "param_set"), null.ok = TRUE) + assert( + check_function(f, args = c("x", "param_set"), null.ok = TRUE), + check_function(f, args = c("x", "env"), null.ok = TRUE) + ) private$.trafo = f } }, @@ -304,6 +334,7 @@ ParamSet = R6Class("ParamSet", if (length(xs) == 0L) xs = named_list() private$.values = xs }, + has_interface = function() !is.null(private$.learnerside), has_deps = function() nrow(private$.deps) > 0L ), @@ -315,7 +346,14 @@ ParamSet = R6Class("ParamSet", .deps = data.table(id = character(0L), on = character(0L), cond = list()), # return a slot / AB, as a named vec, named with id (and can enfore a certain vec-type) get_member_with_idnames = function(member, astype) set_names(astype(map(self$params, member)), self$ids()), - + .learnerside = NULL, + copy_param_set = function(param_set) { + private$.params = param_set$params + private$.deps = param_set$deps + private$.values = param_set$values + private$.trafo = param_set$trafo + invisible(self) + }, deep_clone = function(name, value) { switch(name, ".params" = map(value, function(x) x$clone(deep = TRUE)), diff --git a/tests/testthat/test_interface.R b/tests/testthat/test_interface.R new file mode 100644 index 00000000..4be8bb09 --- /dev/null +++ b/tests/testthat/test_interface.R @@ -0,0 +1,22 @@ +context("Interface") + +test_that("interface use case works", { + ps_orig = ParamSet$new(list(ParamInt$new("mtry", 0))) + ps = ps_orig$clone() + ps$subset(setdiff(ps$ids(), "mtry")) + ps$add(ParamDbl$new("mtry.pexp", 0, 1)) + ps$trafo = function(x, env) { + x$mtry = round(env$task$ncol ^ x$mtry.pexp) + x$mtry.pexp = NULL + x + } + + ps_orig$values$mtry = 3 + ps_orig$add_interface(ps) + ps_orig$values$mtry.pexp = 0.7 + expect_error({ps_orig$values$mtry = 3}) + + expect_equal(ps_orig$get_values(), list(mtry.pexp = 0.7)) + expect_equal(ps_orig$get_values(learnerside = TRUE, + env = list(task = list(ncol = 200))), list(mtry = round(200^0.7))) +})