From 1ab409c1a7a57cce3a630d2abcd2037031b946a4 Mon Sep 17 00:00:00 2001 From: Julie G <43496356+julieg18@users.noreply.github.com> Date: Mon, 20 Mar 2023 09:47:50 -0500 Subject: [PATCH] Consolidate `collectCustomPlots` (#3466) --- extension/src/experiments/index.ts | 4 + extension/src/experiments/model/index.ts | 4 + extension/src/plots/index.ts | 9 +- extension/src/plots/model/collect.test.ts | 207 +++------- extension/src/plots/model/collect.ts | 372 ++++++------------ extension/src/plots/model/index.ts | 95 ++--- extension/src/plots/webview/contract.ts | 14 +- extension/src/plots/webview/messages.ts | 14 +- .../test/fixtures/expShow/base/customPlots.ts | 268 ++++++++----- 9 files changed, 393 insertions(+), 594 deletions(-) diff --git a/extension/src/experiments/index.ts b/extension/src/experiments/index.ts index 360af40e31..67f421641f 100644 --- a/extension/src/experiments/index.ts +++ b/extension/src/experiments/index.ts @@ -333,6 +333,10 @@ export class Experiments extends BaseRepository { return this.experiments.getExperimentCount() } + public getExperimentsWithCheckpoints() { + return this.experiments.getExperimentsWithCheckpoints() + } + public async selectExperiments() { const experiments = this.experiments.getExperimentsWithCheckpoints() diff --git a/extension/src/experiments/model/index.ts b/extension/src/experiments/model/index.ts index 0a29502689..1a7a386086 100644 --- a/extension/src/experiments/model/index.ts +++ b/extension/src/experiments/model/index.ts @@ -60,6 +60,10 @@ export type ExperimentWithCheckpoints = Experiment & { checkpoints?: Experiment[] } +export type ExperimentWithDefinedCheckpoints = Experiment & { + checkpoints: Experiment[] +} + export enum ExperimentType { WORKSPACE = 'workspace', COMMIT = 'commit', diff --git a/extension/src/plots/index.ts b/extension/src/plots/index.ts index e612fd456a..d61e04875e 100644 --- a/extension/src/plots/index.ts +++ b/extension/src/plots/index.ts @@ -13,7 +13,6 @@ import { Experiments } from '../experiments' import { Resource } from '../resourceLocator' import { InternalCommands } from '../commands/internal' import { definedAndNonEmpty } from '../util/array' -import { ExperimentsOutput } from '../cli/dvc/contract' import { TEMP_PLOTS_DIR } from '../cli/dvc/constants' import { removeDir } from '../fileSystem' import { Toast } from '../vscode/toast' @@ -173,7 +172,7 @@ export class Plots extends BaseRepository { waitForInitialExpData.dispose() this.data.setMetricFiles(data) this.setupExperimentsListener(experiments) - void this.initializeData(data) + void this.initializeData() } }) ) @@ -184,7 +183,7 @@ export class Plots extends BaseRepository { experiments.onDidChangeExperiments(async data => { if (data) { await Promise.all([ - this.plots.transformAndSetExperiments(data), + this.plots.transformAndSetExperiments(), this.data.setMetricFiles(data) ]) } @@ -200,8 +199,8 @@ export class Plots extends BaseRepository { ) } - private async initializeData(data: ExperimentsOutput) { - await this.plots.transformAndSetExperiments(data) + private async initializeData() { + await this.plots.transformAndSetExperiments() void this.data.managedUpdate() await Promise.all([ this.data.isReady(), diff --git a/extension/src/plots/model/collect.test.ts b/extension/src/plots/model/collect.test.ts index b5175eae4c..295beb3e65 100644 --- a/extension/src/plots/model/collect.test.ts +++ b/extension/src/plots/model/collect.test.ts @@ -1,18 +1,16 @@ import { join } from 'path' -import omit from 'lodash.omit' import isEmpty from 'lodash.isempty' import { collectData, collectTemplates, collectOverrideRevisionDetails, - collectCustomPlots, - collectCustomCheckpointPlots, - collectCustomPlotData + collectCustomPlots } from './collect' +import { isCheckpointPlot } from './custom' import plotsDiffFixture from '../../test/fixtures/plotsDiff/output' import customPlotsFixture, { customPlotsOrderFixture, - checkpointPlotsFixture + experimentsWithCheckpoints } from '../../test/fixtures/expShow/base/customPlots' import { ExperimentStatus, @@ -20,14 +18,13 @@ import { } from '../../cli/dvc/contract' import { sameContents } from '../../util/array' import { - CheckpointPlot, - CustomPlot, CustomPlotData, + CustomPlotType, + DEFAULT_NB_ITEMS_PER_ROW, + DEFAULT_PLOT_HEIGHT, TemplatePlot } from '../webview/contract' import { getCLICommitId } from '../../test/fixtures/plotsDiff/util' -import expShowFixture from '../../test/fixtures/expShow/base/output' -import modifiedFixture from '../../test/fixtures/expShow/modified/output' import { SelectedExperimentWithColor } from '../../experiments/model' import { Experiment } from '../../experiments/webview/contract' @@ -35,120 +32,62 @@ const logsLossPath = join('logs', 'loss.tsv') const logsLossPlot = (plotsDiffFixture[logsLossPath][0] || {}) as TemplatePlot -const getCustomPlotFromCustomPlotData = ({ - id, - metric, - param, - type, - values -}: CustomPlotData) => - ({ - id, - metric, - param, - type, - values - } as CustomPlot) - describe('collectCustomPlots', () => { + const defaultFuncArgs = { + experiments: experimentsWithCheckpoints, + hasCheckpoints: true, + height: DEFAULT_PLOT_HEIGHT, + nbItemsPerRow: DEFAULT_NB_ITEMS_PER_ROW, + plotsOrderValues: customPlotsOrderFixture, + selectedRevisions: customPlotsFixture.colors?.domain + } + it('should return the expected data from the test fixture', () => { - const expectedOutput: CustomPlot[] = customPlotsFixture.plots.map( - getCustomPlotFromCustomPlotData - ) - const data = collectCustomPlots( - customPlotsOrderFixture, - checkpointPlotsFixture, - [ - { - id: '12345', - label: '123', - metrics: { - 'summary.json': { - accuracy: 0.3724166750907898, - loss: 2.0205044746398926 - } - }, - name: 'exp-e7a67', - params: { 'params.yaml': { dropout: 0.15, epochs: 2 } } - }, - { - id: '12345', - label: '123', - metrics: { - 'summary.json': { - accuracy: 0.4668000042438507, - loss: 1.9293040037155151 - } - }, - name: 'test-branch', - params: { 'params.yaml': { dropout: 0.122, epochs: 2 } } - }, - { - id: '12345', - label: '123', - metrics: { - 'summary.json': { - accuracy: 0.5926499962806702, - loss: 1.775016188621521 - } - }, - name: 'exp-83425', - params: { 'params.yaml': { dropout: 0.124, epochs: 5 } } - } - ] + const expectedOutput: CustomPlotData[] = customPlotsFixture.plots + const data = collectCustomPlots(defaultFuncArgs) + expect(data).toStrictEqual(expectedOutput) + }) + + it('should return only custom plots if there no selected revisions', () => { + const expectedOutput: CustomPlotData[] = customPlotsFixture.plots.filter( + plot => plot.type !== CustomPlotType.CHECKPOINT ) + const data = collectCustomPlots({ + ...defaultFuncArgs, + selectedRevisions: undefined + }) expect(data).toStrictEqual(expectedOutput) }) -}) -describe('collectCustomPlotData', () => { - it('should return the expected data from test fixture', () => { - const expectedMetricVsParamPlotData = customPlotsFixture.plots[0] - const expectedCheckpointsPlotData = customPlotsFixture.plots[2] - const metricVsParamPlot = getCustomPlotFromCustomPlotData( - expectedMetricVsParamPlotData - ) - const checkpointsPlot = getCustomPlotFromCustomPlotData( - expectedCheckpointsPlotData + it('should return only custom plots if checkpoints are not enabled', () => { + const expectedOutput: CustomPlotData[] = customPlotsFixture.plots.filter( + plot => plot.type !== CustomPlotType.CHECKPOINT ) + const data = collectCustomPlots({ + ...defaultFuncArgs, + hasCheckpoints: false + }) - const metricVsParamData = collectCustomPlotData( - metricVsParamPlot, - customPlotsFixture.colors, - customPlotsFixture.nbItemsPerRow, - customPlotsFixture.height - ) + expect(data).toStrictEqual(expectedOutput) + }) - const checkpointsData = collectCustomPlotData( - { - ...checkpointsPlot, - values: [ - ...checkpointsPlot.values, - { - group: 'exp-123', - iteration: 1, - y: 1.4534177053451538 - }, - { - group: 'exp-123', - iteration: 2, - y: 1.757687 - }, - { - group: 'exp-123', - iteration: 3, - y: 1.989894 - } - ] - } as CheckpointPlot, - customPlotsFixture.colors, - customPlotsFixture.nbItemsPerRow, - customPlotsFixture.height - ) + it('should return checkpoint plots with values only containing selected experiments data', () => { + const domain = customPlotsFixture.colors?.domain.slice(1) as string[] + + const expectedOutput = customPlotsFixture.plots.map(plot => ({ + ...plot, + values: isCheckpointPlot(plot) + ? plot.values.filter(value => domain.includes(value.group)) + : plot.values + })) - expect(metricVsParamData).toStrictEqual(expectedMetricVsParamPlotData) - expect(checkpointsData).toStrictEqual(expectedCheckpointsPlotData) + const data = collectCustomPlots({ + ...defaultFuncArgs, + selectedRevisions: domain + }) + + expect(data).toStrictEqual(expectedOutput) }) }) @@ -215,50 +154,6 @@ describe('collectData', () => { }) }) -describe('collectCustomCheckpointPlotsData', () => { - it('should return the expected data from the test fixture', () => { - const data = collectCustomCheckpointPlots(expShowFixture) - - expect(data).toStrictEqual(checkpointPlotsFixture) - }) - - it('should provide a continuous series for a modified experiment', () => { - const data = collectCustomCheckpointPlots(modifiedFixture) - - for (const { values } of Object.values(data)) { - const initialExperiment = values.filter( - point => point.group === 'exp-908bd' - ) - const modifiedExperiment = values.find( - point => point.group === 'exp-01b3a' - ) - - const lastIterationInitial = initialExperiment?.slice(-1)[0] - const firstIterationModified = modifiedExperiment - - expect(lastIterationInitial).not.toStrictEqual(firstIterationModified) - expect(omit(lastIterationInitial, 'group')).toStrictEqual( - omit(firstIterationModified, 'group') - ) - - const baseExperiment = values.filter(point => point.group === 'exp-920fc') - const restartedExperiment = values.find( - point => point.group === 'exp-9bc1b' - ) - - const iterationRestartedFrom = baseExperiment?.slice(5)[0] - const firstIterationAfterRestart = restartedExperiment - - expect(iterationRestartedFrom).not.toStrictEqual( - firstIterationAfterRestart - ) - expect(omit(iterationRestartedFrom, 'group')).toStrictEqual( - omit(firstIterationAfterRestart, 'group') - ) - } - }) -}) - describe('collectTemplates', () => { it('should return the expected output from the test fixture', () => { const { content } = logsLossPlot diff --git a/extension/src/plots/model/collect.ts b/extension/src/plots/model/collect.ts index 64b7beccbd..7612e164ae 100644 --- a/extension/src/plots/model/collect.ts +++ b/extension/src/plots/model/collect.ts @@ -1,13 +1,10 @@ -import omit from 'lodash.omit' import get from 'lodash.get' import { TopLevelSpec } from 'vega-lite' import { VisualizationSpec } from 'react-vega' -import { CustomCheckpointPlots } from '.' import { getFullValuePath, CHECKPOINTS_PARAM, CustomPlotsOrderValue, - isCheckpointPlot, isCheckpointValue } from './custom' import { getRevisionFirstThreeColumns } from './util' @@ -22,35 +19,19 @@ import { TemplatePlotSection, PlotsType, Revision, - CustomPlotType, - CustomPlot, - MetricVsParamPlot, - CustomPlotData + CustomPlotData, + MetricVsParamPlotValues } from '../webview/contract' +import { EXPERIMENT_WORKSPACE_ID, PlotsOutput } from '../../cli/dvc/contract' import { - EXPERIMENT_WORKSPACE_ID, - ExperimentFieldsOrError, - ExperimentsOutput, - ExperimentStatus, - isValueTree, - PlotsOutput, - Value, - ValueTree -} from '../../cli/dvc/contract' -import { extractColumns } from '../../experiments/columns/extract' -import { - decodeColumn, - appendColumnToPath, splitColumnPath, FILE_SEPARATOR } from '../../experiments/columns/paths' import { ColumnType, Experiment, - isRunning, - MetricOrParamColumns + isRunning } from '../../experiments/webview/contract' -import { addToMapArray } from '../../util/map' import { TemplateOrder } from '../paths/collect' import { extendVegaSpec, @@ -67,231 +48,72 @@ import { unmergeConcatenatedFields } from '../multiSource/collect' import { StrokeDashEncoding } from '../multiSource/constants' -import { SelectedExperimentWithColor } from '../../experiments/model' +import { + ExperimentWithCheckpoints, + ExperimentWithDefinedCheckpoints, + SelectedExperimentWithColor +} from '../../experiments/model' import { Color } from '../../experiments/model/status/colors' -import { typedValueTreeEntries } from '../../experiments/columns/collect/metricsAndParams' - -type CheckpointPlotAccumulator = { - iterations: Record - plots: Map -} - -const collectFromMetricsFile = ( - acc: CheckpointPlotAccumulator, - name: string, - iteration: number, - key: string | undefined, - value: Value | ValueTree, - ancestors: string[] = [] -) => { - const pathArray = [...ancestors, key].filter(Boolean) as string[] - - if (isValueTree(value)) { - for (const [childKey, childValue] of typedValueTreeEntries(value)) { - collectFromMetricsFile( - acc, - name, - iteration, - childKey, - childValue, - pathArray - ) - } - return - } - - const path = appendColumnToPath(...pathArray) - - addToMapArray(acc.plots, path, { group: name, iteration, y: value }) -} - -type MetricsAndDetailsOrUndefined = - | { - checkpoint_parent: string | undefined - checkpoint_tip: string | undefined - metrics: MetricOrParamColumns | undefined - status: ExperimentStatus | undefined - } - | undefined - -const transformExperimentData = ( - experimentFieldsOrError: ExperimentFieldsOrError -): MetricsAndDetailsOrUndefined => { - const experimentFields = experimentFieldsOrError.data - if (!experimentFields) { - return - } - - const { checkpoint_tip, checkpoint_parent, status } = experimentFields - const { metrics } = extractColumns(experimentFields) - - return { checkpoint_parent, checkpoint_tip, metrics, status } -} - -type ValidData = { - checkpoint_parent: string - checkpoint_tip: string - metrics: MetricOrParamColumns - status: ExperimentStatus -} -const isValid = (data: MetricsAndDetailsOrUndefined): data is ValidData => - !!(data?.checkpoint_tip && data?.checkpoint_parent && data?.metrics) +export const getCustomPlotId = (metric: string, param = CHECKPOINTS_PARAM) => + `custom-${metric}-${param}` -const collectFromMetrics = ( - acc: CheckpointPlotAccumulator, - experimentName: string, - iteration: number, - metrics: MetricOrParamColumns -) => { - for (const file of Object.keys(metrics)) { - collectFromMetricsFile( - acc, - experimentName, - iteration, - undefined, - metrics[file], - [file] - ) - } -} +const getValueFromColumn = ( + path: string, + experiment: ExperimentWithCheckpoints +) => get(experiment, splitColumnPath(path)) as number | undefined -const getLastIteration = ( - acc: CheckpointPlotAccumulator, - checkpointParent: string -): number => acc.iterations[checkpointParent] || 0 - -const collectIteration = ( - acc: CheckpointPlotAccumulator, - sha: string, - checkpointParent: string -): number => { - const iteration = getLastIteration(acc, checkpointParent) + 1 - acc.iterations[sha] = iteration - return iteration -} +const isExperimentWithDefinedCheckpoints = ( + experiment: ExperimentWithCheckpoints +): experiment is ExperimentWithDefinedCheckpoints => !!experiment.checkpoints -const linkModified = ( - acc: CheckpointPlotAccumulator, - experimentName: string, - checkpointTip: string, - checkpointParent: string, - parent: ExperimentFieldsOrError | undefined +const collectCheckpointValuesFromExperiment = ( + values: CheckpointPlotValues, + exp: ExperimentWithDefinedCheckpoints, + metricPath: string ) => { - if (!parent) { - return - } + const group = exp.name || exp.label + const maxEpoch = exp.checkpoints.length + 1 - const parentData = transformExperimentData(parent) - if (!isValid(parentData) || parentData.checkpoint_tip === checkpointTip) { - return + const metricValue = getValueFromColumn(metricPath, exp) + if (metricValue !== undefined) { + values.push({ group, iteration: maxEpoch, y: metricValue }) } - const lastIteration = getLastIteration(acc, checkpointParent) - collectFromMetrics(acc, experimentName, lastIteration, parentData.metrics) -} - -const collectFromExperimentsObject = ( - acc: CheckpointPlotAccumulator, - experimentsObject: { [sha: string]: ExperimentFieldsOrError } -) => { - for (const [sha, experimentData] of Object.entries( - experimentsObject - ).reverse()) { - const data = transformExperimentData(experimentData) - - if (!isValid(data)) { - continue + for (const [ind, checkpoint] of exp.checkpoints.entries()) { + const metricValue = getValueFromColumn(metricPath, checkpoint) + if (metricValue !== undefined) { + values.push({ group, iteration: maxEpoch - ind - 1, y: metricValue }) } - const { - checkpoint_tip: checkpointTip, - checkpoint_parent: checkpointParent, - metrics - } = data - - const experimentName = experimentsObject[checkpointTip].data?.name - if (!experimentName) { - continue - } - - linkModified( - acc, - experimentName, - checkpointTip, - checkpointParent, - experimentsObject[checkpointParent] - ) - - const iteration = collectIteration(acc, sha, checkpointParent) - collectFromMetrics(acc, experimentName, iteration, metrics) } } -export const getCustomPlotId = (metric: string, param = CHECKPOINTS_PARAM) => - `custom-${metric}-${param}` - -export const collectCustomCheckpointPlots = ( - data: ExperimentsOutput -): CustomCheckpointPlots => { - const acc = { - iterations: {}, - plots: new Map() - } - - for (const { baseline, ...experimentsObject } of Object.values( - omit(data, EXPERIMENT_WORKSPACE_ID) - )) { - const commit = transformExperimentData(baseline) - - if (commit) { - collectFromExperimentsObject(acc, experimentsObject) - } - } - - const plotsData: CustomCheckpointPlots = {} - if (acc.plots.size === 0) { - return plotsData - } - - for (const [key, value] of acc.plots.entries()) { - const decodedMetric = decodeColumn(key) - plotsData[decodedMetric] = { - id: getCustomPlotId(decodedMetric), - metric: decodedMetric, - param: CHECKPOINTS_PARAM, - type: CustomPlotType.CHECKPOINT, - values: value +const getCheckpointValues = ( + experiments: ExperimentWithCheckpoints[], + metricPath: string +): CheckpointPlotValues => { + const values: CheckpointPlotValues = [] + for (const experiment of experiments) { + if (isExperimentWithDefinedCheckpoints(experiment)) { + collectCheckpointValuesFromExperiment(values, experiment, metricPath) } } - - return plotsData + return values } -const collectMetricVsParamPlot = ( - metric: string, - param: string, - experiments: Experiment[] -): MetricVsParamPlot => { - const splitUpMetricPath = splitColumnPath( - getFullValuePath(ColumnType.METRICS, metric, FILE_SEPARATOR) - ) - const splitUpParamPath = splitColumnPath( - getFullValuePath(ColumnType.PARAMS, param, FILE_SEPARATOR) - ) - const plotData: MetricVsParamPlot = { - id: getCustomPlotId(metric, param), - metric, - param, - type: CustomPlotType.METRIC_VS_PARAM, - values: [] - } +const getMetricVsParamValues = ( + experiments: ExperimentWithCheckpoints[], + metricPath: string, + paramPath: string +): MetricVsParamPlotValues => { + const values: MetricVsParamPlotValues = [] for (const experiment of experiments) { - const metricValue = get(experiment, splitUpMetricPath) as number | undefined - const paramValue = get(experiment, splitUpParamPath) as number | undefined + const metricValue = getValueFromColumn(metricPath, experiment) + const paramValue = getValueFromColumn(paramPath, experiment) if (metricValue !== undefined && paramValue !== undefined) { - plotData.values.push({ + values.push({ expName: experiment.name || experiment.label, metric: metricValue, param: paramValue @@ -299,44 +121,78 @@ const collectMetricVsParamPlot = ( } } - return plotData + return values } -export const collectCustomPlots = ( - plotsOrderValues: CustomPlotsOrderValue[], - checkpointPlots: CustomCheckpointPlots, - experiments: Experiment[] -): CustomPlot[] => { - return plotsOrderValues - .map((plotOrderValue): CustomPlot => { - if (isCheckpointValue(plotOrderValue.type)) { - const { metric } = plotOrderValue - return checkpointPlots[metric] - } - const { metric, param } = plotOrderValue - return collectMetricVsParamPlot(metric, param, experiments) - }) - .filter(Boolean) +const getCustomPlotData = ( + orderValue: CustomPlotsOrderValue, + experiments: ExperimentWithCheckpoints[], + selectedRevisions: string[] | undefined = [], + height: number, + nbItemsPerRow: number +): CustomPlotData => { + const { metric, param, type } = orderValue + const metricPath = getFullValuePath( + ColumnType.METRICS, + metric, + FILE_SEPARATOR + ) + + const paramPath = getFullValuePath(ColumnType.PARAMS, param, FILE_SEPARATOR) + + const selectedExperiments = experiments.filter(({ name, label }) => + selectedRevisions.includes(name || label) + ) + + const values = isCheckpointValue(type) + ? getCheckpointValues(selectedExperiments, metricPath) + : getMetricVsParamValues(experiments, metricPath, paramPath) + + return { + id: getCustomPlotId(metric, param), + metric, + param, + type, + values, + yTitle: truncateVerticalTitle(metric, nbItemsPerRow, height) as string + } as CustomPlotData } -export const collectCustomPlotData = ( - plot: CustomPlot, - colors: ColorScale | undefined, - nbItemsPerRow: number, +export const collectCustomPlots = ({ + plotsOrderValues, + experiments, + hasCheckpoints, + selectedRevisions, + height, + nbItemsPerRow +}: { + plotsOrderValues: CustomPlotsOrderValue[] + experiments: ExperimentWithCheckpoints[] + hasCheckpoints: boolean + selectedRevisions: string[] | undefined height: number -): CustomPlotData => { - const selectedExperiments = colors?.domain - const filteredValues = isCheckpointPlot(plot) - ? plot.values.filter(value => - (selectedExperiments as string[]).includes(value.group) + nbItemsPerRow: number +}): CustomPlotData[] => { + const plots = [] + const shouldSkipCheckpointPlots = !hasCheckpoints || !selectedRevisions + + for (const value of plotsOrderValues) { + if (shouldSkipCheckpointPlots && isCheckpointValue(value.type)) { + continue + } + + plots.push( + getCustomPlotData( + value, + experiments, + selectedRevisions, + height, + nbItemsPerRow ) - : plot.values + ) + } - return { - ...plot, - values: filteredValues, - yTitle: truncateVerticalTitle(plot.metric, nbItemsPerRow, height) as string - } as CustomPlotData + return plots } type RevisionPathData = { [path: string]: Record[] } diff --git a/extension/src/plots/model/index.ts b/extension/src/plots/model/index.ts index e8a8a1bd05..118c51de26 100644 --- a/extension/src/plots/model/index.ts +++ b/extension/src/plots/model/index.ts @@ -11,16 +11,10 @@ import { collectCommitRevisionDetails, collectOverrideRevisionDetails, collectCustomPlots, - getCustomPlotId, - collectCustomCheckpointPlots, - collectCustomPlotData + getCustomPlotId } from './collect' import { getRevisionFirstThreeColumns } from './util' -import { - cleanupOldOrderValue, - CustomPlotsOrderValue, - isCheckpointPlot -} from './custom' +import { cleanupOldOrderValue, CustomPlotsOrderValue } from './custom' import { CheckpointPlot, ComparisonPlots, @@ -32,14 +26,11 @@ import { SectionCollapsed, CustomPlotData, CustomPlotsData, - CustomPlot, - ColorScale, DEFAULT_HEIGHT, DEFAULT_NB_ITEMS_PER_ROW, PlotHeight } from '../webview/contract' import { - ExperimentsOutput, EXPERIMENT_WORKSPACE_ID, PlotsOutputOrError } from '../../cli/dvc/contract' @@ -80,9 +71,6 @@ export class PlotsModel extends ModelWithPersistence { private multiSourceVariations: MultiSourceVariations = {} private multiSourceEncoding: MultiSourceEncoding = {} - private customCheckpointPlots?: CustomCheckpointPlots - private customPlots?: CustomPlot[] - constructor( dvcRoot: string, experiments: Experiments, @@ -105,9 +93,7 @@ export class PlotsModel extends ModelWithPersistence { this.customPlotsOrder = this.revive(PersistenceKey.PLOTS_CUSTOM_ORDER, []) } - public transformAndSetExperiments(data: ExperimentsOutput) { - this.recreateCustomPlots(data) - + public transformAndSetExperiments() { return this.removeStaleData() } @@ -124,7 +110,6 @@ export class PlotsModel extends ModelWithPersistence { collectTemplates(data), collectMultiSourceVariations(data, this.multiSourceVariations) ]) - this.recreateCustomPlots() this.comparisonData = { ...this.comparisonData, @@ -153,7 +138,13 @@ export class PlotsModel extends ModelWithPersistence { } public getCustomPlots(): CustomPlotsData | undefined { - if (!this.customPlots) { + const experimentsWithNoCommitData = this.experiments.hasCheckpoints() + ? this.experiments + .getExperimentsWithCheckpoints() + .filter(({ checkpoints }) => !!checkpoints) + : this.experiments.getExperiments() + + if (experimentsWithNoCommitData.length === 0) { return } @@ -162,32 +153,31 @@ export class PlotsModel extends ModelWithPersistence { .getSelectedExperiments() .map(({ displayColor, id: revision }) => ({ displayColor, revision })) ) + const height = this.getHeight(PlotsSection.CUSTOM_PLOTS) + const nbItemsPerRow = this.getNbItemsPerRowOrWidth( + PlotsSection.CUSTOM_PLOTS + ) + const plotsOrderValues = this.getCustomPlotsOrder() + + const plots: CustomPlotData[] = collectCustomPlots({ + experiments: experimentsWithNoCommitData, + hasCheckpoints: this.experiments.hasCheckpoints(), + height, + nbItemsPerRow, + plotsOrderValues, + selectedRevisions: colors?.domain + }) - return { - colors, - height: this.getHeight(PlotsSection.CUSTOM_PLOTS), - nbItemsPerRow: this.getNbItemsPerRowOrWidth(PlotsSection.CUSTOM_PLOTS), - plots: this.getCustomPlotsData(this.customPlots, colors) - } - } - - public recreateCustomPlots(data?: ExperimentsOutput) { - if (data) { - this.customCheckpointPlots = collectCustomCheckpointPlots(data) + if (plots.length === 0 && plotsOrderValues.length > 0) { + return } - const experiments = this.experiments.getExperiments() - - if (experiments.length === 0) { - this.customPlots = undefined - return + return { + colors, + height, + nbItemsPerRow, + plots } - const customPlots: CustomPlot[] = collectCustomPlots( - this.getCustomPlotsOrder(), - this.customCheckpointPlots || {}, - experiments - ) - this.customPlots = customPlots } public getCustomPlotsOrder() { @@ -198,7 +188,6 @@ export class PlotsModel extends ModelWithPersistence { public updateCustomPlotsOrder(plotsOrder: CustomPlotsOrderValue[]) { this.customPlotsOrder = plotsOrder - this.recreateCustomPlots() } public setCustomPlotsOrder(plotsOrder: CustomPlotsOrderValue[]) { @@ -461,28 +450,6 @@ export class PlotsModel extends ModelWithPersistence { return this.commitRevisions[label] || label } - private getCustomPlotsData( - plots: CustomPlot[], - colors: ColorScale | undefined - ): CustomPlotData[] { - const selectedExperimentsExist = !!colors - const filteredPlots: CustomPlotData[] = [] - for (const plot of plots) { - if (!selectedExperimentsExist && isCheckpointPlot(plot)) { - continue - } - filteredPlots.push( - collectCustomPlotData( - plot, - colors, - this.getNbItemsPerRowOrWidth(PlotsSection.CUSTOM_PLOTS), - this.getHeight(PlotsSection.CUSTOM_PLOTS) - ) - ) - } - return filteredPlots - } - private getSelectedComparisonPlots( paths: string[], selectedRevisions: string[] diff --git a/extension/src/plots/webview/contract.ts b/extension/src/plots/webview/contract.ts index 679fa40f3d..6dcadff4d2 100644 --- a/extension/src/plots/webview/contract.ts +++ b/extension/src/plots/webview/contract.ts @@ -91,21 +91,21 @@ export type CheckpointPlotValues = { export type ColorScale = { domain: string[]; range: Color[] } -export type CheckpointPlot = { +type CustomPlotBase = { id: string - values: CheckpointPlotValues metric: string param: string - type: CustomPlotType.CHECKPOINT } +export type CheckpointPlot = { + values: CheckpointPlotValues + type: CustomPlotType.CHECKPOINT +} & CustomPlotBase + export type MetricVsParamPlot = { - id: string values: MetricVsParamPlotValues - metric: string - param: string type: CustomPlotType.METRIC_VS_PARAM -} +} & CustomPlotBase export type CustomPlot = MetricVsParamPlot | CheckpointPlot diff --git a/extension/src/plots/webview/messages.ts b/extension/src/plots/webview/messages.ts index 85f724aca8..4c64602ab4 100644 --- a/extension/src/plots/webview/messages.ts +++ b/extension/src/plots/webview/messages.ts @@ -37,6 +37,7 @@ import { doesCustomPlotAlreadyExist, isCheckpointValue } from '../model/custom' +import { getCustomPlotId } from '../model/collect' export class WebviewMessages { private readonly paths: PathsModel @@ -282,20 +283,23 @@ export class WebviewMessages { } private setCustomPlotsOrder(plotIds: string[]) { - const customPlots = this.plots.getCustomPlots()?.plots - if (!customPlots) { - return - } + const customPlotsOrderWithId = this.plots + .getCustomPlotsOrder() + .map(value => ({ + ...value, + id: getCustomPlotId(value.metric, value.param) + })) const newOrder: CustomPlotsOrderValue[] = reorderObjectList( plotIds, - customPlots, + customPlotsOrderWithId, 'id' ).map(({ metric, param, type }) => ({ metric, param, type })) + this.plots.setCustomPlotsOrder(newOrder) this.sendCustomPlotsAndEvent(EventName.VIEWS_REORDER_PLOTS_CUSTOM) } diff --git a/extension/src/test/fixtures/expShow/base/customPlots.ts b/extension/src/test/fixtures/expShow/base/customPlots.ts index f13bdaf77c..37536b2a2f 100644 --- a/extension/src/test/fixtures/expShow/base/customPlots.ts +++ b/extension/src/test/fixtures/expShow/base/customPlots.ts @@ -1,5 +1,5 @@ +import { ExperimentWithCheckpoints } from '../../../../experiments/model' import { copyOriginalColors } from '../../../../experiments/model/status/colors' -import { CustomCheckpointPlots } from '../../../../plots/model' import { CHECKPOINTS_PARAM, CustomPlotsOrderValue @@ -34,88 +34,158 @@ export const customPlotsOrderFixture: CustomPlotsOrderValue[] = [ } ] -export const checkpointPlotsFixture: CustomCheckpointPlots = { - 'summary.json:loss': { - id: 'custom-summary.json:loss-epoch', - metric: 'summary.json:loss', - param: CHECKPOINTS_PARAM, - type: CustomPlotType.CHECKPOINT, - values: [ - { group: 'exp-83425', iteration: 1, y: 1.9896177053451538 }, - { group: 'exp-83425', iteration: 2, y: 1.9329891204833984 }, - { group: 'exp-83425', iteration: 3, y: 1.8798457384109497 }, - { group: 'exp-83425', iteration: 4, y: 1.8261293172836304 }, - { group: 'exp-83425', iteration: 5, y: 1.775016188621521 }, - { group: 'exp-83425', iteration: 6, y: 1.775016188621521 }, - { group: 'test-branch', iteration: 1, y: 1.9882521629333496 }, - { group: 'test-branch', iteration: 2, y: 1.9293040037155151 }, - { group: 'test-branch', iteration: 3, y: 1.9293040037155151 }, - { group: 'exp-e7a67', iteration: 1, y: 2.020392894744873 }, - { group: 'exp-e7a67', iteration: 2, y: 2.0205044746398926 }, - { group: 'exp-e7a67', iteration: 3, y: 2.0205044746398926 } - ] - }, - 'summary.json:accuracy': { - id: 'custom-summary.json:accuracy-epoch', - metric: 'summary.json:accuracy', - param: CHECKPOINTS_PARAM, - type: CustomPlotType.CHECKPOINT, - values: [ - { group: 'exp-83425', iteration: 1, y: 0.40904998779296875 }, - { group: 'exp-83425', iteration: 2, y: 0.46094998717308044 }, - { group: 'exp-83425', iteration: 3, y: 0.5113166570663452 }, - { group: 'exp-83425', iteration: 4, y: 0.557449996471405 }, - { group: 'exp-83425', iteration: 5, y: 0.5926499962806702 }, - { group: 'exp-83425', iteration: 6, y: 0.5926499962806702 }, - { group: 'test-branch', iteration: 1, y: 0.4083833396434784 }, - { group: 'test-branch', iteration: 2, y: 0.4668000042438507 }, - { group: 'test-branch', iteration: 3, y: 0.4668000042438507 }, - { group: 'exp-e7a67', iteration: 1, y: 0.3723166584968567 }, - { group: 'exp-e7a67', iteration: 2, y: 0.3724166750907898 }, - { group: 'exp-e7a67', iteration: 3, y: 0.3724166750907898 } +export const experimentsWithCheckpoints: ExperimentWithCheckpoints[] = [ + { + id: '12345', + label: '123', + metrics: { + 'summary.json': { + accuracy: 0.3724166750907898, + loss: 2.0205044746398926 + } + }, + name: 'exp-e7a67', + params: { 'params.yaml': { dropout: 0.15, epochs: 2 } }, + checkpoints: [ + { + id: '12345', + label: '123', + metrics: { + 'summary.json': { + accuracy: 0.3724166750907898, + loss: 2.0205044746398926 + } + }, + name: 'exp-e7a67', + params: { 'params.yaml': { dropout: 0.15, epochs: 2 } } + }, + { + id: '12345', + label: '123', + metrics: { + 'summary.json': { + accuracy: 0.3723166584968567, + loss: 2.020392894744873 + } + }, + name: 'exp-e7a67', + params: { 'params.yaml': { dropout: 0.15, epochs: 2 } } + } ] }, - 'summary.json:val_loss': { - id: 'custom-summary.json:val_loss-epoch', - metric: 'summary.json:val_loss', - param: CHECKPOINTS_PARAM, - type: CustomPlotType.CHECKPOINT, - values: [ - { group: 'exp-83425', iteration: 1, y: 1.9391471147537231 }, - { group: 'exp-83425', iteration: 2, y: 1.8825950622558594 }, - { group: 'exp-83425', iteration: 3, y: 1.827923059463501 }, - { group: 'exp-83425', iteration: 4, y: 1.7749212980270386 }, - { group: 'exp-83425', iteration: 5, y: 1.7233840227127075 }, - { group: 'exp-83425', iteration: 6, y: 1.7233840227127075 }, - { group: 'test-branch', iteration: 1, y: 1.9363881349563599 }, - { group: 'test-branch', iteration: 2, y: 1.8770883083343506 }, - { group: 'test-branch', iteration: 3, y: 1.8770883083343506 }, - { group: 'exp-e7a67', iteration: 1, y: 1.9979370832443237 }, - { group: 'exp-e7a67', iteration: 2, y: 1.9979370832443237 }, - { group: 'exp-e7a67', iteration: 3, y: 1.9979370832443237 } + { + id: '12345', + label: '123', + metrics: { + 'summary.json': { + accuracy: 0.4668000042438507, + loss: 1.9293040037155151 + } + }, + name: 'test-branch', + params: { 'params.yaml': { dropout: 0.122, epochs: 2 } }, + checkpoints: [ + { + id: '12345', + label: '123', + metrics: { + 'summary.json': { + accuracy: 0.4668000042438507, + loss: 1.9293040037155151 + } + }, + name: 'test-branch', + params: { 'params.yaml': { dropout: 0.122, epochs: 2 } } + }, + { + id: '12345', + label: '123', + metrics: { + 'summary.json': { + accuracy: 0.4083833396434784, + loss: 1.9882521629333496 + } + }, + name: 'test-branch', + params: { 'params.yaml': { dropout: 0.122, epochs: 2 } } + } ] }, - 'summary.json:val_accuracy': { - id: 'custom-summary.json:val_accuracy-epoch', - metric: 'summary.json:val_accuracy', - param: CHECKPOINTS_PARAM, - type: CustomPlotType.CHECKPOINT, - values: [ - { group: 'exp-83425', iteration: 1, y: 0.49399998784065247 }, - { group: 'exp-83425', iteration: 2, y: 0.5550000071525574 }, - { group: 'exp-83425', iteration: 3, y: 0.6035000085830688 }, - { group: 'exp-83425', iteration: 4, y: 0.6414999961853027 }, - { group: 'exp-83425', iteration: 5, y: 0.6704000234603882 }, - { group: 'exp-83425', iteration: 6, y: 0.6704000234603882 }, - { group: 'test-branch', iteration: 1, y: 0.4970000088214874 }, - { group: 'test-branch', iteration: 2, y: 0.5608000159263611 }, - { group: 'test-branch', iteration: 3, y: 0.5608000159263611 }, - { group: 'exp-e7a67', iteration: 1, y: 0.4277999997138977 }, - { group: 'exp-e7a67', iteration: 2, y: 0.4277999997138977 }, - { group: 'exp-e7a67', iteration: 3, y: 0.4277999997138977 } + { + id: '12345', + label: '123', + metrics: { + 'summary.json': { + accuracy: 0.5926499962806702, + loss: 1.775016188621521 + } + }, + name: 'exp-83425', + params: { 'params.yaml': { dropout: 0.124, epochs: 5 } }, + checkpoints: [ + { + id: '12345', + label: '123', + metrics: { + 'summary.json': { + accuracy: 0.5926499962806702, + loss: 1.775016188621521 + } + }, + name: 'exp-83425', + params: { 'params.yaml': { dropout: 0.124, epochs: 5 } } + }, + { + id: '12345', + label: '123', + metrics: { + 'summary.json': { + accuracy: 0.557449996471405, + loss: 1.8261293172836304 + } + }, + name: 'exp-83425', + params: { 'params.yaml': { dropout: 0.124, epochs: 5 } } + }, + { + id: '12345', + label: '123', + metrics: { + 'summary.json': { + accuracy: 0.5113166570663452, + loss: 1.8798457384109497 + } + }, + name: 'exp-83425', + params: { 'params.yaml': { dropout: 0.124, epochs: 5 } } + }, + { + id: '12345', + label: '123', + metrics: { + 'summary.json': { + accuracy: 0.46094998717308044, + loss: 1.9329891204833984 + } + }, + name: 'exp-83425', + params: { 'params.yaml': { dropout: 0.124, epochs: 5 } } + }, + { + id: '12345', + label: '123', + metrics: { + 'summary.json': { + accuracy: 0.40904998779296875, + loss: 1.9896177053451538 + } + }, + name: 'exp-83425', + params: { 'params.yaml': { dropout: 0.124, epochs: 5 } } + } ] } -} +] const colors = copyOriginalColors() @@ -178,18 +248,18 @@ const data: CustomPlotsData = { metric: 'summary.json:loss', param: CHECKPOINTS_PARAM, values: [ - { group: 'exp-83425', iteration: 1, y: 1.9896177053451538 }, - { group: 'exp-83425', iteration: 2, y: 1.9329891204833984 }, - { group: 'exp-83425', iteration: 3, y: 1.8798457384109497 }, - { group: 'exp-83425', iteration: 4, y: 1.8261293172836304 }, - { group: 'exp-83425', iteration: 5, y: 1.775016188621521 }, - { group: 'exp-83425', iteration: 6, y: 1.775016188621521 }, - { group: 'test-branch', iteration: 1, y: 1.9882521629333496 }, - { group: 'test-branch', iteration: 2, y: 1.9293040037155151 }, - { group: 'test-branch', iteration: 3, y: 1.9293040037155151 }, - { group: 'exp-e7a67', iteration: 1, y: 2.020392894744873 }, + { group: 'exp-e7a67', iteration: 3, y: 2.0205044746398926 }, { group: 'exp-e7a67', iteration: 2, y: 2.0205044746398926 }, - { group: 'exp-e7a67', iteration: 3, y: 2.0205044746398926 } + { group: 'exp-e7a67', iteration: 1, y: 2.020392894744873 }, + { group: 'test-branch', iteration: 3, y: 1.9293040037155151 }, + { group: 'test-branch', iteration: 2, y: 1.9293040037155151 }, + { group: 'test-branch', iteration: 1, y: 1.9882521629333496 }, + { group: 'exp-83425', iteration: 6, y: 1.775016188621521 }, + { group: 'exp-83425', iteration: 5, y: 1.775016188621521 }, + { group: 'exp-83425', iteration: 4, y: 1.8261293172836304 }, + { group: 'exp-83425', iteration: 3, y: 1.8798457384109497 }, + { group: 'exp-83425', iteration: 2, y: 1.9329891204833984 }, + { group: 'exp-83425', iteration: 1, y: 1.9896177053451538 } ], type: CustomPlotType.CHECKPOINT, yTitle: 'summary.json:loss' @@ -199,18 +269,18 @@ const data: CustomPlotsData = { metric: 'summary.json:accuracy', param: CHECKPOINTS_PARAM, values: [ - { group: 'exp-83425', iteration: 1, y: 0.40904998779296875 }, - { group: 'exp-83425', iteration: 2, y: 0.46094998717308044 }, - { group: 'exp-83425', iteration: 3, y: 0.5113166570663452 }, - { group: 'exp-83425', iteration: 4, y: 0.557449996471405 }, - { group: 'exp-83425', iteration: 5, y: 0.5926499962806702 }, - { group: 'exp-83425', iteration: 6, y: 0.5926499962806702 }, - { group: 'test-branch', iteration: 1, y: 0.4083833396434784 }, - { group: 'test-branch', iteration: 2, y: 0.4668000042438507 }, - { group: 'test-branch', iteration: 3, y: 0.4668000042438507 }, - { group: 'exp-e7a67', iteration: 1, y: 0.3723166584968567 }, + { group: 'exp-e7a67', iteration: 3, y: 0.3724166750907898 }, { group: 'exp-e7a67', iteration: 2, y: 0.3724166750907898 }, - { group: 'exp-e7a67', iteration: 3, y: 0.3724166750907898 } + { group: 'exp-e7a67', iteration: 1, y: 0.3723166584968567 }, + { group: 'test-branch', iteration: 3, y: 0.4668000042438507 }, + { group: 'test-branch', iteration: 2, y: 0.4668000042438507 }, + { group: 'test-branch', iteration: 1, y: 0.4083833396434784 }, + { group: 'exp-83425', iteration: 6, y: 0.5926499962806702 }, + { group: 'exp-83425', iteration: 5, y: 0.5926499962806702 }, + { group: 'exp-83425', iteration: 4, y: 0.557449996471405 }, + { group: 'exp-83425', iteration: 3, y: 0.5113166570663452 }, + { group: 'exp-83425', iteration: 2, y: 0.46094998717308044 }, + { group: 'exp-83425', iteration: 1, y: 0.40904998779296875 } ], type: CustomPlotType.CHECKPOINT, yTitle: 'summary.json:accuracy'