Skip to content

Commit

Permalink
Merge pull request #5 from stijnh/main
Browse files Browse the repository at this point in the history
Add initial support for non-numeric tunable types
  • Loading branch information
benvanwerkhoven authored Apr 22, 2024
2 parents 79821b3 + 53a9478 commit be7ee36
Showing 1 changed file with 117 additions and 31 deletions.
148 changes: 117 additions & 31 deletions ktdashboard/ktdashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@
import panel as pn
import panel.widgets as pnw
import pandas as pd
from bokeh.models import HoverTool, LinearColorMapper
import bokeh.palettes
from bokeh.models.ranges import FactorRange
from bokeh.transform import jitter
from bokeh.models import HoverTool, LinearColorMapper, CategoricalColorMapper
from bokeh.plotting import ColumnDataSource, figure


class KTdashboard:
""" Main object to instantiate to hold everything related to a running dashboard"""

def __init__(self, cache_file, demo=False):
def __init__(self, cache_file, demo=False, default_key=None):
self.demo = demo
self.cache_file = cache_file

Expand Down Expand Up @@ -44,56 +47,108 @@ def __init__(self, cache_file, demo=False):
if self.demo:
self.index = min(len(data), 1000)

all_tune_param_keys = cached_data["tune_params_keys"]
all_tune_params = dict()

for key in all_tune_param_keys:
values = cached_data["tune_params"][key]
for row in data:
if row[key] not in values:
values = sorted(values + [row[key]])

all_tune_params[key] = values

# figure out which keys are interesting
single_value_tune_param_keys = [key for key in cached_data["tune_params_keys"] if len(cached_data["tune_params"][key]) == 1]
tune_param_keys = [key for key in cached_data["tune_params_keys"] if key not in single_value_tune_param_keys]
single_value_keys = [key for key in data[0].keys() if not isinstance(data[0][key],list) and key not in single_value_tune_param_keys]
output_keys = [key for key in single_value_keys if key not in tune_param_keys]
single_value_tune_param_keys = [key for key in all_tune_param_keys if len(all_tune_params[key]) == 1]
tune_param_keys = [key for key in all_tune_param_keys if key not in single_value_tune_param_keys]
scalar_value_keys = [key for key in data[0].keys() if not isinstance(data[0][key],list) and key not in single_value_tune_param_keys]
output_keys = [key for key in scalar_value_keys if key not in tune_param_keys]
float_keys = [key for key in output_keys if isinstance(data[0][key], float)]

self.single_value_tune_param_keys = single_value_tune_param_keys
self.tune_param_keys = tune_param_keys
self.single_value_keys = single_value_keys
self.scalar_value_keys = scalar_value_keys
self.output_keys = output_keys
self.float_keys = float_keys

self.data_df = pd.DataFrame(data[:self.index])[single_value_keys]
self.source = ColumnDataSource(data=self.data_df)
# Convert to a data frame
data_df = pd.DataFrame(data[:self.index])[scalar_value_keys]

# Replace all column that are objects by categorical
for column, dtype in data_df.dtypes.items():
if column in tune_param_keys and dtype == "object":
data_df[column] = pd.Categorical(
data_df[column],
categories=all_tune_params[column],
ordered=True)

self.data = data
self.data_df = data_df
self.source = ColumnDataSource(data=self.data_df)

self.plot_width = 900
self.plot_height = 600
plot_options=dict(width=self.plot_width, min_width=self.plot_width, height=self.plot_height, min_height=self.plot_height)
plot_options['tools'] = [HoverTool(tooltips=[(k, "@{"+k+"}" + ("{0.00}" if k in float_keys else "")) for k in single_value_keys]), "box_select,box_zoom,save,reset"]
plot_options['tools'] = [HoverTool(tooltips=[(k, "@{"+k+"}" + ("{0.00}" if k in float_keys else "")) for k in scalar_value_keys]), "box_select,box_zoom,save,reset"]

self.plot_options = plot_options

# find default key
default_key = 'GFLOP/s'
if default_key not in single_value_keys:
default_key = 'time' # Check if time is defined
if default_key not in single_value_keys:
default_key = single_Value_keys[0]
if default_key is None:
default_key = 'GFLOP/s'
if default_key not in scalar_value_keys:
default_key = 'time' # Check if time is defined

if default_key not in scalar_value_keys:
default_key = scalar_value_keys[0]

# setup widgets
self.yvariable = pnw.Select(name='Y', value=default_key, options=single_value_keys)
self.xvariable = pnw.Select(name='X', value='index', options=['index']+single_value_keys)
self.colorvariable = pnw.Select(name='Color By', value=default_key, options=single_value_keys)
self.yvariable = pnw.Select(name='Y', value=default_key, options=scalar_value_keys)
self.xvariable = pnw.Select(name='X', value='index', options=['index']+scalar_value_keys)
self.colorvariable = pnw.Select(name='Color By', value=default_key, options=scalar_value_keys)
self.xscale = pnw.RadioButtonGroup(name="xscale", options=["linear", "log"])
self.yscale = pnw.RadioButtonGroup(name="yscale", options=["linear", "log"])

