Skip to content

Commit

Permalink
matplotlib plot
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink committed Jul 31, 2024
1 parent 276d481 commit 24c963a
Showing 1 changed file with 89 additions and 1 deletion.
90 changes: 89 additions & 1 deletion dacapo/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,18 @@
import bokeh.layouts
import bokeh.plotting
import numpy as np
from tqdm import tqdm

from collections import namedtuple
import itertools
from typing import List
import matplotlib.pyplot as plt


import os



RunInfo = namedtuple(
"RunInfo",
[
Expand Down Expand Up @@ -117,7 +124,7 @@ def get_runs_info(
return runs


def plot_runs(
def bokeh_plot_runs(
run_config_base_names,
smooth=100,
validation_scores=None,
Expand Down Expand Up @@ -384,3 +391,84 @@ def plot_runs(
else:
bokeh.plotting.output_file("performance_plots.html")
bokeh.plotting.save(plot)


def plot_runs(
run_config_base_names,
smooth=100,
validation_scores=None,
higher_is_betters=None,
plot_losses=None,
):
"""
Plot runs.
Args:
run_config_base_names: Names of run configs to plot
smooth: Smoothing factor
validation_scores: Validation scores to plot
higher_is_betters: Whether higher is better
plot_losses: Whether to plot losses
Returns:
None
"""
print("PLOTTING RUNS")
runs = get_runs_info(run_config_base_names, validation_scores, plot_losses)
print("GOT RUNS INFO")

colors = itertools.cycle(plt.cm.tab20.colors)
include_validation_figure = False
include_loss_figure = False

fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(15, 10))
loss_ax = axes[0]
validation_ax = axes[1]

for run, color in zip(runs, colors):
name = run.name

if run.plot_loss:
iterations = [stat.iteration for stat in run.training_stats.iteration_stats]
losses = [stat.loss for stat in run.training_stats.iteration_stats]

print(f"Run {run.name} has {len(losses)} iterations")

if run.plot_loss:
include_loss_figure = True
smooth = int(np.maximum(len(iterations) / 2500, 1))
print(f"smoothing: {smooth}")
x, _ = smooth_values(iterations, smooth, stride=smooth)
y, s = smooth_values(losses, smooth, stride=smooth)
print(x, y)
print(f"plotting {(len(x), len(y))} points")
loss_ax.plot(x, y, label=name, color=color)
print("LOSS PLOTTED")

if run.validation_score_name and run.validation_scores.validated_until() > 0:
validation_score_data = run.validation_scores.to_xarray().sel(
criteria=run.validation_score_name
)
colors_val = itertools.cycle(plt.cm.tab20.colors)
for dataset,color_v in zip(run.validation_scores.datasets,colors_val):
dataset_data = validation_score_data.sel(datasets=dataset)
include_validation_figure = True
x = [score.iteration for score in run.validation_scores.scores]
cc = next(colors_val)
for i in range(dataset_data.data.shape[1]):
current_name = f"{i}_{dataset.name}_{name}_{run.validation_score_name}"
validation_ax.plot(x, dataset_data.data[:,i] , label=current_name, color=cc, alpha=0.5+0.2*i)
print("VALIDATION PLOTTED")

if include_loss_figure:
loss_ax.set_title("Training")
loss_ax.set_xlabel("Iterations")
loss_ax.set_ylabel("Loss")
loss_ax.legend()

if include_validation_figure:
validation_ax.set_title("Validation")
validation_ax.set_xlabel("Iterations")
validation_ax.set_ylabel("Validation Score")
validation_ax.legend()

plt.tight_layout()
plt.show()

0 comments on commit 24c963a

Please sign in to comment.