-
Notifications
You must be signed in to change notification settings - Fork 356
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make sure pruning does prune #1014
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,7 @@ | |
import pickle | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changes to this files are unrelated, just a convention update since I stumbled on it in the process |
||
import numpy as np | ||
import pytest | ||
from nevergrad.parametrization import parameter as p | ||
import nevergrad as ng | ||
from nevergrad.common import testing | ||
from nevergrad.functions import ArtificialFunction | ||
import nevergrad.common.typing as tp | ||
|
@@ -20,12 +20,12 @@ def _arg_return(*args: tp.Any, **kwargs: tp.Any) -> float: | |
|
||
|
||
def test_experiment_function() -> None: | ||
param = p.Instrumentation( | ||
p.Choice([1, 12]), | ||
param = ng.p.Instrumentation( | ||
ng.p.Choice([1, 12]), | ||
"constant", | ||
p.Array(shape=(2, 2)), | ||
ng.p.Array(shape=(2, 2)), | ||
constkwarg="blublu", | ||
plop=p.Choice([3, 4]), | ||
plop=ng.p.Choice([3, 4]), | ||
) | ||
with pytest.raises(RuntimeError): | ||
base.ExperimentFunction(_arg_return, param) | ||
|
@@ -53,11 +53,11 @@ def test_experiment_function() -> None: | |
def test_instrumented_function_kwarg_order() -> None: | ||
ifunc = base.ExperimentFunction( | ||
_arg_return, | ||
p.Instrumentation( | ||
kw4=p.Choice([1, 0]), | ||
ng.p.Instrumentation( | ||
kw4=ng.p.Choice([1, 0]), | ||
kw2="constant", | ||
kw3=p.Array(shape=(2, 2)), | ||
kw1=p.Scalar(2.0).set_mutation(sigma=2.0), | ||
kw3=ng.p.Array(shape=(2, 2)), | ||
kw1=ng.p.Scalar(2.0).set_mutation(sigma=2.0), | ||
).set_name("test"), | ||
) | ||
np.testing.assert_equal(ifunc.dimension, 7) | ||
|
@@ -74,16 +74,16 @@ def __call__(self, x: float, y: float = 0) -> float: | |
|
||
|
||
def test_callable_parametrization() -> None: | ||
ifunc = base.ExperimentFunction(lambda x: x ** 2, p.Scalar(2).set_mutation(2).set_name("")) # type: ignore | ||
ifunc = base.ExperimentFunction(lambda x: x ** 2, ng.p.Scalar(2).set_mutation(2).set_name("")) # type: ignore | ||
np.testing.assert_equal(ifunc.descriptors["name"], "<lambda>") | ||
ifunc = base.ExperimentFunction(_Callable(), p.Scalar(2).set_mutation(sigma=2).set_name("")) | ||
ifunc = base.ExperimentFunction(_Callable(), ng.p.Scalar(2).set_mutation(sigma=2).set_name("")) | ||
np.testing.assert_equal(ifunc.descriptors["name"], "_Callable") | ||
# test automatic filling | ||
assert len(ifunc._auto_init) == 2 | ||
|
||
|
||
def test_packed_function() -> None: | ||
ifunc = base.ExperimentFunction(_Callable(), p.Scalar(1).set_name("")) | ||
ifunc = base.ExperimentFunction(_Callable(), ng.p.Scalar(1).set_name("")) | ||
with pytest.raises(AssertionError): | ||
base.MultiExperiment([ifunc, ifunc], [100, 100]) | ||
pfunc = base.MultiExperiment([ifunc, ifunc.copy()], [100, 100]) | ||
|
@@ -92,7 +92,7 @@ def test_packed_function() -> None: | |
|
||
|
||
def test_deterministic_data_setter() -> None: | ||
instru = p.Instrumentation(p.Choice([0, 1, 2, 3]), y=p.Choice([0, 1, 2, 3])).set_name("") | ||
instru = ng.p.Instrumentation(ng.p.Choice([0, 1, 2, 3]), y=ng.p.Choice([0, 1, 2, 3])).set_name("") | ||
ifunc = base.ExperimentFunction(_Callable(), instru) | ||
data = [0.01, 0, 0, 0, 0.01, 0, 0, 0] | ||
for _ in range(20): | ||
|
@@ -113,28 +113,28 @@ def test_deterministic_data_setter() -> None: | |
|
||
|
||
@testing.parametrized( | ||
floats=((p.Scalar(), p.Scalar(init=12.0)), True, False), | ||
array_int=((p.Scalar(), p.Array(shape=(1,)).set_integer_casting()), False, False), | ||
softmax_noisy=((p.Choice(["blue", "red"]), p.Array(shape=(1,))), True, True), | ||
floats=((ng.p.Scalar(), ng.p.Scalar(init=12.0)), True, False), | ||
array_int=((ng.p.Scalar(), ng.p.Array(shape=(1,)).set_integer_casting()), False, False), | ||
softmax_noisy=((ng.p.Choice(["blue", "red"]), ng.p.Array(shape=(1,))), True, True), | ||
softmax_deterministic=( | ||
(p.Choice(["blue", "red"], deterministic=True), p.Array(shape=(1,))), | ||
(ng.p.Choice(["blue", "red"], deterministic=True), ng.p.Array(shape=(1,))), | ||
False, | ||
False, | ||
), | ||
ordered_discrete=((p.TransitionChoice([True, False]), p.Array(shape=(1,))), False, False), | ||
ordered_discrete=((ng.p.TransitionChoice([True, False]), ng.p.Array(shape=(1,))), False, False), | ||
) | ||
def test_parametrization_continuous_noisy( | ||
variables: tp.Tuple[p.Parameter, ...], continuous: bool, noisy: bool | ||
variables: tp.Tuple[ng.p.Parameter, ...], continuous: bool, noisy: bool | ||
) -> None: | ||
instru = p.Instrumentation(*variables) | ||
instru = ng.p.Instrumentation(*variables) | ||
assert instru.descriptors.continuous == continuous | ||
assert instru.descriptors.deterministic != noisy | ||
|
||
|
||
class ExampleFunction(base.ExperimentFunction): | ||
def __init__(self, dimension: int, number: int, default: int = 12): # pylint: disable=unused-argument | ||
# unused argument is used to check that it is automatically added as descriptor | ||
super().__init__(self.oracle_call, p.Array(shape=(dimension,))) | ||
super().__init__(self.oracle_call, ng.p.Array(shape=(dimension,))) | ||
|
||
def oracle_call(self, x: np.ndarray) -> float: | ||
return float(x[0]) | ||
|
@@ -157,7 +157,7 @@ def test_function_descriptors_and_pickle() -> None: | |
class ExampleFunctionAllDefault(base.ExperimentFunction): | ||
def __init__(self, dimension: int = 2, default: int = 12): # pylint: disable=unused-argument | ||
# unused argument is used to check that it is automatically added as descriptor | ||
super().__init__(lambda x: 3.0, p.Array(shape=(dimension,))) | ||
super().__init__(lambda x: 3.0, ng.p.Array(shape=(dimension,))) | ||
|
||
|
||
def test_function_descriptors_all_default() -> None: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -254,12 +254,18 @@ class Pruning: | |
def __init__(self, min_len: int, max_len: int): | ||
self.min_len = min_len | ||
self.max_len = max_len | ||
self._num_prunings = 0 # for testing it is not called too often | ||
|
||
def __call__(self, archive: Archive[MultiValue]) -> Archive[MultiValue]: | ||
if len(archive) < self.max_len: | ||
return archive | ||
return self._prune(archive) | ||
|
||
def _prune(self, archive: Archive[MultiValue]) -> Archive[MultiValue]: | ||
self._num_prunings += 1 | ||
# separate function to ease profiling | ||
quantiles: tp.Dict[str, float] = {} | ||
threshold = float(self.min_len) / len(archive) | ||
threshold = float(self.min_len + 1) / len(archive) | ||
names = ["optimistic", "pessimistic", "average"] | ||
for name in names: | ||
quantiles[name] = np.quantile( | ||
|
@@ -269,8 +275,9 @@ def __call__(self, archive: Archive[MultiValue]) -> Archive[MultiValue]: | |
new_archive.bytesdict = { | ||
b: v | ||
for b, v in archive.bytesdict.items() | ||
if any(v.get_estimation(n) <= quantiles[n] for n in names) | ||
} | ||
if any(v.get_estimation(n) < quantiles[n] for n in names) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the strict comparison is the big change that make it work. |
||
} # strict comparison to make sure we prune even for values repeated maaany times | ||
# this may remove all points though, but nevermind for now | ||
return new_archive | ||
|
||
@classmethod | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
new version because this is very impactful