From 2db949de20ec6dd58f6a7fd94cb5d410f271ab95 Mon Sep 17 00:00:00 2001 From: julieg18 Date: Tue, 14 Mar 2023 14:46:06 -0500 Subject: [PATCH 01/12] WIP on consolidating custom data collection --- extension/src/experiments/index.ts | 4 + extension/src/plots/index.ts | 9 +- extension/src/plots/model/collect.test.ts | 55 +--- extension/src/plots/model/collect.ts | 255 ++++-------------- extension/src/plots/model/index.ts | 24 +- .../test/fixtures/expShow/base/customPlots.ts | 223 +++++++++------ 6 files changed, 216 insertions(+), 354 deletions(-) diff --git a/extension/src/experiments/index.ts b/extension/src/experiments/index.ts index b9c3744481..a0c1aaa50a 100644 --- a/extension/src/experiments/index.ts +++ b/extension/src/experiments/index.ts @@ -333,6 +333,10 @@ export class Experiments extends BaseRepository { return this.experiments.getExperimentCount() } + public getRowData() { + return this.experiments.getExperimentsWithCheckpoints() + } + public async selectExperiments() { const experiments = this.experiments.getExperimentsWithCheckpoints() diff --git a/extension/src/plots/index.ts b/extension/src/plots/index.ts index 97913382bd..1a1dff877c 100644 --- a/extension/src/plots/index.ts +++ b/extension/src/plots/index.ts @@ -13,7 +13,6 @@ import { Experiments } from '../experiments' import { Resource } from '../resourceLocator' import { InternalCommands } from '../commands/internal' import { definedAndNonEmpty } from '../util/array' -import { ExperimentsOutput } from '../cli/dvc/contract' import { TEMP_PLOTS_DIR } from '../cli/dvc/constants' import { removeDir } from '../fileSystem' import { Toast } from '../vscode/toast' @@ -173,7 +172,7 @@ export class Plots extends BaseRepository { waitForInitialExpData.dispose() this.data.setMetricFiles(data) this.setupExperimentsListener(experiments) - void this.initializeData(data) + void this.initializeData() } }) ) @@ -184,7 +183,7 @@ export class Plots extends BaseRepository { experiments.onDidChangeExperiments(async data => { if (data) { await Promise.all([ - this.plots.transformAndSetExperiments(data), + this.plots.transformAndSetExperiments(), this.data.setMetricFiles(data) ]) } @@ -200,8 +199,8 @@ export class Plots extends BaseRepository { ) } - private async initializeData(data: ExperimentsOutput) { - await this.plots.transformAndSetExperiments(data) + private async initializeData() { + await this.plots.transformAndSetExperiments() void this.data.managedUpdate() await Promise.all([ this.data.isReady(), diff --git a/extension/src/plots/model/collect.test.ts b/extension/src/plots/model/collect.test.ts index b5175eae4c..b212f2c37b 100644 --- a/extension/src/plots/model/collect.test.ts +++ b/extension/src/plots/model/collect.test.ts @@ -1,18 +1,16 @@ import { join } from 'path' -import omit from 'lodash.omit' import isEmpty from 'lodash.isempty' import { collectData, collectTemplates, collectOverrideRevisionDetails, collectCustomPlots, - collectCustomCheckpointPlots, collectCustomPlotData } from './collect' import plotsDiffFixture from '../../test/fixtures/plotsDiff/output' import customPlotsFixture, { customPlotsOrderFixture, - checkpointPlotsFixture + experimentsWithCheckpoints } from '../../test/fixtures/expShow/base/customPlots' import { ExperimentStatus, @@ -26,8 +24,6 @@ import { TemplatePlot } from '../webview/contract' import { getCLICommitId } from '../../test/fixtures/plotsDiff/util' -import expShowFixture from '../../test/fixtures/expShow/base/output' -import modifiedFixture from '../../test/fixtures/expShow/modified/output' import { SelectedExperimentWithColor } from '../../experiments/model' import { Experiment } from '../../experiments/webview/contract' @@ -57,7 +53,6 @@ describe('collectCustomPlots', () => { ) const data = collectCustomPlots( customPlotsOrderFixture, - checkpointPlotsFixture, [ { id: '12345', @@ -95,9 +90,9 @@ describe('collectCustomPlots', () => { name: 'exp-83425', params: { 'params.yaml': { dropout: 0.124, epochs: 5 } } } - ] + ], + experimentsWithCheckpoints ) - expect(data).toStrictEqual(expectedOutput) }) }) @@ -215,50 +210,6 @@ describe('collectData', () => { }) }) -describe('collectCustomCheckpointPlotsData', () => { - it('should return the expected data from the test fixture', () => { - const data = collectCustomCheckpointPlots(expShowFixture) - - expect(data).toStrictEqual(checkpointPlotsFixture) - }) - - it('should provide a continuous series for a modified experiment', () => { - const data = collectCustomCheckpointPlots(modifiedFixture) - - for (const { values } of Object.values(data)) { - const initialExperiment = values.filter( - point => point.group === 'exp-908bd' - ) - const modifiedExperiment = values.find( - point => point.group === 'exp-01b3a' - ) - - const lastIterationInitial = initialExperiment?.slice(-1)[0] - const firstIterationModified = modifiedExperiment - - expect(lastIterationInitial).not.toStrictEqual(firstIterationModified) - expect(omit(lastIterationInitial, 'group')).toStrictEqual( - omit(firstIterationModified, 'group') - ) - - const baseExperiment = values.filter(point => point.group === 'exp-920fc') - const restartedExperiment = values.find( - point => point.group === 'exp-9bc1b' - ) - - const iterationRestartedFrom = baseExperiment?.slice(5)[0] - const firstIterationAfterRestart = restartedExperiment - - expect(iterationRestartedFrom).not.toStrictEqual( - firstIterationAfterRestart - ) - expect(omit(iterationRestartedFrom, 'group')).toStrictEqual( - omit(firstIterationAfterRestart, 'group') - ) - } - }) -}) - describe('collectTemplates', () => { it('should return the expected output from the test fixture', () => { const { content } = logsLossPlot diff --git a/extension/src/plots/model/collect.ts b/extension/src/plots/model/collect.ts index 64b7beccbd..87ec8cd1f8 100644 --- a/extension/src/plots/model/collect.ts +++ b/extension/src/plots/model/collect.ts @@ -1,8 +1,6 @@ -import omit from 'lodash.omit' import get from 'lodash.get' import { TopLevelSpec } from 'vega-lite' import { VisualizationSpec } from 'react-vega' -import { CustomCheckpointPlots } from '.' import { getFullValuePath, CHECKPOINTS_PARAM, @@ -25,32 +23,19 @@ import { CustomPlotType, CustomPlot, MetricVsParamPlot, - CustomPlotData + CustomPlotData, + CheckpointPlot } from '../webview/contract' +import { EXPERIMENT_WORKSPACE_ID, PlotsOutput } from '../../cli/dvc/contract' import { - EXPERIMENT_WORKSPACE_ID, - ExperimentFieldsOrError, - ExperimentsOutput, - ExperimentStatus, - isValueTree, - PlotsOutput, - Value, - ValueTree -} from '../../cli/dvc/contract' -import { extractColumns } from '../../experiments/columns/extract' -import { - decodeColumn, - appendColumnToPath, splitColumnPath, FILE_SEPARATOR } from '../../experiments/columns/paths' import { ColumnType, Experiment, - isRunning, - MetricOrParamColumns + isRunning } from '../../experiments/webview/contract' -import { addToMapArray } from '../../util/map' import { TemplateOrder } from '../paths/collect' import { extendVegaSpec, @@ -67,204 +52,59 @@ import { unmergeConcatenatedFields } from '../multiSource/collect' import { StrokeDashEncoding } from '../multiSource/constants' -import { SelectedExperimentWithColor } from '../../experiments/model' +import { + ExperimentWithCheckpoints, + SelectedExperimentWithColor +} from '../../experiments/model' import { Color } from '../../experiments/model/status/colors' -import { typedValueTreeEntries } from '../../experiments/columns/collect/metricsAndParams' - -type CheckpointPlotAccumulator = { - iterations: Record - plots: Map -} - -const collectFromMetricsFile = ( - acc: CheckpointPlotAccumulator, - name: string, - iteration: number, - key: string | undefined, - value: Value | ValueTree, - ancestors: string[] = [] -) => { - const pathArray = [...ancestors, key].filter(Boolean) as string[] - - if (isValueTree(value)) { - for (const [childKey, childValue] of typedValueTreeEntries(value)) { - collectFromMetricsFile( - acc, - name, - iteration, - childKey, - childValue, - pathArray - ) - } - return - } - - const path = appendColumnToPath(...pathArray) - - addToMapArray(acc.plots, path, { group: name, iteration, y: value }) -} - -type MetricsAndDetailsOrUndefined = - | { - checkpoint_parent: string | undefined - checkpoint_tip: string | undefined - metrics: MetricOrParamColumns | undefined - status: ExperimentStatus | undefined - } - | undefined - -const transformExperimentData = ( - experimentFieldsOrError: ExperimentFieldsOrError -): MetricsAndDetailsOrUndefined => { - const experimentFields = experimentFieldsOrError.data - if (!experimentFields) { - return - } - const { checkpoint_tip, checkpoint_parent, status } = experimentFields - const { metrics } = extractColumns(experimentFields) - - return { checkpoint_parent, checkpoint_tip, metrics, status } -} - -type ValidData = { - checkpoint_parent: string - checkpoint_tip: string - metrics: MetricOrParamColumns - status: ExperimentStatus -} - -const isValid = (data: MetricsAndDetailsOrUndefined): data is ValidData => - !!(data?.checkpoint_tip && data?.checkpoint_parent && data?.metrics) - -const collectFromMetrics = ( - acc: CheckpointPlotAccumulator, - experimentName: string, - iteration: number, - metrics: MetricOrParamColumns -) => { - for (const file of Object.keys(metrics)) { - collectFromMetricsFile( - acc, - experimentName, - iteration, - undefined, - metrics[file], - [file] - ) - } -} - -const getLastIteration = ( - acc: CheckpointPlotAccumulator, - checkpointParent: string -): number => acc.iterations[checkpointParent] || 0 - -const collectIteration = ( - acc: CheckpointPlotAccumulator, - sha: string, - checkpointParent: string -): number => { - const iteration = getLastIteration(acc, checkpointParent) + 1 - acc.iterations[sha] = iteration - return iteration -} +export const getCustomPlotId = (metric: string, param = CHECKPOINTS_PARAM) => + `custom-${metric}-${param}` -const linkModified = ( - acc: CheckpointPlotAccumulator, - experimentName: string, - checkpointTip: string, - checkpointParent: string, - parent: ExperimentFieldsOrError | undefined +const getExperimentValues = ( + allValues: CheckpointPlotValues, + exp: ExperimentWithCheckpoints, + metric: string ) => { - if (!parent) { - return - } + const splitMetric = splitColumnPath( + getFullValuePath(ColumnType.METRICS, metric, FILE_SEPARATOR) + ) + const values = [] + const group = exp.name || exp.label + const expEpochLength = (exp.checkpoints as Experiment[]).length + 1 - const parentData = transformExperimentData(parent) - if (!isValid(parentData) || parentData.checkpoint_tip === checkpointTip) { - return + const y = get(exp, splitMetric) as number | undefined + if (y !== undefined) { + values.push({ group, iteration: expEpochLength, y }) } - const lastIteration = getLastIteration(acc, checkpointParent) - collectFromMetrics(acc, experimentName, lastIteration, parentData.metrics) -} - -const collectFromExperimentsObject = ( - acc: CheckpointPlotAccumulator, - experimentsObject: { [sha: string]: ExperimentFieldsOrError } -) => { - for (const [sha, experimentData] of Object.entries( - experimentsObject - ).reverse()) { - const data = transformExperimentData(experimentData) - - if (!isValid(data)) { - continue + for (const [ind, checkpoint] of (exp.checkpoints as Experiment[]).entries()) { + const y = get(checkpoint, splitMetric) as number | undefined + if (y !== undefined) { + values.push({ group, iteration: expEpochLength - ind - 1, y }) } - const { - checkpoint_tip: checkpointTip, - checkpoint_parent: checkpointParent, - metrics - } = data - - const experimentName = experimentsObject[checkpointTip].data?.name - if (!experimentName) { - continue - } - - linkModified( - acc, - experimentName, - checkpointTip, - checkpointParent, - experimentsObject[checkpointParent] - ) - - const iteration = collectIteration(acc, sha, checkpointParent) - collectFromMetrics(acc, experimentName, iteration, metrics) } + values.reverse() + allValues.push(...values) } -export const getCustomPlotId = (metric: string, param = CHECKPOINTS_PARAM) => - `custom-${metric}-${param}` - -export const collectCustomCheckpointPlots = ( - data: ExperimentsOutput -): CustomCheckpointPlots => { - const acc = { - iterations: {}, - plots: new Map() - } - - for (const { baseline, ...experimentsObject } of Object.values( - omit(data, EXPERIMENT_WORKSPACE_ID) - )) { - const commit = transformExperimentData(baseline) - - if (commit) { - collectFromExperimentsObject(acc, experimentsObject) - } - } - - const plotsData: CustomCheckpointPlots = {} - if (acc.plots.size === 0) { - return plotsData +const collectCheckpointPlot = ( + metric: string, + experiments: ExperimentWithCheckpoints[] +): CheckpointPlot => { + const plotData: CheckpointPlot = { + id: getCustomPlotId(metric), + metric, + param: CHECKPOINTS_PARAM, + type: CustomPlotType.CHECKPOINT, + values: [] } - - for (const [key, value] of acc.plots.entries()) { - const decodedMetric = decodeColumn(key) - plotsData[decodedMetric] = { - id: getCustomPlotId(decodedMetric), - metric: decodedMetric, - param: CHECKPOINTS_PARAM, - type: CustomPlotType.CHECKPOINT, - values: value - } + const fullValues: CheckpointPlotValues = [] + for (const experiment of experiments) { + getExperimentValues(fullValues, experiment, metric) } - - return plotsData + plotData.values = fullValues + return plotData } const collectMetricVsParamPlot = ( @@ -304,14 +144,15 @@ const collectMetricVsParamPlot = ( export const collectCustomPlots = ( plotsOrderValues: CustomPlotsOrderValue[], - checkpointPlots: CustomCheckpointPlots, - experiments: Experiment[] + experiments: Experiment[], + selectedExperiments: ExperimentWithCheckpoints[] ): CustomPlot[] => { return plotsOrderValues .map((plotOrderValue): CustomPlot => { if (isCheckpointValue(plotOrderValue.type)) { const { metric } = plotOrderValue - return checkpointPlots[metric] + + return collectCheckpointPlot(metric, selectedExperiments) } const { metric, param } = plotOrderValue return collectMetricVsParamPlot(metric, param, experiments) diff --git a/extension/src/plots/model/index.ts b/extension/src/plots/model/index.ts index 2fd2f22ad7..c54f9d5b28 100644 --- a/extension/src/plots/model/index.ts +++ b/extension/src/plots/model/index.ts @@ -12,7 +12,6 @@ import { collectOverrideRevisionDetails, collectCustomPlots, getCustomPlotId, - collectCustomCheckpointPlots, collectCustomPlotData } from './collect' import { getRevisionFirstThreeColumns } from './util' @@ -39,7 +38,6 @@ import { PlotHeight } from '../webview/contract' import { - ExperimentsOutput, EXPERIMENT_WORKSPACE_ID, PlotsOutputOrError } from '../../cli/dvc/contract' @@ -80,7 +78,6 @@ export class PlotsModel extends ModelWithPersistence { private multiSourceVariations: MultiSourceVariations = {} private multiSourceEncoding: MultiSourceEncoding = {} - private customCheckpointPlots?: CustomCheckpointPlots private customPlots?: CustomPlot[] constructor( @@ -105,8 +102,8 @@ export class PlotsModel extends ModelWithPersistence { this.customPlotsOrder = this.revive(PersistenceKey.PLOTS_CUSTOM_ORDER, []) } - public transformAndSetExperiments(data: ExperimentsOutput) { - this.recreateCustomPlots(data) + public transformAndSetExperiments() { + this.recreateCustomPlots() return this.removeStaleData() } @@ -171,21 +168,20 @@ export class PlotsModel extends ModelWithPersistence { } } - public recreateCustomPlots(data?: ExperimentsOutput) { - if (data) { - this.customCheckpointPlots = collectCustomCheckpointPlots(data) - } - - const experiments = this.experiments.getExperiments() + public recreateCustomPlots() { + const allExperiments = this.experiments.getExperiments() + const experimentsWithCheckpoints = this.experiments + .getRowData() + .filter(({ checkpoints }) => !!checkpoints) - if (experiments.length === 0) { + if (allExperiments.length === 0) { this.customPlots = undefined return } const customPlots: CustomPlot[] = collectCustomPlots( this.getCustomPlotsOrder(), - this.customCheckpointPlots || {}, - experiments + experimentsWithCheckpoints, + this.experiments.getRowData().filter(({ checkpoints }) => !!checkpoints) ) this.customPlots = customPlots } diff --git a/extension/src/test/fixtures/expShow/base/customPlots.ts b/extension/src/test/fixtures/expShow/base/customPlots.ts index f13bdaf77c..4a1b85ce14 100644 --- a/extension/src/test/fixtures/expShow/base/customPlots.ts +++ b/extension/src/test/fixtures/expShow/base/customPlots.ts @@ -1,3 +1,4 @@ +import { ExperimentWithCheckpoints } from '../../../../experiments/model' import { copyOriginalColors } from '../../../../experiments/model/status/colors' import { CustomCheckpointPlots } from '../../../../plots/model' import { @@ -34,88 +35,158 @@ export const customPlotsOrderFixture: CustomPlotsOrderValue[] = [ } ] -export const checkpointPlotsFixture: CustomCheckpointPlots = { - 'summary.json:loss': { - id: 'custom-summary.json:loss-epoch', - metric: 'summary.json:loss', - param: CHECKPOINTS_PARAM, - type: CustomPlotType.CHECKPOINT, - values: [ - { group: 'exp-83425', iteration: 1, y: 1.9896177053451538 }, - { group: 'exp-83425', iteration: 2, y: 1.9329891204833984 }, - { group: 'exp-83425', iteration: 3, y: 1.8798457384109497 }, - { group: 'exp-83425', iteration: 4, y: 1.8261293172836304 }, - { group: 'exp-83425', iteration: 5, y: 1.775016188621521 }, - { group: 'exp-83425', iteration: 6, y: 1.775016188621521 }, - { group: 'test-branch', iteration: 1, y: 1.9882521629333496 }, - { group: 'test-branch', iteration: 2, y: 1.9293040037155151 }, - { group: 'test-branch', iteration: 3, y: 1.9293040037155151 }, - { group: 'exp-e7a67', iteration: 1, y: 2.020392894744873 }, - { group: 'exp-e7a67', iteration: 2, y: 2.0205044746398926 }, - { group: 'exp-e7a67', iteration: 3, y: 2.0205044746398926 } - ] - }, - 'summary.json:accuracy': { - id: 'custom-summary.json:accuracy-epoch', - metric: 'summary.json:accuracy', - param: CHECKPOINTS_PARAM, - type: CustomPlotType.CHECKPOINT, - values: [ - { group: 'exp-83425', iteration: 1, y: 0.40904998779296875 }, - { group: 'exp-83425', iteration: 2, y: 0.46094998717308044 }, - { group: 'exp-83425', iteration: 3, y: 0.5113166570663452 }, - { group: 'exp-83425', iteration: 4, y: 0.557449996471405 }, - { group: 'exp-83425', iteration: 5, y: 0.5926499962806702 }, - { group: 'exp-83425', iteration: 6, y: 0.5926499962806702 }, - { group: 'test-branch', iteration: 1, y: 0.4083833396434784 }, - { group: 'test-branch', iteration: 2, y: 0.4668000042438507 }, - { group: 'test-branch', iteration: 3, y: 0.4668000042438507 }, - { group: 'exp-e7a67', iteration: 1, y: 0.3723166584968567 }, - { group: 'exp-e7a67', iteration: 2, y: 0.3724166750907898 }, - { group: 'exp-e7a67', iteration: 3, y: 0.3724166750907898 } +export const experimentsWithCheckpoints: ExperimentWithCheckpoints[] = [ + { + id: '12345', + label: '123', + metrics: { + 'summary.json': { + accuracy: 0.5926499962806702, + loss: 1.775016188621521 + } + }, + name: 'exp-83425', + params: { 'params.yaml': { dropout: 0.124, epochs: 5 } }, + checkpoints: [ + { + id: '12345', + label: '123', + metrics: { + 'summary.json': { + accuracy: 0.5926499962806702, + loss: 1.775016188621521 + } + }, + name: 'exp-83425', + params: { 'params.yaml': { dropout: 0.124, epochs: 5 } } + }, + { + id: '12345', + label: '123', + metrics: { + 'summary.json': { + accuracy: 0.557449996471405, + loss: 1.8261293172836304 + } + }, + name: 'exp-83425', + params: { 'params.yaml': { dropout: 0.124, epochs: 5 } } + }, + { + id: '12345', + label: '123', + metrics: { + 'summary.json': { + accuracy: 0.5113166570663452, + loss: 1.8798457384109497 + } + }, + name: 'exp-83425', + params: { 'params.yaml': { dropout: 0.124, epochs: 5 } } + }, + { + id: '12345', + label: '123', + metrics: { + 'summary.json': { + accuracy: 0.46094998717308044, + loss: 1.9329891204833984 + } + }, + name: 'exp-83425', + params: { 'params.yaml': { dropout: 0.124, epochs: 5 } } + }, + { + id: '12345', + label: '123', + metrics: { + 'summary.json': { + accuracy: 0.40904998779296875, + loss: 1.9896177053451538 + } + }, + name: 'exp-83425', + params: { 'params.yaml': { dropout: 0.124, epochs: 5 } } + } ] }, - 'summary.json:val_loss': { - id: 'custom-summary.json:val_loss-epoch', - metric: 'summary.json:val_loss', - param: CHECKPOINTS_PARAM, - type: CustomPlotType.CHECKPOINT, - values: [ - { group: 'exp-83425', iteration: 1, y: 1.9391471147537231 }, - { group: 'exp-83425', iteration: 2, y: 1.8825950622558594 }, - { group: 'exp-83425', iteration: 3, y: 1.827923059463501 }, - { group: 'exp-83425', iteration: 4, y: 1.7749212980270386 }, - { group: 'exp-83425', iteration: 5, y: 1.7233840227127075 }, - { group: 'exp-83425', iteration: 6, y: 1.7233840227127075 }, - { group: 'test-branch', iteration: 1, y: 1.9363881349563599 }, - { group: 'test-branch', iteration: 2, y: 1.8770883083343506 }, - { group: 'test-branch', iteration: 3, y: 1.8770883083343506 }, - { group: 'exp-e7a67', iteration: 1, y: 1.9979370832443237 }, - { group: 'exp-e7a67', iteration: 2, y: 1.9979370832443237 }, - { group: 'exp-e7a67', iteration: 3, y: 1.9979370832443237 } + { + id: '12345', + label: '123', + metrics: { + 'summary.json': { + accuracy: 0.4668000042438507, + loss: 1.9293040037155151 + } + }, + name: 'test-branch', + params: { 'params.yaml': { dropout: 0.122, epochs: 2 } }, + checkpoints: [ + { + id: '12345', + label: '123', + metrics: { + 'summary.json': { + accuracy: 0.4668000042438507, + loss: 1.9293040037155151 + } + }, + name: 'test-branch', + params: { 'params.yaml': { dropout: 0.122, epochs: 2 } } + }, + { + id: '12345', + label: '123', + metrics: { + 'summary.json': { + accuracy: 0.4083833396434784, + loss: 1.9882521629333496 + } + }, + name: 'test-branch', + params: { 'params.yaml': { dropout: 0.122, epochs: 2 } } + } ] }, - 'summary.json:val_accuracy': { - id: 'custom-summary.json:val_accuracy-epoch', - metric: 'summary.json:val_accuracy', - param: CHECKPOINTS_PARAM, - type: CustomPlotType.CHECKPOINT, - values: [ - { group: 'exp-83425', iteration: 1, y: 0.49399998784065247 }, - { group: 'exp-83425', iteration: 2, y: 0.5550000071525574 }, - { group: 'exp-83425', iteration: 3, y: 0.6035000085830688 }, - { group: 'exp-83425', iteration: 4, y: 0.6414999961853027 }, - { group: 'exp-83425', iteration: 5, y: 0.6704000234603882 }, - { group: 'exp-83425', iteration: 6, y: 0.6704000234603882 }, - { group: 'test-branch', iteration: 1, y: 0.4970000088214874 }, - { group: 'test-branch', iteration: 2, y: 0.5608000159263611 }, - { group: 'test-branch', iteration: 3, y: 0.5608000159263611 }, - { group: 'exp-e7a67', iteration: 1, y: 0.4277999997138977 }, - { group: 'exp-e7a67', iteration: 2, y: 0.4277999997138977 }, - { group: 'exp-e7a67', iteration: 3, y: 0.4277999997138977 } + { + id: '12345', + label: '123', + metrics: { + 'summary.json': { + accuracy: 0.3724166750907898, + loss: 2.0205044746398926 + } + }, + name: 'exp-e7a67', + params: { 'params.yaml': { dropout: 0.15, epochs: 2 } }, + checkpoints: [ + { + id: '12345', + label: '123', + metrics: { + 'summary.json': { + accuracy: 0.3724166750907898, + loss: 2.0205044746398926 + } + }, + name: 'exp-e7a67', + params: { 'params.yaml': { dropout: 0.15, epochs: 2 } } + }, + { + id: '12345', + label: '123', + metrics: { + 'summary.json': { + accuracy: 0.3723166584968567, + loss: 2.020392894744873 + } + }, + name: 'exp-e7a67', + params: { 'params.yaml': { dropout: 0.15, epochs: 2 } } + } ] } -} +] const colors = copyOriginalColors() From 981091f8aae806fc0894def9ae43f5971d7bf62d Mon Sep 17 00:00:00 2001 From: julieg18 Date: Tue, 14 Mar 2023 18:58:05 -0500 Subject: [PATCH 02/12] do some more consolidation and get working tests --- extension/src/experiments/model/index.ts | 4 + extension/src/plots/model/collect.test.ts | 40 +---- extension/src/plots/model/collect.ts | 43 ++--- extension/src/plots/model/index.ts | 6 +- .../test/fixtures/expShow/base/customPlots.ts | 165 +++++++++--------- 5 files changed, 114 insertions(+), 144 deletions(-) diff --git a/extension/src/experiments/model/index.ts b/extension/src/experiments/model/index.ts index 0a29502689..1a7a386086 100644 --- a/extension/src/experiments/model/index.ts +++ b/extension/src/experiments/model/index.ts @@ -60,6 +60,10 @@ export type ExperimentWithCheckpoints = Experiment & { checkpoints?: Experiment[] } +export type ExperimentWithDefinedCheckpoints = Experiment & { + checkpoints: Experiment[] +} + export enum ExperimentType { WORKSPACE = 'workspace', COMMIT = 'commit', diff --git a/extension/src/plots/model/collect.test.ts b/extension/src/plots/model/collect.test.ts index b212f2c37b..ce53fa017e 100644 --- a/extension/src/plots/model/collect.test.ts +++ b/extension/src/plots/model/collect.test.ts @@ -30,7 +30,7 @@ import { Experiment } from '../../experiments/webview/contract' const logsLossPath = join('logs', 'loss.tsv') const logsLossPlot = (plotsDiffFixture[logsLossPath][0] || {}) as TemplatePlot - +// missing tests here (collect function are still a WIP) const getCustomPlotFromCustomPlotData = ({ id, metric, @@ -53,44 +53,6 @@ describe('collectCustomPlots', () => { ) const data = collectCustomPlots( customPlotsOrderFixture, - [ - { - id: '12345', - label: '123', - metrics: { - 'summary.json': { - accuracy: 0.3724166750907898, - loss: 2.0205044746398926 - } - }, - name: 'exp-e7a67', - params: { 'params.yaml': { dropout: 0.15, epochs: 2 } } - }, - { - id: '12345', - label: '123', - metrics: { - 'summary.json': { - accuracy: 0.4668000042438507, - loss: 1.9293040037155151 - } - }, - name: 'test-branch', - params: { 'params.yaml': { dropout: 0.122, epochs: 2 } } - }, - { - id: '12345', - label: '123', - metrics: { - 'summary.json': { - accuracy: 0.5926499962806702, - loss: 1.775016188621521 - } - }, - name: 'exp-83425', - params: { 'params.yaml': { dropout: 0.124, epochs: 5 } } - } - ], experimentsWithCheckpoints ) expect(data).toStrictEqual(expectedOutput) diff --git a/extension/src/plots/model/collect.ts b/extension/src/plots/model/collect.ts index 87ec8cd1f8..2274e81afd 100644 --- a/extension/src/plots/model/collect.ts +++ b/extension/src/plots/model/collect.ts @@ -54,7 +54,8 @@ import { import { StrokeDashEncoding } from '../multiSource/constants' import { ExperimentWithCheckpoints, - SelectedExperimentWithColor + SelectedExperimentWithColor, + ExperimentWithDefinedCheckpoints } from '../../experiments/model' import { Color } from '../../experiments/model/status/colors' @@ -62,49 +63,52 @@ export const getCustomPlotId = (metric: string, param = CHECKPOINTS_PARAM) => `custom-${metric}-${param}` const getExperimentValues = ( - allValues: CheckpointPlotValues, - exp: ExperimentWithCheckpoints, + values: CheckpointPlotValues, + exp: ExperimentWithDefinedCheckpoints, metric: string ) => { const splitMetric = splitColumnPath( getFullValuePath(ColumnType.METRICS, metric, FILE_SEPARATOR) ) - const values = [] const group = exp.name || exp.label - const expEpochLength = (exp.checkpoints as Experiment[]).length + 1 + const expEpochLength = exp.checkpoints.length + 1 const y = get(exp, splitMetric) as number | undefined if (y !== undefined) { values.push({ group, iteration: expEpochLength, y }) } - for (const [ind, checkpoint] of (exp.checkpoints as Experiment[]).entries()) { + for (const [ind, checkpoint] of exp.checkpoints.entries()) { const y = get(checkpoint, splitMetric) as number | undefined if (y !== undefined) { values.push({ group, iteration: expEpochLength - ind - 1, y }) } } - values.reverse() - allValues.push(...values) } - +// I belive we can further combine these +// leaving the only separate thing being the +// creation of the values const collectCheckpointPlot = ( metric: string, experiments: ExperimentWithCheckpoints[] ): CheckpointPlot => { - const plotData: CheckpointPlot = { + const fullValues: CheckpointPlotValues = [] + for (const experiment of experiments) { + if (experiment.checkpoints) { + getExperimentValues( + fullValues, + experiment as ExperimentWithDefinedCheckpoints, + metric + ) + } + } + return { id: getCustomPlotId(metric), metric, param: CHECKPOINTS_PARAM, type: CustomPlotType.CHECKPOINT, - values: [] - } - const fullValues: CheckpointPlotValues = [] - for (const experiment of experiments) { - getExperimentValues(fullValues, experiment, metric) + values: fullValues } - plotData.values = fullValues - return plotData } const collectMetricVsParamPlot = ( @@ -144,15 +148,14 @@ const collectMetricVsParamPlot = ( export const collectCustomPlots = ( plotsOrderValues: CustomPlotsOrderValue[], - experiments: Experiment[], - selectedExperiments: ExperimentWithCheckpoints[] + experiments: Experiment[] ): CustomPlot[] => { return plotsOrderValues .map((plotOrderValue): CustomPlot => { if (isCheckpointValue(plotOrderValue.type)) { const { metric } = plotOrderValue - return collectCheckpointPlot(metric, selectedExperiments) + return collectCheckpointPlot(metric, experiments) } const { metric, param } = plotOrderValue return collectMetricVsParamPlot(metric, param, experiments) diff --git a/extension/src/plots/model/index.ts b/extension/src/plots/model/index.ts index c54f9d5b28..eb847dbcd5 100644 --- a/extension/src/plots/model/index.ts +++ b/extension/src/plots/model/index.ts @@ -180,8 +180,7 @@ export class PlotsModel extends ModelWithPersistence { } const customPlots: CustomPlot[] = collectCustomPlots( this.getCustomPlotsOrder(), - experimentsWithCheckpoints, - this.experiments.getRowData().filter(({ checkpoints }) => !!checkpoints) + experimentsWithCheckpoints ) this.customPlots = customPlots } @@ -461,6 +460,9 @@ export class PlotsModel extends ModelWithPersistence { const selectedExperimentsExist = !!colors const filteredPlots: CustomPlotData[] = [] for (const plot of plots) { + // i wonder if we could filter this when + // were doing the collection part now + // instead of after... if (!selectedExperimentsExist && isCheckpointPlot(plot)) { continue } diff --git a/extension/src/test/fixtures/expShow/base/customPlots.ts b/extension/src/test/fixtures/expShow/base/customPlots.ts index 4a1b85ce14..37536b2a2f 100644 --- a/extension/src/test/fixtures/expShow/base/customPlots.ts +++ b/extension/src/test/fixtures/expShow/base/customPlots.ts @@ -1,6 +1,5 @@ import { ExperimentWithCheckpoints } from '../../../../experiments/model' import { copyOriginalColors } from '../../../../experiments/model/status/colors' -import { CustomCheckpointPlots } from '../../../../plots/model' import { CHECKPOINTS_PARAM, CustomPlotsOrderValue @@ -41,72 +40,36 @@ export const experimentsWithCheckpoints: ExperimentWithCheckpoints[] = [ label: '123', metrics: { 'summary.json': { - accuracy: 0.5926499962806702, - loss: 1.775016188621521 + accuracy: 0.3724166750907898, + loss: 2.0205044746398926 } }, - name: 'exp-83425', - params: { 'params.yaml': { dropout: 0.124, epochs: 5 } }, + name: 'exp-e7a67', + params: { 'params.yaml': { dropout: 0.15, epochs: 2 } }, checkpoints: [ { id: '12345', label: '123', metrics: { 'summary.json': { - accuracy: 0.5926499962806702, - loss: 1.775016188621521 - } - }, - name: 'exp-83425', - params: { 'params.yaml': { dropout: 0.124, epochs: 5 } } - }, - { - id: '12345', - label: '123', - metrics: { - 'summary.json': { - accuracy: 0.557449996471405, - loss: 1.8261293172836304 - } - }, - name: 'exp-83425', - params: { 'params.yaml': { dropout: 0.124, epochs: 5 } } - }, - { - id: '12345', - label: '123', - metrics: { - 'summary.json': { - accuracy: 0.5113166570663452, - loss: 1.8798457384109497 - } - }, - name: 'exp-83425', - params: { 'params.yaml': { dropout: 0.124, epochs: 5 } } - }, - { - id: '12345', - label: '123', - metrics: { - 'summary.json': { - accuracy: 0.46094998717308044, - loss: 1.9329891204833984 + accuracy: 0.3724166750907898, + loss: 2.0205044746398926 } }, - name: 'exp-83425', - params: { 'params.yaml': { dropout: 0.124, epochs: 5 } } + name: 'exp-e7a67', + params: { 'params.yaml': { dropout: 0.15, epochs: 2 } } }, { id: '12345', label: '123', metrics: { 'summary.json': { - accuracy: 0.40904998779296875, - loss: 1.9896177053451538 + accuracy: 0.3723166584968567, + loss: 2.020392894744873 } }, - name: 'exp-83425', - params: { 'params.yaml': { dropout: 0.124, epochs: 5 } } + name: 'exp-e7a67', + params: { 'params.yaml': { dropout: 0.15, epochs: 2 } } } ] }, @@ -153,36 +116,72 @@ export const experimentsWithCheckpoints: ExperimentWithCheckpoints[] = [ label: '123', metrics: { 'summary.json': { - accuracy: 0.3724166750907898, - loss: 2.0205044746398926 + accuracy: 0.5926499962806702, + loss: 1.775016188621521 } }, - name: 'exp-e7a67', - params: { 'params.yaml': { dropout: 0.15, epochs: 2 } }, + name: 'exp-83425', + params: { 'params.yaml': { dropout: 0.124, epochs: 5 } }, checkpoints: [ { id: '12345', label: '123', metrics: { 'summary.json': { - accuracy: 0.3724166750907898, - loss: 2.0205044746398926 + accuracy: 0.5926499962806702, + loss: 1.775016188621521 } }, - name: 'exp-e7a67', - params: { 'params.yaml': { dropout: 0.15, epochs: 2 } } + name: 'exp-83425', + params: { 'params.yaml': { dropout: 0.124, epochs: 5 } } }, { id: '12345', label: '123', metrics: { 'summary.json': { - accuracy: 0.3723166584968567, - loss: 2.020392894744873 + accuracy: 0.557449996471405, + loss: 1.8261293172836304 } }, - name: 'exp-e7a67', - params: { 'params.yaml': { dropout: 0.15, epochs: 2 } } + name: 'exp-83425', + params: { 'params.yaml': { dropout: 0.124, epochs: 5 } } + }, + { + id: '12345', + label: '123', + metrics: { + 'summary.json': { + accuracy: 0.5113166570663452, + loss: 1.8798457384109497 + } + }, + name: 'exp-83425', + params: { 'params.yaml': { dropout: 0.124, epochs: 5 } } + }, + { + id: '12345', + label: '123', + metrics: { + 'summary.json': { + accuracy: 0.46094998717308044, + loss: 1.9329891204833984 + } + }, + name: 'exp-83425', + params: { 'params.yaml': { dropout: 0.124, epochs: 5 } } + }, + { + id: '12345', + label: '123', + metrics: { + 'summary.json': { + accuracy: 0.40904998779296875, + loss: 1.9896177053451538 + } + }, + name: 'exp-83425', + params: { 'params.yaml': { dropout: 0.124, epochs: 5 } } } ] } @@ -249,18 +248,18 @@ const data: CustomPlotsData = { metric: 'summary.json:loss', param: CHECKPOINTS_PARAM, values: [ - { group: 'exp-83425', iteration: 1, y: 1.9896177053451538 }, - { group: 'exp-83425', iteration: 2, y: 1.9329891204833984 }, - { group: 'exp-83425', iteration: 3, y: 1.8798457384109497 }, - { group: 'exp-83425', iteration: 4, y: 1.8261293172836304 }, - { group: 'exp-83425', iteration: 5, y: 1.775016188621521 }, - { group: 'exp-83425', iteration: 6, y: 1.775016188621521 }, - { group: 'test-branch', iteration: 1, y: 1.9882521629333496 }, - { group: 'test-branch', iteration: 2, y: 1.9293040037155151 }, - { group: 'test-branch', iteration: 3, y: 1.9293040037155151 }, - { group: 'exp-e7a67', iteration: 1, y: 2.020392894744873 }, + { group: 'exp-e7a67', iteration: 3, y: 2.0205044746398926 }, { group: 'exp-e7a67', iteration: 2, y: 2.0205044746398926 }, - { group: 'exp-e7a67', iteration: 3, y: 2.0205044746398926 } + { group: 'exp-e7a67', iteration: 1, y: 2.020392894744873 }, + { group: 'test-branch', iteration: 3, y: 1.9293040037155151 }, + { group: 'test-branch', iteration: 2, y: 1.9293040037155151 }, + { group: 'test-branch', iteration: 1, y: 1.9882521629333496 }, + { group: 'exp-83425', iteration: 6, y: 1.775016188621521 }, + { group: 'exp-83425', iteration: 5, y: 1.775016188621521 }, + { group: 'exp-83425', iteration: 4, y: 1.8261293172836304 }, + { group: 'exp-83425', iteration: 3, y: 1.8798457384109497 }, + { group: 'exp-83425', iteration: 2, y: 1.9329891204833984 }, + { group: 'exp-83425', iteration: 1, y: 1.9896177053451538 } ], type: CustomPlotType.CHECKPOINT, yTitle: 'summary.json:loss' @@ -270,18 +269,18 @@ const data: CustomPlotsData = { metric: 'summary.json:accuracy', param: CHECKPOINTS_PARAM, values: [ - { group: 'exp-83425', iteration: 1, y: 0.40904998779296875 }, - { group: 'exp-83425', iteration: 2, y: 0.46094998717308044 }, - { group: 'exp-83425', iteration: 3, y: 0.5113166570663452 }, - { group: 'exp-83425', iteration: 4, y: 0.557449996471405 }, - { group: 'exp-83425', iteration: 5, y: 0.5926499962806702 }, - { group: 'exp-83425', iteration: 6, y: 0.5926499962806702 }, - { group: 'test-branch', iteration: 1, y: 0.4083833396434784 }, - { group: 'test-branch', iteration: 2, y: 0.4668000042438507 }, - { group: 'test-branch', iteration: 3, y: 0.4668000042438507 }, - { group: 'exp-e7a67', iteration: 1, y: 0.3723166584968567 }, + { group: 'exp-e7a67', iteration: 3, y: 0.3724166750907898 }, { group: 'exp-e7a67', iteration: 2, y: 0.3724166750907898 }, - { group: 'exp-e7a67', iteration: 3, y: 0.3724166750907898 } + { group: 'exp-e7a67', iteration: 1, y: 0.3723166584968567 }, + { group: 'test-branch', iteration: 3, y: 0.4668000042438507 }, + { group: 'test-branch', iteration: 2, y: 0.4668000042438507 }, + { group: 'test-branch', iteration: 1, y: 0.4083833396434784 }, + { group: 'exp-83425', iteration: 6, y: 0.5926499962806702 }, + { group: 'exp-83425', iteration: 5, y: 0.5926499962806702 }, + { group: 'exp-83425', iteration: 4, y: 0.557449996471405 }, + { group: 'exp-83425', iteration: 3, y: 0.5113166570663452 }, + { group: 'exp-83425', iteration: 2, y: 0.46094998717308044 }, + { group: 'exp-83425', iteration: 1, y: 0.40904998779296875 } ], type: CustomPlotType.CHECKPOINT, yTitle: 'summary.json:accuracy' From 5132b838a4a61057f49a9596cf9e1217a7d84743 Mon Sep 17 00:00:00 2001 From: julieg18 Date: Wed, 15 Mar 2023 09:20:07 -0500 Subject: [PATCH 03/12] Consolidate some more * have the code just have separate funcs for values creation --- extension/src/plots/model/collect.test.ts | 3 +- extension/src/plots/model/collect.ts | 110 +++++++++++----------- extension/src/plots/model/index.ts | 3 - 3 files changed, 56 insertions(+), 60 deletions(-) diff --git a/extension/src/plots/model/collect.test.ts b/extension/src/plots/model/collect.test.ts index ce53fa017e..b568a1fb5d 100644 --- a/extension/src/plots/model/collect.test.ts +++ b/extension/src/plots/model/collect.test.ts @@ -30,7 +30,8 @@ import { Experiment } from '../../experiments/webview/contract' const logsLossPath = join('logs', 'loss.tsv') const logsLossPlot = (plotsDiffFixture[logsLossPath][0] || {}) as TemplatePlot -// missing tests here (collect function are still a WIP) +// missing function tests here collection function are still a WIP) + const getCustomPlotFromCustomPlotData = ({ id, metric, diff --git a/extension/src/plots/model/collect.ts b/extension/src/plots/model/collect.ts index 2274e81afd..0a72c67639 100644 --- a/extension/src/plots/model/collect.ts +++ b/extension/src/plots/model/collect.ts @@ -20,11 +20,9 @@ import { TemplatePlotSection, PlotsType, Revision, - CustomPlotType, CustomPlot, - MetricVsParamPlot, CustomPlotData, - CheckpointPlot + MetricVsParamPlotValues } from '../webview/contract' import { EXPERIMENT_WORKSPACE_ID, PlotsOutput } from '../../cli/dvc/contract' import { @@ -65,77 +63,54 @@ export const getCustomPlotId = (metric: string, param = CHECKPOINTS_PARAM) => const getExperimentValues = ( values: CheckpointPlotValues, exp: ExperimentWithDefinedCheckpoints, - metric: string + splitUpMetricPath: string[] ) => { - const splitMetric = splitColumnPath( - getFullValuePath(ColumnType.METRICS, metric, FILE_SEPARATOR) - ) const group = exp.name || exp.label const expEpochLength = exp.checkpoints.length + 1 - const y = get(exp, splitMetric) as number | undefined + const y = get(exp, splitUpMetricPath) as number | undefined if (y !== undefined) { values.push({ group, iteration: expEpochLength, y }) } for (const [ind, checkpoint] of exp.checkpoints.entries()) { - const y = get(checkpoint, splitMetric) as number | undefined + const y = get(checkpoint, splitUpMetricPath) as number | undefined if (y !== undefined) { values.push({ group, iteration: expEpochLength - ind - 1, y }) } } } -// I belive we can further combine these -// leaving the only separate thing being the -// creation of the values -const collectCheckpointPlot = ( - metric: string, - experiments: ExperimentWithCheckpoints[] -): CheckpointPlot => { + +const collectCheckpointValues = ( + experiments: ExperimentWithCheckpoints[], + splitUpMetricPath: string[] +): CheckpointPlotValues => { const fullValues: CheckpointPlotValues = [] for (const experiment of experiments) { if (experiment.checkpoints) { getExperimentValues( fullValues, experiment as ExperimentWithDefinedCheckpoints, - metric + splitUpMetricPath ) } } - return { - id: getCustomPlotId(metric), - metric, - param: CHECKPOINTS_PARAM, - type: CustomPlotType.CHECKPOINT, - values: fullValues - } + return fullValues } -const collectMetricVsParamPlot = ( - metric: string, - param: string, - experiments: Experiment[] -): MetricVsParamPlot => { - const splitUpMetricPath = splitColumnPath( - getFullValuePath(ColumnType.METRICS, metric, FILE_SEPARATOR) - ) - const splitUpParamPath = splitColumnPath( - getFullValuePath(ColumnType.PARAMS, param, FILE_SEPARATOR) - ) - const plotData: MetricVsParamPlot = { - id: getCustomPlotId(metric, param), - metric, - param, - type: CustomPlotType.METRIC_VS_PARAM, - values: [] - } +const collectMetricVsParamValues = ( + experiments: ExperimentWithCheckpoints[], + splitUpMetricPath: string[], + splitUpParamPath: string[] +): MetricVsParamPlotValues => { + const fullValues: MetricVsParamPlotValues = [] for (const experiment of experiments) { const metricValue = get(experiment, splitUpMetricPath) as number | undefined const paramValue = get(experiment, splitUpParamPath) as number | undefined if (metricValue !== undefined && paramValue !== undefined) { - plotData.values.push({ + fullValues.push({ expName: experiment.name || experiment.label, metric: metricValue, param: paramValue @@ -143,24 +118,47 @@ const collectMetricVsParamPlot = ( } } - return plotData + return fullValues +} + +const collectCustomPlot = ( + orderValue: CustomPlotsOrderValue, + experiments: ExperimentWithCheckpoints[] +): CustomPlot => { + const { metric, param, type } = orderValue + const splitUpMetricPath = splitColumnPath( + getFullValuePath(ColumnType.METRICS, metric, FILE_SEPARATOR) + ) + const splitUpParamPath = splitColumnPath( + getFullValuePath(ColumnType.PARAMS, param, FILE_SEPARATOR) + ) + // were looping over the experiments each time with each plot + // I wonder if we could loop over them once and update an object of custom plots + // all at once vs creating each plot one at a time + const values = isCheckpointValue(type) + ? collectCheckpointValues(experiments, splitUpMetricPath) + : collectMetricVsParamValues( + experiments, + splitUpMetricPath, + splitUpParamPath + ) + + return { + id: getCustomPlotId(metric, param), + metric, + param, + type, + values + } as CustomPlot } export const collectCustomPlots = ( plotsOrderValues: CustomPlotsOrderValue[], - experiments: Experiment[] + experiments: ExperimentWithCheckpoints[] ): CustomPlot[] => { - return plotsOrderValues - .map((plotOrderValue): CustomPlot => { - if (isCheckpointValue(plotOrderValue.type)) { - const { metric } = plotOrderValue - - return collectCheckpointPlot(metric, experiments) - } - const { metric, param } = plotOrderValue - return collectMetricVsParamPlot(metric, param, experiments) - }) - .filter(Boolean) + return plotsOrderValues.map(plotOrderValue => + collectCustomPlot(plotOrderValue, experiments) + ) } export const collectCustomPlotData = ( diff --git a/extension/src/plots/model/index.ts b/extension/src/plots/model/index.ts index eb847dbcd5..b5a5a0b3ca 100644 --- a/extension/src/plots/model/index.ts +++ b/extension/src/plots/model/index.ts @@ -460,9 +460,6 @@ export class PlotsModel extends ModelWithPersistence { const selectedExperimentsExist = !!colors const filteredPlots: CustomPlotData[] = [] for (const plot of plots) { - // i wonder if we could filter this when - // were doing the collection part now - // instead of after... if (!selectedExperimentsExist && isCheckpointPlot(plot)) { continue } From ee1b7389de35c0da08952e96bf31eb757c369278 Mon Sep 17 00:00:00 2001 From: julieg18 Date: Wed, 15 Mar 2023 10:42:25 -0500 Subject: [PATCH 04/12] Try to make naming more clear --- extension/src/plots/model/collect.ts | 38 ++++++++++++---------------- 1 file changed, 16 insertions(+), 22 deletions(-) diff --git a/extension/src/plots/model/collect.ts b/extension/src/plots/model/collect.ts index 0a72c67639..e04cc528cf 100644 --- a/extension/src/plots/model/collect.ts +++ b/extension/src/plots/model/collect.ts @@ -60,35 +60,35 @@ import { Color } from '../../experiments/model/status/colors' export const getCustomPlotId = (metric: string, param = CHECKPOINTS_PARAM) => `custom-${metric}-${param}` -const getExperimentValues = ( +const collectCheckpointValuesFromExperiment = ( values: CheckpointPlotValues, exp: ExperimentWithDefinedCheckpoints, splitUpMetricPath: string[] ) => { const group = exp.name || exp.label - const expEpochLength = exp.checkpoints.length + 1 + const maxEpoch = exp.checkpoints.length + 1 - const y = get(exp, splitUpMetricPath) as number | undefined - if (y !== undefined) { - values.push({ group, iteration: expEpochLength, y }) + const metricValue = get(exp, splitUpMetricPath) as number | undefined + if (metricValue !== undefined) { + values.push({ group, iteration: maxEpoch, y: metricValue }) } for (const [ind, checkpoint] of exp.checkpoints.entries()) { - const y = get(checkpoint, splitUpMetricPath) as number | undefined - if (y !== undefined) { - values.push({ group, iteration: expEpochLength - ind - 1, y }) + const metricValue = get(checkpoint, splitUpMetricPath) as number | undefined + if (metricValue !== undefined) { + values.push({ group, iteration: maxEpoch - ind - 1, y: metricValue }) } } } -const collectCheckpointValues = ( +const getCheckpointValues = ( experiments: ExperimentWithCheckpoints[], splitUpMetricPath: string[] ): CheckpointPlotValues => { const fullValues: CheckpointPlotValues = [] for (const experiment of experiments) { if (experiment.checkpoints) { - getExperimentValues( + collectCheckpointValuesFromExperiment( fullValues, experiment as ExperimentWithDefinedCheckpoints, splitUpMetricPath @@ -98,7 +98,7 @@ const collectCheckpointValues = ( return fullValues } -const collectMetricVsParamValues = ( +const getMetricVsParamValues = ( experiments: ExperimentWithCheckpoints[], splitUpMetricPath: string[], splitUpParamPath: string[] @@ -121,7 +121,7 @@ const collectMetricVsParamValues = ( return fullValues } -const collectCustomPlot = ( +const getCustomPlot = ( orderValue: CustomPlotsOrderValue, experiments: ExperimentWithCheckpoints[] ): CustomPlot => { @@ -132,16 +132,10 @@ const collectCustomPlot = ( const splitUpParamPath = splitColumnPath( getFullValuePath(ColumnType.PARAMS, param, FILE_SEPARATOR) ) - // were looping over the experiments each time with each plot - // I wonder if we could loop over them once and update an object of custom plots - // all at once vs creating each plot one at a time + const values = isCheckpointValue(type) - ? collectCheckpointValues(experiments, splitUpMetricPath) - : collectMetricVsParamValues( - experiments, - splitUpMetricPath, - splitUpParamPath - ) + ? getCheckpointValues(experiments, splitUpMetricPath) + : getMetricVsParamValues(experiments, splitUpMetricPath, splitUpParamPath) return { id: getCustomPlotId(metric, param), @@ -157,7 +151,7 @@ export const collectCustomPlots = ( experiments: ExperimentWithCheckpoints[] ): CustomPlot[] => { return plotsOrderValues.map(plotOrderValue => - collectCustomPlot(plotOrderValue, experiments) + getCustomPlot(plotOrderValue, experiments) ) } From 3e80fb88c180890c2952644fba71fb184d390517 Mon Sep 17 00:00:00 2001 From: julieg18 Date: Wed, 15 Mar 2023 10:59:16 -0500 Subject: [PATCH 05/12] Rewrite CustomPlot types --- extension/src/plots/webview/contract.ts | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/extension/src/plots/webview/contract.ts b/extension/src/plots/webview/contract.ts index e512a6dadb..852feecad2 100644 --- a/extension/src/plots/webview/contract.ts +++ b/extension/src/plots/webview/contract.ts @@ -89,21 +89,21 @@ export type CheckpointPlotValues = { export type ColorScale = { domain: string[]; range: Color[] } -export type CheckpointPlot = { +type CustomPlotBase = { id: string - values: CheckpointPlotValues metric: string param: string - type: CustomPlotType.CHECKPOINT } +export type CheckpointPlot = { + values: CheckpointPlotValues + type: CustomPlotType.CHECKPOINT +} & CustomPlotBase + export type MetricVsParamPlot = { - id: string values: MetricVsParamPlotValues - metric: string - param: string type: CustomPlotType.METRIC_VS_PARAM -} +} & CustomPlotBase export type CustomPlot = MetricVsParamPlot | CheckpointPlot From 468ac5a67728458579ebfd3cc466b98b533d032a Mon Sep 17 00:00:00 2001 From: julieg18 Date: Wed, 15 Mar 2023 11:26:28 -0500 Subject: [PATCH 06/12] Fix broken no checkpoint experiments --- extension/src/plots/model/index.ts | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/extension/src/plots/model/index.ts b/extension/src/plots/model/index.ts index b5a5a0b3ca..401174f7d4 100644 --- a/extension/src/plots/model/index.ts +++ b/extension/src/plots/model/index.ts @@ -169,18 +169,18 @@ export class PlotsModel extends ModelWithPersistence { } public recreateCustomPlots() { - const allExperiments = this.experiments.getExperiments() - const experimentsWithCheckpoints = this.experiments - .getRowData() - .filter(({ checkpoints }) => !!checkpoints) + const experiments = this.experiments.hasCheckpoints() + ? this.experiments.getRowData().filter(({ checkpoints }) => !!checkpoints) + : this.experiments.getExperiments() - if (allExperiments.length === 0) { + if (experiments.length === 0) { this.customPlots = undefined return } + const customPlots: CustomPlot[] = collectCustomPlots( this.getCustomPlotsOrder(), - experimentsWithCheckpoints + experiments ) this.customPlots = customPlots } From 51be88a069a316b3e36bc8fa0f3db86c1817a3b3 Mon Sep 17 00:00:00 2001 From: julieg18 Date: Thu, 16 Mar 2023 19:25:17 -0500 Subject: [PATCH 07/12] Fix typo --- extension/src/experiments/index.ts | 2 +- extension/src/plots/model/index.ts | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/extension/src/experiments/index.ts b/extension/src/experiments/index.ts index a0c1aaa50a..7ed8b1ee3f 100644 --- a/extension/src/experiments/index.ts +++ b/extension/src/experiments/index.ts @@ -333,7 +333,7 @@ export class Experiments extends BaseRepository { return this.experiments.getExperimentCount() } - public getRowData() { + public getExperimentsWithCheckpoints() { return this.experiments.getExperimentsWithCheckpoints() } diff --git a/extension/src/plots/model/index.ts b/extension/src/plots/model/index.ts index 401174f7d4..2c61ba2ec9 100644 --- a/extension/src/plots/model/index.ts +++ b/extension/src/plots/model/index.ts @@ -170,7 +170,9 @@ export class PlotsModel extends ModelWithPersistence { public recreateCustomPlots() { const experiments = this.experiments.hasCheckpoints() - ? this.experiments.getRowData().filter(({ checkpoints }) => !!checkpoints) + ? this.experiments + .getExperimentsWithCheckpoints() + .filter(({ checkpoints }) => !!checkpoints) : this.experiments.getExperiments() if (experiments.length === 0) { From de74af96ea8ba6fb805a256e1688ad059aabcfd9 Mon Sep 17 00:00:00 2001 From: julieg18 Date: Thu, 16 Mar 2023 19:45:30 -0500 Subject: [PATCH 08/12] Clean up a bit --- extension/src/plots/model/collect.test.ts | 1 - extension/src/plots/model/collect.ts | 38 +++++++++++++---------- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/extension/src/plots/model/collect.test.ts b/extension/src/plots/model/collect.test.ts index b568a1fb5d..c31bf79162 100644 --- a/extension/src/plots/model/collect.test.ts +++ b/extension/src/plots/model/collect.test.ts @@ -30,7 +30,6 @@ import { Experiment } from '../../experiments/webview/contract' const logsLossPath = join('logs', 'loss.tsv') const logsLossPlot = (plotsDiffFixture[logsLossPath][0] || {}) as TemplatePlot -// missing function tests here collection function are still a WIP) const getCustomPlotFromCustomPlotData = ({ id, diff --git a/extension/src/plots/model/collect.ts b/extension/src/plots/model/collect.ts index e04cc528cf..9227662732 100644 --- a/extension/src/plots/model/collect.ts +++ b/extension/src/plots/model/collect.ts @@ -60,21 +60,26 @@ import { Color } from '../../experiments/model/status/colors' export const getCustomPlotId = (metric: string, param = CHECKPOINTS_PARAM) => `custom-${metric}-${param}` +export const getValueFromColumn = ( + path: string, + experiment: ExperimentWithCheckpoints +) => get(experiment, splitColumnPath(path)) as number | undefined + const collectCheckpointValuesFromExperiment = ( values: CheckpointPlotValues, exp: ExperimentWithDefinedCheckpoints, - splitUpMetricPath: string[] + metricPath: string ) => { const group = exp.name || exp.label const maxEpoch = exp.checkpoints.length + 1 - const metricValue = get(exp, splitUpMetricPath) as number | undefined + const metricValue = getValueFromColumn(metricPath, exp) if (metricValue !== undefined) { values.push({ group, iteration: maxEpoch, y: metricValue }) } for (const [ind, checkpoint] of exp.checkpoints.entries()) { - const metricValue = get(checkpoint, splitUpMetricPath) as number | undefined + const metricValue = getValueFromColumn(metricPath, checkpoint) if (metricValue !== undefined) { values.push({ group, iteration: maxEpoch - ind - 1, y: metricValue }) } @@ -83,7 +88,7 @@ const collectCheckpointValuesFromExperiment = ( const getCheckpointValues = ( experiments: ExperimentWithCheckpoints[], - splitUpMetricPath: string[] + metricPath: string ): CheckpointPlotValues => { const fullValues: CheckpointPlotValues = [] for (const experiment of experiments) { @@ -91,7 +96,7 @@ const getCheckpointValues = ( collectCheckpointValuesFromExperiment( fullValues, experiment as ExperimentWithDefinedCheckpoints, - splitUpMetricPath + metricPath ) } } @@ -100,14 +105,14 @@ const getCheckpointValues = ( const getMetricVsParamValues = ( experiments: ExperimentWithCheckpoints[], - splitUpMetricPath: string[], - splitUpParamPath: string[] + metricPath: string, + paramPath: string ): MetricVsParamPlotValues => { const fullValues: MetricVsParamPlotValues = [] for (const experiment of experiments) { - const metricValue = get(experiment, splitUpMetricPath) as number | undefined - const paramValue = get(experiment, splitUpParamPath) as number | undefined + const metricValue = getValueFromColumn(metricPath, experiment) + const paramValue = getValueFromColumn(paramPath, experiment) if (metricValue !== undefined && paramValue !== undefined) { fullValues.push({ @@ -126,16 +131,17 @@ const getCustomPlot = ( experiments: ExperimentWithCheckpoints[] ): CustomPlot => { const { metric, param, type } = orderValue - const splitUpMetricPath = splitColumnPath( - getFullValuePath(ColumnType.METRICS, metric, FILE_SEPARATOR) - ) - const splitUpParamPath = splitColumnPath( - getFullValuePath(ColumnType.PARAMS, param, FILE_SEPARATOR) + const metricPath = getFullValuePath( + ColumnType.METRICS, + metric, + FILE_SEPARATOR ) + const paramPath = getFullValuePath(ColumnType.PARAMS, param, FILE_SEPARATOR) + const values = isCheckpointValue(type) - ? getCheckpointValues(experiments, splitUpMetricPath) - : getMetricVsParamValues(experiments, splitUpMetricPath, splitUpParamPath) + ? getCheckpointValues(experiments, metricPath) + : getMetricVsParamValues(experiments, metricPath, paramPath) return { id: getCustomPlotId(metric, param), From 877e1e81aafb6e3f38cecb24766175c560f3c98d Mon Sep 17 00:00:00 2001 From: julieg18 Date: Fri, 17 Mar 2023 08:42:55 -0500 Subject: [PATCH 09/12] Refactor based off comments * fullValues to values * use typeguard on ExperimentsWithDefinedCheckpoints --- extension/src/plots/model/collect.ts | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/extension/src/plots/model/collect.ts b/extension/src/plots/model/collect.ts index 9227662732..3493545017 100644 --- a/extension/src/plots/model/collect.ts +++ b/extension/src/plots/model/collect.ts @@ -52,19 +52,23 @@ import { import { StrokeDashEncoding } from '../multiSource/constants' import { ExperimentWithCheckpoints, - SelectedExperimentWithColor, - ExperimentWithDefinedCheckpoints + ExperimentWithDefinedCheckpoints, + SelectedExperimentWithColor } from '../../experiments/model' import { Color } from '../../experiments/model/status/colors' export const getCustomPlotId = (metric: string, param = CHECKPOINTS_PARAM) => `custom-${metric}-${param}` -export const getValueFromColumn = ( +const getValueFromColumn = ( path: string, experiment: ExperimentWithCheckpoints ) => get(experiment, splitColumnPath(path)) as number | undefined +const isExperimentWithDefinedCheckpoints = ( + experiment: ExperimentWithCheckpoints +): experiment is ExperimentWithDefinedCheckpoints => !!experiment.checkpoints + const collectCheckpointValuesFromExperiment = ( values: CheckpointPlotValues, exp: ExperimentWithDefinedCheckpoints, @@ -90,17 +94,13 @@ const getCheckpointValues = ( experiments: ExperimentWithCheckpoints[], metricPath: string ): CheckpointPlotValues => { - const fullValues: CheckpointPlotValues = [] + const values: CheckpointPlotValues = [] for (const experiment of experiments) { - if (experiment.checkpoints) { - collectCheckpointValuesFromExperiment( - fullValues, - experiment as ExperimentWithDefinedCheckpoints, - metricPath - ) + if (isExperimentWithDefinedCheckpoints(experiment)) { + collectCheckpointValuesFromExperiment(values, experiment, metricPath) } } - return fullValues + return values } const getMetricVsParamValues = ( @@ -108,14 +108,14 @@ const getMetricVsParamValues = ( metricPath: string, paramPath: string ): MetricVsParamPlotValues => { - const fullValues: MetricVsParamPlotValues = [] + const values: MetricVsParamPlotValues = [] for (const experiment of experiments) { const metricValue = getValueFromColumn(metricPath, experiment) const paramValue = getValueFromColumn(paramPath, experiment) if (metricValue !== undefined && paramValue !== undefined) { - fullValues.push({ + values.push({ expName: experiment.name || experiment.label, metric: metricValue, param: paramValue @@ -123,7 +123,7 @@ const getMetricVsParamValues = ( } } - return fullValues + return values } const getCustomPlot = ( From 97aba851ad3a998424c8daa4ed90e862e4a20a07 Mon Sep 17 00:00:00 2001 From: julieg18 Date: Fri, 17 Mar 2023 09:04:28 -0500 Subject: [PATCH 10/12] Make commit/workspace exp filtering more clear --- extension/src/plots/model/index.ts | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/extension/src/plots/model/index.ts b/extension/src/plots/model/index.ts index 2c61ba2ec9..072886546f 100644 --- a/extension/src/plots/model/index.ts +++ b/extension/src/plots/model/index.ts @@ -121,7 +121,6 @@ export class PlotsModel extends ModelWithPersistence { collectTemplates(data), collectMultiSourceVariations(data, this.multiSourceVariations) ]) - this.recreateCustomPlots() this.comparisonData = { ...this.comparisonData, @@ -169,11 +168,9 @@ export class PlotsModel extends ModelWithPersistence { } public recreateCustomPlots() { - const experiments = this.experiments.hasCheckpoints() - ? this.experiments - .getExperimentsWithCheckpoints() - .filter(({ checkpoints }) => !!checkpoints) - : this.experiments.getExperiments() + const experiments = this.experiments + .getExperimentsWithCheckpoints() + .filter(({ commit, id }) => !commit && id !== EXPERIMENT_WORKSPACE_ID) if (experiments.length === 0) { this.customPlots = undefined From 9de1c424c03548249fc8848e9525d73f65c591ca Mon Sep 17 00:00:00 2001 From: julieg18 Date: Fri, 17 Mar 2023 11:07:42 -0500 Subject: [PATCH 11/12] Undo change that breaks vscode data --- extension/src/plots/model/index.ts | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/extension/src/plots/model/index.ts b/extension/src/plots/model/index.ts index 072886546f..d4431d0e7d 100644 --- a/extension/src/plots/model/index.ts +++ b/extension/src/plots/model/index.ts @@ -168,18 +168,20 @@ export class PlotsModel extends ModelWithPersistence { } public recreateCustomPlots() { - const experiments = this.experiments - .getExperimentsWithCheckpoints() - .filter(({ commit, id }) => !commit && id !== EXPERIMENT_WORKSPACE_ID) + const experimentsWithNoCommitData = this.experiments.hasCheckpoints() + ? this.experiments + .getExperimentsWithCheckpoints() + .filter(({ checkpoints }) => !!checkpoints) + : this.experiments.getExperiments() - if (experiments.length === 0) { + if (experimentsWithNoCommitData.length === 0) { this.customPlots = undefined return } const customPlots: CustomPlot[] = collectCustomPlots( this.getCustomPlotsOrder(), - experiments + experimentsWithNoCommitData ) this.customPlots = customPlots } From 0f438d87d57d629543a7dad84938f063e64951b7 Mon Sep 17 00:00:00 2001 From: Julie G <43496356+julieg18@users.noreply.github.com> Date: Mon, 20 Mar 2023 09:08:59 -0500 Subject: [PATCH 12/12] Create custom plots when plots are requested (#3491) --- extension/src/plots/model/collect.test.ts | 120 +++++++++------------- extension/src/plots/model/collect.ts | 76 ++++++++------ extension/src/plots/model/index.ts | 87 +++++----------- extension/src/plots/webview/messages.ts | 14 ++- 4 files changed, 133 insertions(+), 164 deletions(-) diff --git a/extension/src/plots/model/collect.test.ts b/extension/src/plots/model/collect.test.ts index c31bf79162..295beb3e65 100644 --- a/extension/src/plots/model/collect.test.ts +++ b/extension/src/plots/model/collect.test.ts @@ -4,9 +4,9 @@ import { collectData, collectTemplates, collectOverrideRevisionDetails, - collectCustomPlots, - collectCustomPlotData + collectCustomPlots } from './collect' +import { isCheckpointPlot } from './custom' import plotsDiffFixture from '../../test/fixtures/plotsDiff/output' import customPlotsFixture, { customPlotsOrderFixture, @@ -18,9 +18,10 @@ import { } from '../../cli/dvc/contract' import { sameContents } from '../../util/array' import { - CheckpointPlot, - CustomPlot, CustomPlotData, + CustomPlotType, + DEFAULT_NB_ITEMS_PER_ROW, + DEFAULT_PLOT_HEIGHT, TemplatePlot } from '../webview/contract' import { getCLICommitId } from '../../test/fixtures/plotsDiff/util' @@ -31,81 +32,62 @@ const logsLossPath = join('logs', 'loss.tsv') const logsLossPlot = (plotsDiffFixture[logsLossPath][0] || {}) as TemplatePlot -const getCustomPlotFromCustomPlotData = ({ - id, - metric, - param, - type, - values -}: CustomPlotData) => - ({ - id, - metric, - param, - type, - values - } as CustomPlot) - describe('collectCustomPlots', () => { + const defaultFuncArgs = { + experiments: experimentsWithCheckpoints, + hasCheckpoints: true, + height: DEFAULT_PLOT_HEIGHT, + nbItemsPerRow: DEFAULT_NB_ITEMS_PER_ROW, + plotsOrderValues: customPlotsOrderFixture, + selectedRevisions: customPlotsFixture.colors?.domain + } + it('should return the expected data from the test fixture', () => { - const expectedOutput: CustomPlot[] = customPlotsFixture.plots.map( - getCustomPlotFromCustomPlotData - ) - const data = collectCustomPlots( - customPlotsOrderFixture, - experimentsWithCheckpoints - ) + const expectedOutput: CustomPlotData[] = customPlotsFixture.plots + const data = collectCustomPlots(defaultFuncArgs) expect(data).toStrictEqual(expectedOutput) }) -}) -describe('collectCustomPlotData', () => { - it('should return the expected data from test fixture', () => { - const expectedMetricVsParamPlotData = customPlotsFixture.plots[0] - const expectedCheckpointsPlotData = customPlotsFixture.plots[2] - const metricVsParamPlot = getCustomPlotFromCustomPlotData( - expectedMetricVsParamPlotData - ) - const checkpointsPlot = getCustomPlotFromCustomPlotData( - expectedCheckpointsPlotData + it('should return only custom plots if there no selected revisions', () => { + const expectedOutput: CustomPlotData[] = customPlotsFixture.plots.filter( + plot => plot.type !== CustomPlotType.CHECKPOINT ) + const data = collectCustomPlots({ + ...defaultFuncArgs, + selectedRevisions: undefined + }) - const metricVsParamData = collectCustomPlotData( - metricVsParamPlot, - customPlotsFixture.colors, - customPlotsFixture.nbItemsPerRow, - customPlotsFixture.height - ) + expect(data).toStrictEqual(expectedOutput) + }) - const checkpointsData = collectCustomPlotData( - { - ...checkpointsPlot, - values: [ - ...checkpointsPlot.values, - { - group: 'exp-123', - iteration: 1, - y: 1.4534177053451538 - }, - { - group: 'exp-123', - iteration: 2, - y: 1.757687 - }, - { - group: 'exp-123', - iteration: 3, - y: 1.989894 - } - ] - } as CheckpointPlot, - customPlotsFixture.colors, - customPlotsFixture.nbItemsPerRow, - customPlotsFixture.height + it('should return only custom plots if checkpoints are not enabled', () => { + const expectedOutput: CustomPlotData[] = customPlotsFixture.plots.filter( + plot => plot.type !== CustomPlotType.CHECKPOINT ) + const data = collectCustomPlots({ + ...defaultFuncArgs, + hasCheckpoints: false + }) - expect(metricVsParamData).toStrictEqual(expectedMetricVsParamPlotData) - expect(checkpointsData).toStrictEqual(expectedCheckpointsPlotData) + expect(data).toStrictEqual(expectedOutput) + }) + + it('should return checkpoint plots with values only containing selected experiments data', () => { + const domain = customPlotsFixture.colors?.domain.slice(1) as string[] + + const expectedOutput = customPlotsFixture.plots.map(plot => ({ + ...plot, + values: isCheckpointPlot(plot) + ? plot.values.filter(value => domain.includes(value.group)) + : plot.values + })) + + const data = collectCustomPlots({ + ...defaultFuncArgs, + selectedRevisions: domain + }) + + expect(data).toStrictEqual(expectedOutput) }) }) diff --git a/extension/src/plots/model/collect.ts b/extension/src/plots/model/collect.ts index 3493545017..7612e164ae 100644 --- a/extension/src/plots/model/collect.ts +++ b/extension/src/plots/model/collect.ts @@ -5,7 +5,6 @@ import { getFullValuePath, CHECKPOINTS_PARAM, CustomPlotsOrderValue, - isCheckpointPlot, isCheckpointValue } from './custom' import { getRevisionFirstThreeColumns } from './util' @@ -20,7 +19,6 @@ import { TemplatePlotSection, PlotsType, Revision, - CustomPlot, CustomPlotData, MetricVsParamPlotValues } from '../webview/contract' @@ -126,10 +124,13 @@ const getMetricVsParamValues = ( return values } -const getCustomPlot = ( +const getCustomPlotData = ( orderValue: CustomPlotsOrderValue, - experiments: ExperimentWithCheckpoints[] -): CustomPlot => { + experiments: ExperimentWithCheckpoints[], + selectedRevisions: string[] | undefined = [], + height: number, + nbItemsPerRow: number +): CustomPlotData => { const { metric, param, type } = orderValue const metricPath = getFullValuePath( ColumnType.METRICS, @@ -139,8 +140,12 @@ const getCustomPlot = ( const paramPath = getFullValuePath(ColumnType.PARAMS, param, FILE_SEPARATOR) + const selectedExperiments = experiments.filter(({ name, label }) => + selectedRevisions.includes(name || label) + ) + const values = isCheckpointValue(type) - ? getCheckpointValues(experiments, metricPath) + ? getCheckpointValues(selectedExperiments, metricPath) : getMetricVsParamValues(experiments, metricPath, paramPath) return { @@ -148,37 +153,46 @@ const getCustomPlot = ( metric, param, type, - values - } as CustomPlot + values, + yTitle: truncateVerticalTitle(metric, nbItemsPerRow, height) as string + } as CustomPlotData } -export const collectCustomPlots = ( - plotsOrderValues: CustomPlotsOrderValue[], +export const collectCustomPlots = ({ + plotsOrderValues, + experiments, + hasCheckpoints, + selectedRevisions, + height, + nbItemsPerRow +}: { + plotsOrderValues: CustomPlotsOrderValue[] experiments: ExperimentWithCheckpoints[] -): CustomPlot[] => { - return plotsOrderValues.map(plotOrderValue => - getCustomPlot(plotOrderValue, experiments) - ) -} - -export const collectCustomPlotData = ( - plot: CustomPlot, - colors: ColorScale | undefined, - nbItemsPerRow: number, + hasCheckpoints: boolean + selectedRevisions: string[] | undefined height: number -): CustomPlotData => { - const selectedExperiments = colors?.domain - const filteredValues = isCheckpointPlot(plot) - ? plot.values.filter(value => - (selectedExperiments as string[]).includes(value.group) + nbItemsPerRow: number +}): CustomPlotData[] => { + const plots = [] + const shouldSkipCheckpointPlots = !hasCheckpoints || !selectedRevisions + + for (const value of plotsOrderValues) { + if (shouldSkipCheckpointPlots && isCheckpointValue(value.type)) { + continue + } + + plots.push( + getCustomPlotData( + value, + experiments, + selectedRevisions, + height, + nbItemsPerRow ) - : plot.values + ) + } - return { - ...plot, - values: filteredValues, - yTitle: truncateVerticalTitle(plot.metric, nbItemsPerRow, height) as string - } as CustomPlotData + return plots } type RevisionPathData = { [path: string]: Record[] } diff --git a/extension/src/plots/model/index.ts b/extension/src/plots/model/index.ts index d4431d0e7d..0f5705b18c 100644 --- a/extension/src/plots/model/index.ts +++ b/extension/src/plots/model/index.ts @@ -11,15 +11,10 @@ import { collectCommitRevisionDetails, collectOverrideRevisionDetails, collectCustomPlots, - getCustomPlotId, - collectCustomPlotData + getCustomPlotId } from './collect' import { getRevisionFirstThreeColumns } from './util' -import { - cleanupOldOrderValue, - CustomPlotsOrderValue, - isCheckpointPlot -} from './custom' +import { cleanupOldOrderValue, CustomPlotsOrderValue } from './custom' import { CheckpointPlot, ComparisonPlots, @@ -31,8 +26,6 @@ import { SectionCollapsed, CustomPlotData, CustomPlotsData, - CustomPlot, - ColorScale, DEFAULT_HEIGHT, DEFAULT_NB_ITEMS_PER_ROW, PlotHeight @@ -78,8 +71,6 @@ export class PlotsModel extends ModelWithPersistence { private multiSourceVariations: MultiSourceVariations = {} private multiSourceEncoding: MultiSourceEncoding = {} - private customPlots?: CustomPlot[] - constructor( dvcRoot: string, experiments: Experiments, @@ -103,8 +94,6 @@ export class PlotsModel extends ModelWithPersistence { } public transformAndSetExperiments() { - this.recreateCustomPlots() - return this.removeStaleData() } @@ -149,7 +138,13 @@ export class PlotsModel extends ModelWithPersistence { } public getCustomPlots(): CustomPlotsData | undefined { - if (!this.customPlots) { + const experimentsWithNoCommitData = this.experiments.hasCheckpoints() + ? this.experiments + .getExperimentsWithCheckpoints() + .filter(({ checkpoints }) => !!checkpoints) + : this.experiments.getExperiments() + + if (experimentsWithNoCommitData.length === 0) { return } @@ -158,32 +153,29 @@ export class PlotsModel extends ModelWithPersistence { .getSelectedExperiments() .map(({ displayColor, id: revision }) => ({ displayColor, revision })) ) + const height = this.getHeight(Section.CUSTOM_PLOTS) + const nbItemsPerRow = this.getNbItemsPerRow(Section.CUSTOM_PLOTS) + const plotsOrderValues = this.getCustomPlotsOrder() + + const plots: CustomPlotData[] = collectCustomPlots({ + experiments: experimentsWithNoCommitData, + hasCheckpoints: this.experiments.hasCheckpoints(), + height, + nbItemsPerRow, + plotsOrderValues, + selectedRevisions: colors?.domain + }) - return { - colors, - height: this.getHeight(Section.CUSTOM_PLOTS), - nbItemsPerRow: this.getNbItemsPerRow(Section.CUSTOM_PLOTS), - plots: this.getCustomPlotsData(this.customPlots, colors) - } - } - - public recreateCustomPlots() { - const experimentsWithNoCommitData = this.experiments.hasCheckpoints() - ? this.experiments - .getExperimentsWithCheckpoints() - .filter(({ checkpoints }) => !!checkpoints) - : this.experiments.getExperiments() - - if (experimentsWithNoCommitData.length === 0) { - this.customPlots = undefined + if (plots.length === 0 && plotsOrderValues.length > 0) { return } - const customPlots: CustomPlot[] = collectCustomPlots( - this.getCustomPlotsOrder(), - experimentsWithNoCommitData - ) - this.customPlots = customPlots + return { + colors, + height, + nbItemsPerRow, + plots + } } public getCustomPlotsOrder() { @@ -194,7 +186,6 @@ export class PlotsModel extends ModelWithPersistence { public updateCustomPlotsOrder(plotsOrder: CustomPlotsOrderValue[]) { this.customPlotsOrder = plotsOrder - this.recreateCustomPlots() } public setCustomPlotsOrder(plotsOrder: CustomPlotsOrderValue[]) { @@ -454,28 +445,6 @@ export class PlotsModel extends ModelWithPersistence { return this.commitRevisions[label] || label } - private getCustomPlotsData( - plots: CustomPlot[], - colors: ColorScale | undefined - ): CustomPlotData[] { - const selectedExperimentsExist = !!colors - const filteredPlots: CustomPlotData[] = [] - for (const plot of plots) { - if (!selectedExperimentsExist && isCheckpointPlot(plot)) { - continue - } - filteredPlots.push( - collectCustomPlotData( - plot, - colors, - this.getNbItemsPerRow(Section.CUSTOM_PLOTS), - this.getHeight(Section.CUSTOM_PLOTS) - ) - ) - } - return filteredPlots - } - private getSelectedComparisonPlots( paths: string[], selectedRevisions: string[] diff --git a/extension/src/plots/webview/messages.ts b/extension/src/plots/webview/messages.ts index 6a9d82441f..a0bd20cc0f 100644 --- a/extension/src/plots/webview/messages.ts +++ b/extension/src/plots/webview/messages.ts @@ -37,6 +37,7 @@ import { doesCustomPlotAlreadyExist, isCheckpointValue } from '../model/custom' +import { getCustomPlotId } from '../model/collect' export class WebviewMessages { private readonly paths: PathsModel @@ -278,20 +279,23 @@ export class WebviewMessages { } private setCustomPlotsOrder(plotIds: string[]) { - const customPlots = this.plots.getCustomPlots()?.plots - if (!customPlots) { - return - } + const customPlotsOrderWithId = this.plots + .getCustomPlotsOrder() + .map(value => ({ + ...value, + id: getCustomPlotId(value.metric, value.param) + })) const newOrder: CustomPlotsOrderValue[] = reorderObjectList( plotIds, - customPlots, + customPlotsOrderWithId, 'id' ).map(({ metric, param, type }) => ({ metric, param, type })) + this.plots.setCustomPlotsOrder(newOrder) this.sendCustomPlotsAndEvent(EventName.VIEWS_REORDER_PLOTS_CUSTOM) }