# connect widgets with the function that draws the scatter plot
self.scatter = pn.bind(self.make_scatter, xvariable=self.xvariable, yvariable=self.yvariable, color_by=self.colorvariable)
self.scatter = pn.bind(
self.make_scatter,
xvariable=self.xvariable,
yvariable=self.yvariable,
color_by=self.colorvariable,
xscale=self.xscale,
yscale=self.yscale)

# actually build up the dashboard
self.dashboard = pn.template.BootstrapTemplate(title='Kernel Tuner Dashboard')
self.dashboard.sidebar.append(pn.Column(self.yvariable, self.xvariable, self.colorvariable, pn.layout.Divider()))
self.dashboard.main.append(self.scatter)
self.dashboard.sidebar.append(pn.Column(
self.yvariable,
self.xvariable,
self.colorvariable))

self.dashboard.sidebar.append(pn.layout.Divider())

self.dashboard.sidebar.append(pn.Row(
pn.pane.Markdown("X axis"),
self.xscale
))

self.dashboard.sidebar.append(pn.Row(
pn.pane.Markdown("Y axis"),
self.yscale
))

self.dashboard.sidebar.append(pn.layout.Divider())

self.multi_choice = list()
for tune_param in self.tune_param_keys:
values = list(set([d[tune_param] for d in data]))
values = all_tune_params[tune_param]

multi_choice = pnw.MultiChoice(name=tune_param, value=values, options=values)
self.dashboard.sidebar.append(multi_choice)
pn.Row(pn.bind(self.update_data_selection, tune_param, multi_choice))

row = pn.bind(self.update_data_selection, tune_param, multi_choice)
self.dashboard.sidebar.append(row)

def __del__(self):
self.cache_file_handle.close()
Expand All @@ -114,21 +169,52 @@ def update_data_selection(self, tune_param, multi_choice):
self.source.data = data_df

def update_colors(self, color_by):
color_mapper = LinearColorMapper(palette='Viridis256', low=min(self.data_df[color_by]),
high=max(self.data_df[color_by]))
dtype = self.data_df.dtypes[color_by]

if dtype == "category":
factors = dtype.categories
if len(factors) < 10:
palette = bokeh.palettes.Category10[10]
else:
palette = bokeh.palettes.Category20[20]


color_mapper = CategoricalColorMapper(palette=palette, factors=factors)

else:
color_mapper = LinearColorMapper(palette='Viridis256', low=min(self.data_df[color_by]),
high=max(self.data_df[color_by]))

color = {'field': color_by, 'transform': color_mapper}
return color

def make_scatter(self, xvariable, yvariable, color_by):
def make_scatter(self, xvariable, yvariable, color_by, xscale, yscale):
color = self.update_colors(color_by)

x = xvariable
y = yvariable

f = figure(**self.plot_options)
f.circle(x, y, size=5, color=color, alpha=0.5, source=self.source)
f.xaxis.axis_label = x
f.yaxis.axis_label = y
plot_options = dict(self.plot_options)
plot_options["x_axis_type"] = xscale
plot_options["y_axis_type"] = yscale

# For categorical data, we add some jitter
dtype = self.data_df.dtypes.get(xvariable)
if dtype == "category":
plot_options["x_range"] = list(dtype.categories)
x = jitter(xvariable, width=0.02, distribution="normal",
range=FactorRange(*dtype.categories))

dtype = self.data_df.dtypes.get(yvariable)
if dtype == "category":
plot_options["y_range"] = list(dtype.categories)
x = jitter(yvariable, width=0.02, distribution="normal",
range=FactorRange(*dtype.categories))

f = figure(**plot_options)
f.scatter(x, y, size=5, color=color, alpha=0.5, source=self.source)
f.xaxis.axis_label = xvariable
f.yaxis.axis_label = yvariable

bokeh_pane = pn.pane.Bokeh(object=f, min_width=self.plot_width, min_height=self.plot_height, max_width=self.plot_width, max_height=self.plot_height)

Expand All @@ -137,7 +223,7 @@ def make_scatter(self, xvariable, yvariable, color_by):
return pane

def update_plot(self, i):
stream_dict = {k:[v] for k,v in dict(self.data[i], index=i).items() if k in ['index']+self.single_value_keys}
stream_dict = {k:[v] for k,v in dict(self.data[i], index=i).items() if k in ['index']+self.scalar_value_keys}
self.source.stream(stream_dict)

def update_data(self):
Expand All @@ -151,7 +237,7 @@ def update_data(self):

for i,element in enumerate(new_data):

stream_dict = {k:[v] for k,v in dict(element, index=self.index+i).items() if k in ['index']+self.single_value_keys}
stream_dict = {k:[v] for k,v in dict(element, index=self.index+i).items() if k in ['index']+self.scalar_value_keys}
self.source.stream(stream_dict)

self.index += len(new_data)
Expand Down

0 comments on commit be7ee36

Please sign in to comment.