From 873e2a505a5b9350fa8a77a24e53668d404db998 Mon Sep 17 00:00:00 2001 From: Zak Vendeiro Date: Mon, 13 Jun 2022 02:28:17 -0700 Subject: [PATCH] Simplified the color generation for cost vs run plots. Previously there was a static mapping of controller type to color. That doesn't jive well with supporting third party controllers, since there's no way to know how many there are. Now colors for cost vs run plots are generated on the fly based on how many different colors are needed. --- mloop/controllers.py | 9 ------- mloop/visualizations.py | 59 +++++++++++++++++++++-------------------- 2 files changed, 30 insertions(+), 38 deletions(-) diff --git a/mloop/controllers.py b/mloop/controllers.py index 6c7b6f9..1116d98 100644 --- a/mloop/controllers.py +++ b/mloop/controllers.py @@ -17,15 +17,6 @@ import mloop.learners as mll import mloop.interfaces as mli -controller_dict = { - 'random': 1, - 'nelder_mead': 2, - 'gaussian_process': 3, - 'differential_evolution': 4, - 'neural_net': 5, - 'third_party': 6, -} -number_of_controllers = len(controller_dict) default_controller_archive_filename = 'controller_archive' default_controller_archive_file_type = 'txt' diff --git a/mloop/visualizations.py b/mloop/visualizations.py index 636cfc6..a75d549 100644 --- a/mloop/visualizations.py +++ b/mloop/visualizations.py @@ -280,24 +280,11 @@ def create_learner_visualizations(filename, ) visualizer.create_visualizations(**learner_visualization_kwargs) -def _color_from_controller_name(controller_name): +def _color_list_from_num_options(num_of_params): ''' - Gives a color (as a number between zero an one) corresponding to each controller name string. - ''' - global cmap - # If controller_name isn't in the mlc.controller_dict dictionary, assume it - # is the name of a third party controller provided by an external package. - if controller_name not in mlc.controller_dict: - controller_name = 'third_party' - - # Determine the color. - index = mlc.controller_dict[controller_name] - fraction = float(index) / float(mlc.number_of_controllers) - return cmap(fraction) - -def _color_list_from_num_of_params(num_of_params): - ''' - Gives a list of colors based on the number of parameters. + Gives a list of colors based on a number of options. + + A distinct color will be generated for each option. ''' global cmap return [cmap(float(x)/num_of_params) for x in range(num_of_params)] @@ -413,8 +400,14 @@ def __init__(self, filename, else: self.finite_flag = False - self.unique_types = set(self.out_type) - self.cost_colors = [_color_from_controller_name(x) for x in self.out_type] + self.unique_types = list(set(self.out_type)) + out_type_colors = _color_list_from_num_options( + len(self.unique_types), + ) + self.out_type_color_mapping = { + self.unique_types[j]: out_type_colors[j] for j in range(len(self.unique_types)) + } + self.cost_colors = [self.out_type_color_mapping[x] for x in self.out_type] self.in_numbers = np.arange(1,self.num_in_costs+1) self.out_numbers = np.arange(1,self.num_out_params+1) self.param_numbers = np.arange(self.num_params) @@ -485,7 +478,15 @@ def plot_cost_vs_run(self): plt.title('Controller: Cost vs run number.') artists = [] for ut in self.unique_types: - artists.append(plt.Line2D((0,1),(0,0), color=_color_from_controller_name(ut), marker='o', linestyle='')) + artists.append( + plt.Line2D( + (0,1), + (0,0), + color=self.out_type_color_mapping[ut], + marker='o', + linestyle='', + ) + ) plt.legend(artists,self.unique_types,loc=legend_loc) def _ensure_parameter_subset_valid(self, parameter_subset): @@ -515,7 +516,7 @@ def plot_parameters_vs_run(self, parameter_subset=None): # Generate set of distinct colors for plotting. num_params = len(parameter_subset) - param_colors = _color_list_from_num_of_params(num_params) + param_colors = _color_list_from_num_options(num_params) global figure_counter, run_label, scale_param_label, legend_loc figure_counter += 1 @@ -576,7 +577,7 @@ def plot_parameters_vs_cost(self, parameter_subset=None): # Generate set of distinct colors for plotting. num_params = len(parameter_subset) - param_colors = _color_list_from_num_of_params(num_params) + param_colors = _color_list_from_num_options(num_params) global figure_counter, run_label, run_label, scale_param_label, legend_loc figure_counter += 1 @@ -718,7 +719,7 @@ def __init__(self, filename, self.param_numbers = np.arange(self.num_params) self.gen_numbers = np.arange(1,self.num_generations+1) - self.param_colors = _color_list_from_num_of_params(self.num_params) + self.param_colors = _color_list_from_num_options(self.num_params) self.gen_plot = np.array([np.full(self.num_population_members, ind, dtype=int) for ind in self.gen_numbers]).flatten() def create_visualizations(self, @@ -795,7 +796,7 @@ def plot_params_vs_generations(self, parameter_subset=None): # Generate set of distinct colors for plotting. num_params = len(parameter_subset) - param_colors = _color_list_from_num_of_params(num_params) + param_colors = _color_list_from_num_options(num_params) if self.params_generations.size == 0: self.log.warning('Unable to plot DE: params vs generations as the initial generation did not complete.') @@ -1125,7 +1126,7 @@ def plot_cross_sections(self, parameter_subset=None): # Generate set of distinct colors for plotting. num_params = len(parameter_subset) - param_colors = _color_list_from_num_of_params(num_params) + param_colors = _color_list_from_num_options(num_params) global figure_counter, legend_loc figure_counter += 1 @@ -1246,7 +1247,7 @@ def plot_hyperparameters_vs_fit(self, parameter_subset=None): # Generate set of distinct colors for plotting. num_params = len(parameter_subset) - param_colors = _color_list_from_num_of_params(num_params) + param_colors = _color_list_from_num_options(num_params) global figure_counter, fit_label, legend_loc, log_length_scale_label figure_counter += 1 @@ -1562,7 +1563,7 @@ def do_cross_sections(self, parameter_subset=None, # Generate set of distinct colors for plotting. num_params = len(parameter_subset) - param_colors = _color_list_from_num_of_params(num_params) + param_colors = _color_list_from_num_options(num_params) # Generate labels for legends. legend_labels = mlu._generate_legend_labels( @@ -1702,7 +1703,7 @@ def plot_losses(self): # Generate set of distinct colors for plotting. num_nets = len(all_losses) - net_colors = _color_list_from_num_of_params(num_nets) + net_colors = _color_list_from_num_options(num_nets) artists=[] legend_labels=[] @@ -1750,7 +1751,7 @@ def plot_regularization_history(self): # Generate set of distinct colors for plotting. num_nets = len(regularization_histories) - net_colors = _color_list_from_num_of_params(num_nets) + net_colors = _color_list_from_num_options(num_nets) artists=[] legend_labels=[]