From 371a561ddb15dd01804facf5a1d0d629d0faef8a Mon Sep 17 00:00:00 2001 From: Nikita Malinin Date: Mon, 18 Sep 2023 13:21:18 +0200 Subject: [PATCH] Patch broken method --- .../torch/nas/bootstrapNAS/search/search.py | 12 +++++++++++- setup.py | 2 +- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/nncf/experimental/torch/nas/bootstrapNAS/search/search.py b/nncf/experimental/torch/nas/bootstrapNAS/search/search.py index 3858d2eeef5..a20faf6eb43 100644 --- a/nncf/experimental/torch/nas/bootstrapNAS/search/search.py +++ b/nncf/experimental/torch/nas/bootstrapNAS/search/search.py @@ -48,6 +48,16 @@ ValFnType = Callable[[TModel, DataLoaderType], float] +class FixIntegerRandomSampling(IntegerRandomSampling): + """ + Wrapper for the IntegerRandomSampling with the fix for https://github.com/anyoptimization/pymoo/issues/388. + """ + + def _do(self, problem, n_samples, **kwargs): + n, (xl, xu) = problem.n_var, problem.bounds() + return np.column_stack([np.random.randint(xl[k], xu[k] + 1, size=(n_samples)) for k in range(n)]) + + class EvolutionaryAlgorithms(Enum): NSGA2 = "NSGA2" @@ -207,7 +217,7 @@ def __init__( if evo_algo == EvolutionaryAlgorithms.NSGA2.value: self._algorithm = NSGA2( pop_size=self.search_params.population, - sampling=IntegerRandomSampling(), + sampling=FixIntegerRandomSampling(), crossover=SBX( prob=self.search_params.crossover_prob, eta=self.search_params.crossover_eta, diff --git a/setup.py b/setup.py index 2f6064206c3..4e843b40219 100644 --- a/setup.py +++ b/setup.py @@ -112,7 +112,7 @@ def find_version(*file_paths): "pandas>=1.1.5,<2.1", "psutil", "pydot>=1.4.1", - "pymoo==0.6.0", + "pymoo==0.6.0.1", # The recent pyparsing major version update seems to break # integration with networkx - the graphs parsed from current .dot # reference files no longer match against the graphs produced in tests.