diff --git a/ktdashboard/ktdashboard.py b/ktdashboard/ktdashboard.py index fda68ac..730020e 100644 --- a/ktdashboard/ktdashboard.py +++ b/ktdashboard/ktdashboard.py @@ -6,14 +6,17 @@ import panel as pn import panel.widgets as pnw import pandas as pd -from bokeh.models import HoverTool, LinearColorMapper +import bokeh.palettes +from bokeh.models.ranges import FactorRange +from bokeh.transform import jitter +from bokeh.models import HoverTool, LinearColorMapper, CategoricalColorMapper from bokeh.plotting import ColumnDataSource, figure class KTdashboard: """ Main object to instantiate to hold everything related to a running dashboard""" - def __init__(self, cache_file, demo=False): + def __init__(self, cache_file, demo=False, default_key=None): self.demo = demo self.cache_file = cache_file @@ -44,56 +47,108 @@ def __init__(self, cache_file, demo=False): if self.demo: self.index = min(len(data), 1000) + all_tune_param_keys = cached_data["tune_params_keys"] + all_tune_params = dict() + + for key in all_tune_param_keys: + values = cached_data["tune_params"][key] + for row in data: + if row[key] not in values: + values = sorted(values + [row[key]]) + + all_tune_params[key] = values + # figure out which keys are interesting - single_value_tune_param_keys = [key for key in cached_data["tune_params_keys"] if len(cached_data["tune_params"][key]) == 1] - tune_param_keys = [key for key in cached_data["tune_params_keys"] if key not in single_value_tune_param_keys] - single_value_keys = [key for key in data[0].keys() if not isinstance(data[0][key],list) and key not in single_value_tune_param_keys] - output_keys = [key for key in single_value_keys if key not in tune_param_keys] + single_value_tune_param_keys = [key for key in all_tune_param_keys if len(all_tune_params[key]) == 1] + tune_param_keys = [key for key in all_tune_param_keys if key not in single_value_tune_param_keys] + scalar_value_keys = [key for key in data[0].keys() if not isinstance(data[0][key],list) and key not in single_value_tune_param_keys] + output_keys = [key for key in scalar_value_keys if key not in tune_param_keys] float_keys = [key for key in output_keys if isinstance(data[0][key], float)] self.single_value_tune_param_keys = single_value_tune_param_keys self.tune_param_keys = tune_param_keys - self.single_value_keys = single_value_keys + self.scalar_value_keys = scalar_value_keys self.output_keys = output_keys self.float_keys = float_keys - self.data_df = pd.DataFrame(data[:self.index])[single_value_keys] - self.source = ColumnDataSource(data=self.data_df) + # Convert to a data frame + data_df = pd.DataFrame(data[:self.index])[scalar_value_keys] + + # Replace all column that are objects by categorical + for column, dtype in data_df.dtypes.items(): + if column in tune_param_keys and dtype == "object": + data_df[column] = pd.Categorical( + data_df[column], + categories=all_tune_params[column], + ordered=True) + self.data = data + self.data_df = data_df + self.source = ColumnDataSource(data=self.data_df) self.plot_width = 900 self.plot_height = 600 plot_options=dict(width=self.plot_width, min_width=self.plot_width, height=self.plot_height, min_height=self.plot_height) - plot_options['tools'] = [HoverTool(tooltips=[(k, "@{"+k+"}" + ("{0.00}" if k in float_keys else "")) for k in single_value_keys]), "box_select,box_zoom,save,reset"] + plot_options['tools'] = [HoverTool(tooltips=[(k, "@{"+k+"}" + ("{0.00}" if k in float_keys else "")) for k in scalar_value_keys]), "box_select,box_zoom,save,reset"] self.plot_options = plot_options # find default key - default_key = 'GFLOP/s' - if default_key not in single_value_keys: - default_key = 'time' # Check if time is defined - if default_key not in single_value_keys: - default_key = single_Value_keys[0] + if default_key is None: + default_key = 'GFLOP/s' + if default_key not in scalar_value_keys: + default_key = 'time' # Check if time is defined + + if default_key not in scalar_value_keys: + default_key = scalar_value_keys[0] # setup widgets - self.yvariable = pnw.Select(name='Y', value=default_key, options=single_value_keys) - self.xvariable = pnw.Select(name='X', value='index', options=['index']+single_value_keys) - self.colorvariable = pnw.Select(name='Color By', value=default_key, options=single_value_keys) + self.yvariable = pnw.Select(name='Y', value=default_key, options=scalar_value_keys) + self.xvariable = pnw.Select(name='X', value='index', options=['index']+scalar_value_keys) + self.colorvariable = pnw.Select(name='Color By', value=default_key, options=scalar_value_keys) + self.xscale = pnw.RadioButtonGroup(name="xscale", options=["linear", "log"]) + self.yscale = pnw.RadioButtonGroup(name="yscale", options=["linear", "log"]) # connect widgets with the function that draws the scatter plot - self.scatter = pn.bind(self.make_scatter, xvariable=self.xvariable, yvariable=self.yvariable, color_by=self.colorvariable) + self.scatter = pn.bind( + self.make_scatter, + xvariable=self.xvariable, + yvariable=self.yvariable, + color_by=self.colorvariable, + xscale=self.xscale, + yscale=self.yscale) # actually build up the dashboard self.dashboard = pn.template.BootstrapTemplate(title='Kernel Tuner Dashboard') - self.dashboard.sidebar.append(pn.Column(self.yvariable, self.xvariable, self.colorvariable, pn.layout.Divider())) self.dashboard.main.append(self.scatter) + self.dashboard.sidebar.append(pn.Column( + self.yvariable, + self.xvariable, + self.colorvariable)) + + self.dashboard.sidebar.append(pn.layout.Divider()) + + self.dashboard.sidebar.append(pn.Row( + pn.pane.Markdown("X axis"), + self.xscale + )) + + self.dashboard.sidebar.append(pn.Row( + pn.pane.Markdown("Y axis"), + self.yscale + )) + + self.dashboard.sidebar.append(pn.layout.Divider()) self.multi_choice = list() for tune_param in self.tune_param_keys: - values = list(set([d[tune_param] for d in data])) + values = all_tune_params[tune_param] + multi_choice = pnw.MultiChoice(name=tune_param, value=values, options=values) self.dashboard.sidebar.append(multi_choice) - pn.Row(pn.bind(self.update_data_selection, tune_param, multi_choice)) + + row = pn.bind(self.update_data_selection, tune_param, multi_choice) + self.dashboard.sidebar.append(row) def __del__(self): self.cache_file_handle.close() @@ -114,21 +169,52 @@ def update_data_selection(self, tune_param, multi_choice): self.source.data = data_df def update_colors(self, color_by): - color_mapper = LinearColorMapper(palette='Viridis256', low=min(self.data_df[color_by]), - high=max(self.data_df[color_by])) + dtype = self.data_df.dtypes[color_by] + + if dtype == "category": + factors = dtype.categories + if len(factors) < 10: + palette = bokeh.palettes.Category10[10] + else: + palette = bokeh.palettes.Category20[20] + + + color_mapper = CategoricalColorMapper(palette=palette, factors=factors) + + else: + color_mapper = LinearColorMapper(palette='Viridis256', low=min(self.data_df[color_by]), + high=max(self.data_df[color_by])) + color = {'field': color_by, 'transform': color_mapper} return color - def make_scatter(self, xvariable, yvariable, color_by): + def make_scatter(self, xvariable, yvariable, color_by, xscale, yscale): color = self.update_colors(color_by) x = xvariable y = yvariable - f = figure(**self.plot_options) - f.circle(x, y, size=5, color=color, alpha=0.5, source=self.source) - f.xaxis.axis_label = x - f.yaxis.axis_label = y + plot_options = dict(self.plot_options) + plot_options["x_axis_type"] = xscale + plot_options["y_axis_type"] = yscale + + # For categorical data, we add some jitter + dtype = self.data_df.dtypes.get(xvariable) + if dtype == "category": + plot_options["x_range"] = list(dtype.categories) + x = jitter(xvariable, width=0.02, distribution="normal", + range=FactorRange(*dtype.categories)) + + dtype = self.data_df.dtypes.get(yvariable) + if dtype == "category": + plot_options["y_range"] = list(dtype.categories) + x = jitter(yvariable, width=0.02, distribution="normal", + range=FactorRange(*dtype.categories)) + + f = figure(**plot_options) + f.scatter(x, y, size=5, color=color, alpha=0.5, source=self.source) + f.xaxis.axis_label = xvariable + f.yaxis.axis_label = yvariable bokeh_pane = pn.pane.Bokeh(object=f, min_width=self.plot_width, min_height=self.plot_height, max_width=self.plot_width, max_height=self.plot_height) @@ -137,7 +223,7 @@ def make_scatter(self, xvariable, yvariable, color_by): return pane def update_plot(self, i): - stream_dict = {k:[v] for k,v in dict(self.data[i], index=i).items() if k in ['index']+self.single_value_keys} + stream_dict = {k:[v] for k,v in dict(self.data[i], index=i).items() if k in ['index']+self.scalar_value_keys} self.source.stream(stream_dict) def update_data(self): @@ -151,7 +237,7 @@ def update_data(self): for i,element in enumerate(new_data): - stream_dict = {k:[v] for k,v in dict(element, index=self.index+i).items() if k in ['index']+self.single_value_keys} + stream_dict = {k:[v] for k,v in dict(element, index=self.index+i).items() if k in ['index']+self.scalar_value_keys} self.source.stream(stream_dict) self.index += len(new_data)