diff --git a/extension/resources/walkthrough/images/plots-custom.png b/extension/resources/walkthrough/images/plots-custom.png new file mode 100644 index 0000000000..68523646ce Binary files /dev/null and b/extension/resources/walkthrough/images/plots-custom.png differ diff --git a/extension/resources/walkthrough/images/plots-trends.png b/extension/resources/walkthrough/images/plots-trends.png deleted file mode 100644 index a8f83df9f3..0000000000 Binary files a/extension/resources/walkthrough/images/plots-trends.png and /dev/null differ diff --git a/extension/resources/walkthrough/live-plots.md b/extension/resources/walkthrough/live-plots.md index e7a12f39c3..02bbb88dc1 100644 --- a/extension/resources/walkthrough/live-plots.md +++ b/extension/resources/walkthrough/live-plots.md @@ -33,8 +33,6 @@ for epoch in range(NUM_EPOCHS): `DVCLive` is _optional_, and you can just append or modify plot files using any language and any tool. -💡 `Trends` section of the plots dashboard is being updated automatically based -on the data in the table. You don't even have to manage or write any special -plot files, but you need to enable -[checkpoints](https://dvc.org/doc/user-guide/experiment-management/checkpoints) -in the project. +💡 Plots created in the `Custom` section of the plots dashboard are being +updated automatically based on the data in the table. You don't even have to +manage or write any special plot files. diff --git a/extension/resources/walkthrough/plots.md b/extension/resources/walkthrough/plots.md index c2e13e36df..eab44a5b09 100644 --- a/extension/resources/walkthrough/plots.md +++ b/extension/resources/walkthrough/plots.md @@ -66,25 +66,29 @@ templates], which may be predefined (e.g. confusion matrix, linear) or custom alt="Plots: Images" />

-

- Plots: Trends -

- **Images** (e.g. `.jpg` or `.svg` files) can be visualized as well. They will be rendered side by side for the selected experiments.

- Plots View Icon + Plots: Custom

-Automatically generated and updated **Trends** that show scalar [metrics] value -per epoch if [checkpoints] are enabled. +**Custom** plots are generated linear plots comparing metrics and params. A user +can add two types of plots, "Checkpoint Trend" and "Metric Vs Param". + +"Metric Vs Param" plots compare a chosen metric and param across experiments. +"Checkpoint Trend" plots can compare a chosen [metric] value per epoch if +[checkpoints] are enabled. -[metrics]: https://dvc.org/doc/command-reference/metrics +[metric]: https://dvc.org/doc/command-reference/metrics [checkpoints]: https://dvc.org/doc/user-guide/experiment-management/checkpoints +

+ Plots View Icon +

+ The **Plots Dashboard** can be configured and accessed from the _Plots_ and _Experiments_ side panels in the [**DVC View**](command:views.dvc-views). diff --git a/extension/src/experiments/index.ts b/extension/src/experiments/index.ts index 360af40e31..67f421641f 100644 --- a/extension/src/experiments/index.ts +++ b/extension/src/experiments/index.ts @@ -333,6 +333,10 @@ export class Experiments extends BaseRepository { return this.experiments.getExperimentCount() } + public getExperimentsWithCheckpoints() { + return this.experiments.getExperimentsWithCheckpoints() + } + public async selectExperiments() { const experiments = this.experiments.getExperimentsWithCheckpoints() diff --git a/extension/src/experiments/model/index.ts b/extension/src/experiments/model/index.ts index 0a29502689..1a7a386086 100644 --- a/extension/src/experiments/model/index.ts +++ b/extension/src/experiments/model/index.ts @@ -60,6 +60,10 @@ export type ExperimentWithCheckpoints = Experiment & { checkpoints?: Experiment[] } +export type ExperimentWithDefinedCheckpoints = Experiment & { + checkpoints: Experiment[] +} + export enum ExperimentType { WORKSPACE = 'workspace', COMMIT = 'commit', diff --git a/extension/src/plots/index.ts b/extension/src/plots/index.ts index e612fd456a..d61e04875e 100644 --- a/extension/src/plots/index.ts +++ b/extension/src/plots/index.ts @@ -13,7 +13,6 @@ import { Experiments } from '../experiments' import { Resource } from '../resourceLocator' import { InternalCommands } from '../commands/internal' import { definedAndNonEmpty } from '../util/array' -import { ExperimentsOutput } from '../cli/dvc/contract' import { TEMP_PLOTS_DIR } from '../cli/dvc/constants' import { removeDir } from '../fileSystem' import { Toast } from '../vscode/toast' @@ -173,7 +172,7 @@ export class Plots extends BaseRepository { waitForInitialExpData.dispose() this.data.setMetricFiles(data) this.setupExperimentsListener(experiments) - void this.initializeData(data) + void this.initializeData() } }) ) @@ -184,7 +183,7 @@ export class Plots extends BaseRepository { experiments.onDidChangeExperiments(async data => { if (data) { await Promise.all([ - this.plots.transformAndSetExperiments(data), + this.plots.transformAndSetExperiments(), this.data.setMetricFiles(data) ]) } @@ -200,8 +199,8 @@ export class Plots extends BaseRepository { ) } - private async initializeData(data: ExperimentsOutput) { - await this.plots.transformAndSetExperiments(data) + private async initializeData() { + await this.plots.transformAndSetExperiments() void this.data.managedUpdate() await Promise.all([ this.data.isReady(), diff --git a/extension/src/plots/model/collect.test.ts b/extension/src/plots/model/collect.test.ts index 0b81f32648..295beb3e65 100644 --- a/extension/src/plots/model/collect.test.ts +++ b/extension/src/plots/model/collect.test.ts @@ -1,26 +1,29 @@ import { join } from 'path' -import omit from 'lodash.omit' import isEmpty from 'lodash.isempty' import { collectData, - collectCheckpointPlotsData, collectTemplates, - collectMetricOrder, collectOverrideRevisionDetails, - collectCustomPlotsData + collectCustomPlots } from './collect' +import { isCheckpointPlot } from './custom' import plotsDiffFixture from '../../test/fixtures/plotsDiff/output' -import expShowFixture from '../../test/fixtures/expShow/base/output' -import modifiedFixture from '../../test/fixtures/expShow/modified/output' -import checkpointPlotsFixture from '../../test/fixtures/expShow/base/checkpointPlots' -import customPlotsFixture from '../../test/fixtures/expShow/base/customPlots' +import customPlotsFixture, { + customPlotsOrderFixture, + experimentsWithCheckpoints +} from '../../test/fixtures/expShow/base/customPlots' import { - ExperimentsOutput, ExperimentStatus, EXPERIMENT_WORKSPACE_ID } from '../../cli/dvc/contract' -import { definedAndNonEmpty, sameContents } from '../../util/array' -import { TemplatePlot } from '../webview/contract' +import { sameContents } from '../../util/array' +import { + CustomPlotData, + CustomPlotType, + DEFAULT_NB_ITEMS_PER_ROW, + DEFAULT_PLOT_HEIGHT, + TemplatePlot +} from '../webview/contract' import { getCLICommitId } from '../../test/fixtures/plotsDiff/util' import { SelectedExperimentWithColor } from '../../experiments/model' import { Experiment } from '../../experiments/webview/contract' @@ -29,230 +32,62 @@ const logsLossPath = join('logs', 'loss.tsv') const logsLossPlot = (plotsDiffFixture[logsLossPath][0] || {}) as TemplatePlot -describe('collectCustomPlotsData', () => { - it('should return the expected data from the text fixture', () => { - const data = collectCustomPlotsData( - [ - { - metric: 'metrics:summary.json:loss', - param: 'params:params.yaml:dropout' - }, - { - metric: 'metrics:summary.json:accuracy', - param: 'params:params.yaml:epochs' - } - ], - [ - { - id: '12345', - label: '123', - metrics: { - 'summary.json': { - accuracy: 0.4668000042438507, - loss: 2.0205044746398926 - } - }, - name: 'exp-e7a67', - params: { 'params.yaml': { dropout: 0.15, epochs: 16 } } - }, - { - id: '12345', - label: '123', - metrics: { - 'summary.json': { - accuracy: 0.3484833240509033, - loss: 1.9293040037155151 - } - }, - name: 'exp-83425', - params: { 'params.yaml': { dropout: 0.25, epochs: 10 } } - }, - { - id: '12345', - label: '123', - metrics: { - 'summary.json': { - accuracy: 0.6768440509033, - loss: 2.298503875732422 - } - }, - name: 'exp-f13bca', - params: { 'params.yaml': { dropout: 0.32, epochs: 20 } } - } - ] - ) - expect(data).toStrictEqual(customPlotsFixture.plots) - }) -}) +describe('collectCustomPlots', () => { + const defaultFuncArgs = { + experiments: experimentsWithCheckpoints, + hasCheckpoints: true, + height: DEFAULT_PLOT_HEIGHT, + nbItemsPerRow: DEFAULT_NB_ITEMS_PER_ROW, + plotsOrderValues: customPlotsOrderFixture, + selectedRevisions: customPlotsFixture.colors?.domain + } -describe('collectCheckpointPlotsData', () => { it('should return the expected data from the test fixture', () => { - const data = collectCheckpointPlotsData(expShowFixture) - expect(data).toStrictEqual( - checkpointPlotsFixture.plots.map(({ id, values }) => ({ id, values })) - ) - }) - - it('should provide a continuous series for a modified experiment', () => { - const data = collectCheckpointPlotsData(modifiedFixture) - - expect(definedAndNonEmpty(data)).toBeTruthy() - - for (const { values } of data || []) { - const initialExperiment = values.filter( - point => point.group === 'exp-908bd' - ) - const modifiedExperiment = values.find( - point => point.group === 'exp-01b3a' - ) - - const lastIterationInitial = initialExperiment?.slice(-1)[0] - const firstIterationModified = modifiedExperiment - - expect(lastIterationInitial).not.toStrictEqual(firstIterationModified) - expect(omit(lastIterationInitial, 'group')).toStrictEqual( - omit(firstIterationModified, 'group') - ) - - const baseExperiment = values.filter(point => point.group === 'exp-920fc') - const restartedExperiment = values.find( - point => point.group === 'exp-9bc1b' - ) - - const iterationRestartedFrom = baseExperiment?.slice(5)[0] - const firstIterationAfterRestart = restartedExperiment - - expect(iterationRestartedFrom).not.toStrictEqual( - firstIterationAfterRestart - ) - expect(omit(iterationRestartedFrom, 'group')).toStrictEqual( - omit(firstIterationAfterRestart, 'group') - ) - } - }) - - it('should return undefined given no input', () => { - const data = collectCheckpointPlotsData({} as ExperimentsOutput) - expect(data).toBeUndefined() + const expectedOutput: CustomPlotData[] = customPlotsFixture.plots + const data = collectCustomPlots(defaultFuncArgs) + expect(data).toStrictEqual(expectedOutput) }) -}) -describe('collectMetricOrder', () => { - it('should return an empty array if there is no checkpoints data', () => { - const metricOrder = collectMetricOrder( - undefined, - ['metric:A', 'metric:B'], - [] + it('should return only custom plots if there no selected revisions', () => { + const expectedOutput: CustomPlotData[] = customPlotsFixture.plots.filter( + plot => plot.type !== CustomPlotType.CHECKPOINT ) - expect(metricOrder).toStrictEqual([]) - }) + const data = collectCustomPlots({ + ...defaultFuncArgs, + selectedRevisions: undefined + }) - it('should return an empty array if the checkpoints data is an empty array', () => { - const metricOrder = collectMetricOrder([], ['metric:A', 'metric:B'], []) - expect(metricOrder).toStrictEqual([]) + expect(data).toStrictEqual(expectedOutput) }) - it('should maintain the existing order if all metrics are selected', () => { - const expectedOrder = [ - 'metric:F', - 'metric:A', - 'metric:B', - 'metric:E', - 'metric:D', - 'metric:C' - ] - - const metricOrder = collectMetricOrder( - [ - { id: 'metric:A', values: [] }, - { id: 'metric:B', values: [] }, - { id: 'metric:C', values: [] }, - { id: 'metric:D', values: [] }, - { id: 'metric:E', values: [] }, - { id: 'metric:F', values: [] } - ], - expectedOrder, - expectedOrder + it('should return only custom plots if checkpoints are not enabled', () => { + const expectedOutput: CustomPlotData[] = customPlotsFixture.plots.filter( + plot => plot.type !== CustomPlotType.CHECKPOINT ) - expect(metricOrder).toStrictEqual(expectedOrder) + const data = collectCustomPlots({ + ...defaultFuncArgs, + hasCheckpoints: false + }) + + expect(data).toStrictEqual(expectedOutput) }) - it('should push unselected metrics to the end', () => { - const existingOrder = [ - 'metric:F', - 'metric:A', - 'metric:B', - 'metric:E', - 'metric:D', - 'metric:C' - ] + it('should return checkpoint plots with values only containing selected experiments data', () => { + const domain = customPlotsFixture.colors?.domain.slice(1) as string[] - const metricOrder = collectMetricOrder( - [ - { id: 'metric:A', values: [] }, - { id: 'metric:B', values: [] }, - { id: 'metric:C', values: [] }, - { id: 'metric:D', values: [] }, - { id: 'metric:E', values: [] }, - { id: 'metric:F', values: [] } - ], - existingOrder, - existingOrder.filter(metric => !['metric:A', 'metric:B'].includes(metric)) - ) - expect(metricOrder).toStrictEqual([ - 'metric:F', - 'metric:E', - 'metric:D', - 'metric:C', - 'metric:A', - 'metric:B' - ]) - }) + const expectedOutput = customPlotsFixture.plots.map(plot => ({ + ...plot, + values: isCheckpointPlot(plot) + ? plot.values.filter(value => domain.includes(value.group)) + : plot.values + })) - it('should add new metrics in the given order', () => { - const metricOrder = collectMetricOrder( - [ - { id: 'metric:C', values: [] }, - { id: 'metric:D', values: [] }, - { id: 'metric:A', values: [] }, - { id: 'metric:B', values: [] }, - { id: 'metric:E', values: [] }, - { id: 'metric:F', values: [] } - ], - ['metric:B', 'metric:A'], - ['metric:B', 'metric:A'] - ) - expect(metricOrder).toStrictEqual([ - 'metric:B', - 'metric:A', - 'metric:C', - 'metric:D', - 'metric:E', - 'metric:F' - ]) - }) + const data = collectCustomPlots({ + ...defaultFuncArgs, + selectedRevisions: domain + }) - it('should give selected metrics precedence', () => { - const metricOrder = collectMetricOrder( - [ - { id: 'metric:C', values: [] }, - { id: 'metric:D', values: [] }, - { id: 'metric:A', values: [] }, - { id: 'metric:B', values: [] }, - { id: 'metric:E', values: [] }, - { id: 'metric:F', values: [] } - ], - ['metric:B', 'metric:A'], - ['metric:B', 'metric:A', 'metric:F'] - ) - expect(metricOrder).toStrictEqual([ - 'metric:B', - 'metric:A', - 'metric:F', - 'metric:C', - 'metric:D', - 'metric:E' - ]) + expect(data).toStrictEqual(expectedOutput) }) }) diff --git a/extension/src/plots/model/collect.ts b/extension/src/plots/model/collect.ts index e38f72075b..7612e164ae 100644 --- a/extension/src/plots/model/collect.ts +++ b/extension/src/plots/model/collect.ts @@ -1,13 +1,16 @@ -import omit from 'lodash.omit' import get from 'lodash.get' import { TopLevelSpec } from 'vega-lite' import { VisualizationSpec } from 'react-vega' -import { CustomPlotsOrderValue } from '.' +import { + getFullValuePath, + CHECKPOINTS_PARAM, + CustomPlotsOrderValue, + isCheckpointValue +} from './custom' import { getRevisionFirstThreeColumns } from './util' import { ColorScale, CheckpointPlotValues, - CheckpointPlot, isImagePlot, ImagePlot, TemplatePlot, @@ -16,38 +19,26 @@ import { TemplatePlotSection, PlotsType, Revision, - CustomPlotData + CustomPlotData, + MetricVsParamPlotValues } from '../webview/contract' +import { EXPERIMENT_WORKSPACE_ID, PlotsOutput } from '../../cli/dvc/contract' import { - EXPERIMENT_WORKSPACE_ID, - ExperimentFieldsOrError, - ExperimentsOutput, - ExperimentStatus, - isValueTree, - PlotsOutput, - Value, - ValueTree -} from '../../cli/dvc/contract' -import { extractColumns } from '../../experiments/columns/extract' -import { - decodeColumn, - appendColumnToPath, - splitColumnPath + splitColumnPath, + FILE_SEPARATOR } from '../../experiments/columns/paths' import { ColumnType, Experiment, - isRunning, - MetricOrParamColumns + isRunning } from '../../experiments/webview/contract' -import { addToMapArray } from '../../util/map' import { TemplateOrder } from '../paths/collect' -import { extendVegaSpec, isMultiViewPlot } from '../vega/util' import { - definedAndNonEmpty, - reorderObjectList, - splitMatchedOrdered -} from '../../util/array' + extendVegaSpec, + isMultiViewPlot, + truncateVerticalTitle +} from '../vega/util' +import { definedAndNonEmpty, reorderObjectList } from '../../util/array' import { shortenForLabel } from '../../util/string' import { getDvcDataVersionInfo, @@ -57,220 +48,72 @@ import { unmergeConcatenatedFields } from '../multiSource/collect' import { StrokeDashEncoding } from '../multiSource/constants' -import { SelectedExperimentWithColor } from '../../experiments/model' +import { + ExperimentWithCheckpoints, + ExperimentWithDefinedCheckpoints, + SelectedExperimentWithColor +} from '../../experiments/model' import { Color } from '../../experiments/model/status/colors' -import { typedValueTreeEntries } from '../../experiments/columns/collect/metricsAndParams' - -type CheckpointPlotAccumulator = { - iterations: Record - plots: Map -} - -const collectFromMetricsFile = ( - acc: CheckpointPlotAccumulator, - name: string, - iteration: number, - key: string | undefined, - value: Value | ValueTree, - ancestors: string[] = [] -) => { - const pathArray = [...ancestors, key].filter(Boolean) as string[] - - if (isValueTree(value)) { - for (const [childKey, childValue] of typedValueTreeEntries(value)) { - collectFromMetricsFile( - acc, - name, - iteration, - childKey, - childValue, - pathArray - ) - } - return - } - - const path = appendColumnToPath(...pathArray) - - addToMapArray(acc.plots, path, { group: name, iteration, y: value }) -} - -type MetricsAndDetailsOrUndefined = - | { - checkpoint_parent: string | undefined - checkpoint_tip: string | undefined - metrics: MetricOrParamColumns | undefined - status: ExperimentStatus | undefined - } - | undefined - -const transformExperimentData = ( - experimentFieldsOrError: ExperimentFieldsOrError -): MetricsAndDetailsOrUndefined => { - const experimentFields = experimentFieldsOrError.data - if (!experimentFields) { - return - } - - const { checkpoint_tip, checkpoint_parent, status } = experimentFields - const { metrics } = extractColumns(experimentFields) - - return { checkpoint_parent, checkpoint_tip, metrics, status } -} -type ValidData = { - checkpoint_parent: string - checkpoint_tip: string - metrics: MetricOrParamColumns - status: ExperimentStatus -} - -const isValid = (data: MetricsAndDetailsOrUndefined): data is ValidData => - !!(data?.checkpoint_tip && data?.checkpoint_parent && data?.metrics) +export const getCustomPlotId = (metric: string, param = CHECKPOINTS_PARAM) => + `custom-${metric}-${param}` -const collectFromMetrics = ( - acc: CheckpointPlotAccumulator, - experimentName: string, - iteration: number, - metrics: MetricOrParamColumns -) => { - for (const file of Object.keys(metrics)) { - collectFromMetricsFile( - acc, - experimentName, - iteration, - undefined, - metrics[file], - [file] - ) - } -} +const getValueFromColumn = ( + path: string, + experiment: ExperimentWithCheckpoints +) => get(experiment, splitColumnPath(path)) as number | undefined -const getLastIteration = ( - acc: CheckpointPlotAccumulator, - checkpointParent: string -): number => acc.iterations[checkpointParent] || 0 - -const collectIteration = ( - acc: CheckpointPlotAccumulator, - sha: string, - checkpointParent: string -): number => { - const iteration = getLastIteration(acc, checkpointParent) + 1 - acc.iterations[sha] = iteration - return iteration -} +const isExperimentWithDefinedCheckpoints = ( + experiment: ExperimentWithCheckpoints +): experiment is ExperimentWithDefinedCheckpoints => !!experiment.checkpoints -const linkModified = ( - acc: CheckpointPlotAccumulator, - experimentName: string, - checkpointTip: string, - checkpointParent: string, - parent: ExperimentFieldsOrError | undefined +const collectCheckpointValuesFromExperiment = ( + values: CheckpointPlotValues, + exp: ExperimentWithDefinedCheckpoints, + metricPath: string ) => { - if (!parent) { - return - } + const group = exp.name || exp.label + const maxEpoch = exp.checkpoints.length + 1 - const parentData = transformExperimentData(parent) - if (!isValid(parentData) || parentData.checkpoint_tip === checkpointTip) { - return + const metricValue = getValueFromColumn(metricPath, exp) + if (metricValue !== undefined) { + values.push({ group, iteration: maxEpoch, y: metricValue }) } - const lastIteration = getLastIteration(acc, checkpointParent) - collectFromMetrics(acc, experimentName, lastIteration, parentData.metrics) -} - -const collectFromExperimentsObject = ( - acc: CheckpointPlotAccumulator, - experimentsObject: { [sha: string]: ExperimentFieldsOrError } -) => { - for (const [sha, experimentData] of Object.entries( - experimentsObject - ).reverse()) { - const data = transformExperimentData(experimentData) - - if (!isValid(data)) { - continue + for (const [ind, checkpoint] of exp.checkpoints.entries()) { + const metricValue = getValueFromColumn(metricPath, checkpoint) + if (metricValue !== undefined) { + values.push({ group, iteration: maxEpoch - ind - 1, y: metricValue }) } - const { - checkpoint_tip: checkpointTip, - checkpoint_parent: checkpointParent, - metrics - } = data - - const experimentName = experimentsObject[checkpointTip].data?.name - if (!experimentName) { - continue - } - - linkModified( - acc, - experimentName, - checkpointTip, - checkpointParent, - experimentsObject[checkpointParent] - ) - - const iteration = collectIteration(acc, sha, checkpointParent) - collectFromMetrics(acc, experimentName, iteration, metrics) } } -export const collectCheckpointPlotsData = ( - data: ExperimentsOutput -): CheckpointPlot[] | undefined => { - const acc = { - iterations: {}, - plots: new Map() - } - - for (const { baseline, ...experimentsObject } of Object.values( - omit(data, EXPERIMENT_WORKSPACE_ID) - )) { - const commit = transformExperimentData(baseline) - - if (commit) { - collectFromExperimentsObject(acc, experimentsObject) +const getCheckpointValues = ( + experiments: ExperimentWithCheckpoints[], + metricPath: string +): CheckpointPlotValues => { + const values: CheckpointPlotValues = [] + for (const experiment of experiments) { + if (isExperimentWithDefinedCheckpoints(experiment)) { + collectCheckpointValuesFromExperiment(values, experiment, metricPath) } } - - if (acc.plots.size === 0) { - return - } - - const plotsData: CheckpointPlot[] = [] - - for (const [key, value] of acc.plots.entries()) { - plotsData.push({ id: decodeColumn(key), values: value }) - } - - return plotsData + return values } -export const getCustomPlotId = (metric: string, param: string) => - `custom-${metric}-${param}` - -const collectCustomPlotData = ( - metric: string, - param: string, - experiments: Experiment[] -): CustomPlotData => { - const splitUpMetricPath = splitColumnPath(metric) - const splitUpParamPath = splitColumnPath(param) - const plotData: CustomPlotData = { - id: getCustomPlotId(metric, param), - metric: metric.slice(ColumnType.METRICS.length + 1), - param: param.slice(ColumnType.PARAMS.length + 1), - values: [] - } +const getMetricVsParamValues = ( + experiments: ExperimentWithCheckpoints[], + metricPath: string, + paramPath: string +): MetricVsParamPlotValues => { + const values: MetricVsParamPlotValues = [] for (const experiment of experiments) { - const metricValue = get(experiment, splitUpMetricPath) as number | undefined - const paramValue = get(experiment, splitUpParamPath) as number | undefined + const metricValue = getValueFromColumn(metricPath, experiment) + const paramValue = getValueFromColumn(paramPath, experiment) if (metricValue !== undefined && paramValue !== undefined) { - plotData.values.push({ + values.push({ expName: experiment.name || experiment.label, metric: metricValue, param: paramValue @@ -278,73 +121,78 @@ const collectCustomPlotData = ( } } - return plotData + return values } -export const collectCustomPlotsData = ( - metricsAndParams: CustomPlotsOrderValue[], - experiments: Experiment[] -): CustomPlotData[] => { - return metricsAndParams.map(({ metric, param }) => - collectCustomPlotData(metric, param, experiments) +const getCustomPlotData = ( + orderValue: CustomPlotsOrderValue, + experiments: ExperimentWithCheckpoints[], + selectedRevisions: string[] | undefined = [], + height: number, + nbItemsPerRow: number +): CustomPlotData => { + const { metric, param, type } = orderValue + const metricPath = getFullValuePath( + ColumnType.METRICS, + metric, + FILE_SEPARATOR ) -} -type MetricOrderAccumulator = { - newOrder: string[] - uncollectedMetrics: string[] - remainingSelectedMetrics: string[] -} + const paramPath = getFullValuePath(ColumnType.PARAMS, param, FILE_SEPARATOR) -const collectExistingOrder = ( - acc: MetricOrderAccumulator, - existingMetricOrder: string[] -) => { - for (const metric of existingMetricOrder) { - const uncollectedIndex = acc.uncollectedMetrics.indexOf(metric) - const remainingIndex = acc.remainingSelectedMetrics.indexOf(metric) - if (uncollectedIndex === -1 || remainingIndex === -1) { - continue - } - acc.uncollectedMetrics.splice(uncollectedIndex, 1) - acc.remainingSelectedMetrics.splice(remainingIndex, 1) - acc.newOrder.push(metric) - } -} - -const collectRemainingSelected = (acc: MetricOrderAccumulator) => { - const [newOrder, uncollectedMetrics] = splitMatchedOrdered( - acc.uncollectedMetrics, - acc.remainingSelectedMetrics + const selectedExperiments = experiments.filter(({ name, label }) => + selectedRevisions.includes(name || label) ) - acc.newOrder.push(...newOrder) - acc.uncollectedMetrics = uncollectedMetrics -} + const values = isCheckpointValue(type) + ? getCheckpointValues(selectedExperiments, metricPath) + : getMetricVsParamValues(experiments, metricPath, paramPath) -export const collectMetricOrder = ( - checkpointPlotData: CheckpointPlot[] | undefined, - existingMetricOrder: string[], - selectedMetrics: string[] = [] -): string[] => { - if (!definedAndNonEmpty(checkpointPlotData)) { - return [] - } - - const acc: MetricOrderAccumulator = { - newOrder: [], - remainingSelectedMetrics: [...selectedMetrics], - uncollectedMetrics: checkpointPlotData.map(({ id }) => id) - } + return { + id: getCustomPlotId(metric, param), + metric, + param, + type, + values, + yTitle: truncateVerticalTitle(metric, nbItemsPerRow, height) as string + } as CustomPlotData +} + +export const collectCustomPlots = ({ + plotsOrderValues, + experiments, + hasCheckpoints, + selectedRevisions, + height, + nbItemsPerRow +}: { + plotsOrderValues: CustomPlotsOrderValue[] + experiments: ExperimentWithCheckpoints[] + hasCheckpoints: boolean + selectedRevisions: string[] | undefined + height: number + nbItemsPerRow: number +}): CustomPlotData[] => { + const plots = [] + const shouldSkipCheckpointPlots = !hasCheckpoints || !selectedRevisions + + for (const value of plotsOrderValues) { + if (shouldSkipCheckpointPlots && isCheckpointValue(value.type)) { + continue + } - if (!definedAndNonEmpty(acc.remainingSelectedMetrics)) { - return acc.uncollectedMetrics + plots.push( + getCustomPlotData( + value, + experiments, + selectedRevisions, + height, + nbItemsPerRow + ) + ) } - collectExistingOrder(acc, existingMetricOrder) - collectRemainingSelected(acc) - - return [...acc.newOrder, ...acc.uncollectedMetrics] + return plots } type RevisionPathData = { [path: string]: Record[] } diff --git a/extension/src/plots/model/custom.test.ts b/extension/src/plots/model/custom.test.ts new file mode 100644 index 0000000000..e1aab504c7 --- /dev/null +++ b/extension/src/plots/model/custom.test.ts @@ -0,0 +1,86 @@ +import { + CHECKPOINTS_PARAM, + cleanupOldOrderValue, + doesCustomPlotAlreadyExist +} from './custom' +import { CustomPlotType } from '../webview/contract' +import { FILE_SEPARATOR } from '../../experiments/columns/paths' + +describe('doesCustomPlotAlreadyExist', () => { + it('should return true if plot exists', () => { + const output = doesCustomPlotAlreadyExist( + [ + { + metric: 'summary.json:loss', + param: 'params.yaml:dropout', + type: CustomPlotType.METRIC_VS_PARAM + }, + { + metric: 'summary.json:accuracy', + param: 'params.yaml:epochs', + type: CustomPlotType.METRIC_VS_PARAM + }, + { + metric: 'summary.json:loss', + param: CHECKPOINTS_PARAM, + type: CustomPlotType.CHECKPOINT + } + ], + 'summary.json:accuracy', + 'params.yaml:epochs' + ) + expect(output).toStrictEqual(true) + }) + + it('should return false if plot does not exists', () => { + const output = doesCustomPlotAlreadyExist( + [ + { + metric: 'summary.json:loss', + param: 'params.yaml:dropout', + type: CustomPlotType.METRIC_VS_PARAM + }, + { + metric: 'summary.json:accuracy', + param: 'params.yaml:epochs', + type: CustomPlotType.METRIC_VS_PARAM + }, + { + metric: 'summary.json:loss', + param: CHECKPOINTS_PARAM, + type: CustomPlotType.CHECKPOINT + } + ], + 'summary.json:loss', + 'params.yaml:epochs' + ) + expect(output).toStrictEqual(false) + }) +}) + +describe('cleanupOlderValue', () => { + it('should update value if contents are outdated', () => { + const output = cleanupOldOrderValue( + { + metric: 'metrics:summary.json:loss', + param: 'params:params.yaml:dropout' + }, + FILE_SEPARATOR + ) + expect(output).toStrictEqual({ + metric: 'summary.json:loss', + param: 'params.yaml:dropout', + type: CustomPlotType.METRIC_VS_PARAM + }) + }) + + it('should not update value if contents are not outdated', () => { + const value = { + metric: 'summary.json:loss', + param: 'params.yaml:dropout', + type: CustomPlotType.METRIC_VS_PARAM + } + const output = cleanupOldOrderValue(value, FILE_SEPARATOR) + expect(output).toStrictEqual(value) + }) +}) diff --git a/extension/src/plots/model/custom.ts b/extension/src/plots/model/custom.ts new file mode 100644 index 0000000000..10a0e7e783 --- /dev/null +++ b/extension/src/plots/model/custom.ts @@ -0,0 +1,52 @@ +import { ColumnType } from '../../experiments/webview/contract' +import { CheckpointPlot, CustomPlot, CustomPlotType } from '../webview/contract' + +export const CHECKPOINTS_PARAM = 'epoch' + +export type CustomPlotsOrderValue = { + type: CustomPlotType.METRIC_VS_PARAM | CustomPlotType.CHECKPOINT + metric: string + param: string +} + +export const isCheckpointValue = ( + type: CustomPlotType.CHECKPOINT | CustomPlotType.METRIC_VS_PARAM +) => type === CustomPlotType.CHECKPOINT + +export const isCheckpointPlot = (plot: CustomPlot): plot is CheckpointPlot => + plot.type === CustomPlotType.CHECKPOINT + +export const doesCustomPlotAlreadyExist = ( + order: CustomPlotsOrderValue[], + metric: string, + param = CHECKPOINTS_PARAM +) => + order.some(value => { + return value.param === param && value.metric === metric + }) + +export const removeColumnTypeFromPath = ( + columnPath: string, + type: string, + fileSep: string +) => + columnPath.startsWith(type + fileSep) + ? columnPath.slice(type.length + 1) + : columnPath + +export const getFullValuePath = ( + type: string, + columnPath: string, + fileSep: string +) => type + fileSep + columnPath + +export const cleanupOldOrderValue = ( + value: { metric: string; param: string } | CustomPlotsOrderValue, + fileSep: string +): CustomPlotsOrderValue => ({ + // previous column paths have the "TYPE:" prefix + metric: removeColumnTypeFromPath(value.metric, ColumnType.METRICS, fileSep), + param: removeColumnTypeFromPath(value.param, ColumnType.PARAMS, fileSep), + // previous values didn't have a type + type: (value as CustomPlotsOrderValue).type || CustomPlotType.METRIC_VS_PARAM +}) diff --git a/extension/src/plots/model/index.test.ts b/extension/src/plots/model/index.test.ts index 1e191535e2..7ce55a1141 100644 --- a/extension/src/plots/model/index.test.ts +++ b/extension/src/plots/model/index.test.ts @@ -9,6 +9,7 @@ import { buildMockMemento } from '../../test/util' import { Experiments } from '../../experiments' import { PersistenceKey } from '../../persistence/constants' import { EXPERIMENT_WORKSPACE_ID } from '../../cli/dvc/contract' +import { customPlotsOrderFixture } from '../../test/fixtures/expShow/base/customPlots' const mockedRevisions = [ { displayColor: 'white', label: EXPERIMENT_WORKSPACE_ID }, @@ -21,10 +22,9 @@ const mockedRevisions = [ describe('plotsModel', () => { let model: PlotsModel const exampleDvcRoot = 'test' - const persistedSelectedMetrics = ['loss', 'accuracy'] const memento = buildMockMemento({ - [PersistenceKey.PLOT_SELECTED_METRICS + exampleDvcRoot]: - persistedSelectedMetrics, + [PersistenceKey.PLOTS_CUSTOM_ORDER + exampleDvcRoot]: + customPlotsOrderFixture, [PersistenceKey.PLOT_NB_ITEMS_PER_ROW_OR_WIDTH + exampleDvcRoot]: DEFAULT_SECTION_NB_ITEMS_PER_ROW_OR_WIDTH }) @@ -45,51 +45,29 @@ describe('plotsModel', () => { jest.clearAllMocks() }) - it('should change the selectedMetrics when calling setSelectedMetrics', () => { - expect(model.getSelectedMetrics()).toStrictEqual(persistedSelectedMetrics) - - const newSelectedMetrics = ['one', 'two', 'four', 'hundred'] - model.setSelectedMetrics(newSelectedMetrics) - - expect(model.getSelectedMetrics()).toStrictEqual(newSelectedMetrics) - }) - - it('should update the persisted selected metrics when calling setSelectedMetrics', () => { - const mementoUpdateSpy = jest.spyOn(memento, 'update') - const newSelectedMetrics = ['one', 'two', 'four', 'hundred'] - - model.setSelectedMetrics(newSelectedMetrics) - - expect(mementoUpdateSpy).toHaveBeenCalledTimes(2) - expect(mementoUpdateSpy).toHaveBeenCalledWith( - PersistenceKey.PLOT_SELECTED_METRICS + exampleDvcRoot, - newSelectedMetrics - ) - }) - it('should change the plotSize when calling setPlotSize', () => { expect( - model.getNbItemsPerRowOrWidth(PlotsSection.CHECKPOINT_PLOTS) + model.getNbItemsPerRowOrWidth(PlotsSection.CUSTOM_PLOTS) ).toStrictEqual(DEFAULT_NB_ITEMS_PER_ROW) - model.setNbItemsPerRowOrWidth(PlotsSection.CHECKPOINT_PLOTS, 1) + model.setNbItemsPerRowOrWidth(PlotsSection.CUSTOM_PLOTS, 1) expect( - model.getNbItemsPerRowOrWidth(PlotsSection.CHECKPOINT_PLOTS) + model.getNbItemsPerRowOrWidth(PlotsSection.CUSTOM_PLOTS) ).toStrictEqual(1) }) it('should update the persisted plot size when calling setPlotSize', () => { const mementoUpdateSpy = jest.spyOn(memento, 'update') - model.setNbItemsPerRowOrWidth(PlotsSection.CHECKPOINT_PLOTS, 2) + model.setNbItemsPerRowOrWidth(PlotsSection.CUSTOM_PLOTS, 2) expect(mementoUpdateSpy).toHaveBeenCalledTimes(1) expect(mementoUpdateSpy).toHaveBeenCalledWith( PersistenceKey.PLOT_NB_ITEMS_PER_ROW_OR_WIDTH + exampleDvcRoot, { ...DEFAULT_SECTION_NB_ITEMS_PER_ROW_OR_WIDTH, - [PlotsSection.CHECKPOINT_PLOTS]: 2 + [PlotsSection.CUSTOM_PLOTS]: 2 } ) }) @@ -99,12 +77,11 @@ describe('plotsModel', () => { expect(model.getSectionCollapsed()).toStrictEqual(DEFAULT_SECTION_COLLAPSED) - model.setSectionCollapsed({ [PlotsSection.CHECKPOINT_PLOTS]: true }) + model.setSectionCollapsed({ [PlotsSection.CUSTOM_PLOTS]: true }) const expectedSectionCollapsed = { - [PlotsSection.CHECKPOINT_PLOTS]: true, [PlotsSection.TEMPLATE_PLOTS]: false, - [PlotsSection.CUSTOM_PLOTS]: false, + [PlotsSection.CUSTOM_PLOTS]: true, [PlotsSection.COMPARISON_TABLE]: false } diff --git a/extension/src/plots/model/index.ts b/extension/src/plots/model/index.ts index 580b5e9296..118c51de26 100644 --- a/extension/src/plots/model/index.ts +++ b/extension/src/plots/model/index.ts @@ -2,9 +2,7 @@ import { Memento } from 'vscode' import isEmpty from 'lodash.isempty' import isEqual from 'lodash.isequal' import { - collectCheckpointPlotsData, collectData, - collectMetricOrder, collectSelectedTemplatePlots, collectTemplates, ComparisonData, @@ -12,13 +10,13 @@ import { TemplateAccumulator, collectCommitRevisionDetails, collectOverrideRevisionDetails, - collectCustomPlotsData, + collectCustomPlots, getCustomPlotId } from './collect' import { getRevisionFirstThreeColumns } from './util' +import { cleanupOldOrderValue, CustomPlotsOrderValue } from './custom' import { CheckpointPlot, - CheckpointPlotData, ComparisonPlots, Revision, ComparisonRevisionData, @@ -27,17 +25,17 @@ import { PlotsSection, SectionCollapsed, CustomPlotData, + CustomPlotsData, DEFAULT_HEIGHT, DEFAULT_NB_ITEMS_PER_ROW, PlotHeight } from '../webview/contract' import { - ExperimentsOutput, EXPERIMENT_WORKSPACE_ID, PlotsOutputOrError } from '../../cli/dvc/contract' import { Experiments } from '../../experiments' -import { getColorScale, truncateVerticalTitle } from '../vega/util' +import { getColorScale } from '../vega/util' import { definedAndNonEmpty, reorderObjectList } from '../../util/array' import { removeMissingKeysFromObject } from '../../util/object' import { TemplateOrder } from '../paths/collect' @@ -50,8 +48,9 @@ import { MultiSourceVariations } from '../multiSource/collect' import { isDvcError } from '../../cli/dvc/reader' +import { FILE_SEPARATOR } from '../../experiments/columns/paths' -export type CustomPlotsOrderValue = { metric: string; param: string } +export type CustomCheckpointPlots = { [metric: string]: CheckpointPlot } export class PlotsModel extends ModelWithPersistence { private readonly experiments: Experiments @@ -72,11 +71,6 @@ export class PlotsModel extends ModelWithPersistence { private multiSourceVariations: MultiSourceVariations = {} private multiSourceEncoding: MultiSourceEncoding = {} - private checkpointPlots?: CheckpointPlot[] - private customPlots?: CustomPlotData[] - private selectedMetrics?: string[] - private metricOrder: string[] - constructor( dvcRoot: string, experiments: Experiments, @@ -96,28 +90,10 @@ export class PlotsModel extends ModelWithPersistence { DEFAULT_SECTION_COLLAPSED ) this.comparisonOrder = this.revive(PersistenceKey.PLOT_COMPARISON_ORDER, []) - this.selectedMetrics = this.revive( - PersistenceKey.PLOT_SELECTED_METRICS, - undefined - ) - this.metricOrder = this.revive(PersistenceKey.PLOT_METRIC_ORDER, []) - this.customPlotsOrder = this.revive(PersistenceKey.PLOTS_CUSTOM_ORDER, []) } - public transformAndSetExperiments(data: ExperimentsOutput) { - const checkpointPlots = collectCheckpointPlotsData(data) - - if (!this.selectedMetrics && checkpointPlots) { - this.selectedMetrics = checkpointPlots.map(({ id }) => id) - } - - this.checkpointPlots = checkpointPlots - - this.setMetricOrder() - - this.recreateCustomPlots() - + public transformAndSetExperiments() { return this.removeStaleData() } @@ -135,8 +111,6 @@ export class PlotsModel extends ModelWithPersistence { collectMultiSourceVariations(data, this.multiSourceVariations) ]) - this.recreateCustomPlots() - this.comparisonData = { ...this.comparisonData, ...comparisonData @@ -163,8 +137,14 @@ export class PlotsModel extends ModelWithPersistence { this.deferred.resolve() } - public getCheckpointPlots() { - if (!this.checkpointPlots) { + public getCustomPlots(): CustomPlotsData | undefined { + const experimentsWithNoCommitData = this.experiments.hasCheckpoints() + ? this.experiments + .getExperimentsWithCheckpoints() + .filter(({ checkpoints }) => !!checkpoints) + : this.experiments.getExperiments() + + if (experimentsWithNoCommitData.length === 0) { return } @@ -173,70 +153,58 @@ export class PlotsModel extends ModelWithPersistence { .getSelectedExperiments() .map(({ displayColor, id: revision }) => ({ displayColor, revision })) ) + const height = this.getHeight(PlotsSection.CUSTOM_PLOTS) + const nbItemsPerRow = this.getNbItemsPerRowOrWidth( + PlotsSection.CUSTOM_PLOTS + ) + const plotsOrderValues = this.getCustomPlotsOrder() + + const plots: CustomPlotData[] = collectCustomPlots({ + experiments: experimentsWithNoCommitData, + hasCheckpoints: this.experiments.hasCheckpoints(), + height, + nbItemsPerRow, + plotsOrderValues, + selectedRevisions: colors?.domain + }) - if (!colors) { + if (plots.length === 0 && plotsOrderValues.length > 0) { return } - const { domain: selectedExperiments } = colors - return { colors, - height: this.getHeight(PlotsSection.CHECKPOINT_PLOTS), - nbItemsPerRow: this.getNbItemsPerRowOrWidth( - PlotsSection.CHECKPOINT_PLOTS - ), - plots: this.getPlots(this.checkpointPlots, selectedExperiments), - selectedMetrics: this.getSelectedMetrics() - } - } - - public getCustomPlots() { - if (!this.customPlots) { - return - } - return { - height: this.getHeight(PlotsSection.CUSTOM_PLOTS), - nbItemsPerRow: this.getNbItemsPerRowOrWidth(PlotsSection.CUSTOM_PLOTS), - plots: this.customPlots + height, + nbItemsPerRow, + plots } } - public recreateCustomPlots() { - const experiments = this.experiments.getExperiments() - if (experiments.length === 0) { - this.customPlots = undefined - return - } - const customPlots: CustomPlotData[] = collectCustomPlotsData( - this.getCustomPlotsOrder(), - experiments + public getCustomPlotsOrder() { + return this.customPlotsOrder.map(value => + cleanupOldOrderValue(value, FILE_SEPARATOR) ) - this.customPlots = customPlots } - public getCustomPlotsOrder() { - return this.customPlotsOrder + public updateCustomPlotsOrder(plotsOrder: CustomPlotsOrderValue[]) { + this.customPlotsOrder = plotsOrder } public setCustomPlotsOrder(plotsOrder: CustomPlotsOrderValue[]) { - this.customPlotsOrder = plotsOrder + this.updateCustomPlotsOrder(plotsOrder) this.persist(PersistenceKey.PLOTS_CUSTOM_ORDER, this.customPlotsOrder) - this.recreateCustomPlots() } public removeCustomPlots(plotIds: string[]) { const newCustomPlotsOrder = this.getCustomPlotsOrder().filter( - ({ metric, param }) => { - return !plotIds.includes(getCustomPlotId(metric, param)) - } + ({ metric, param }) => !plotIds.includes(getCustomPlotId(metric, param)) ) this.setCustomPlotsOrder(newCustomPlotsOrder) } - public addCustomPlot(metricAndParam: CustomPlotsOrderValue) { - const newCustomPlotsOrder = [...this.getCustomPlotsOrder(), metricAndParam] + public addCustomPlot(value: CustomPlotsOrderValue) { + const newCustomPlotsOrder = [...this.getCustomPlotsOrder(), value] this.setCustomPlotsOrder(newCustomPlotsOrder) } @@ -383,28 +351,6 @@ export class PlotsModel extends ModelWithPersistence { return this.experiments.getSelectedRevisions().map(({ label }) => label) } - public setSelectedMetrics(selectedMetrics: string[]) { - this.selectedMetrics = selectedMetrics - this.setMetricOrder() - this.persist( - PersistenceKey.PLOT_SELECTED_METRICS, - this.getSelectedMetrics() - ) - } - - public getSelectedMetrics() { - return this.selectedMetrics - } - - public setMetricOrder(metricOrder?: string[]) { - this.metricOrder = collectMetricOrder( - this.checkpointPlots, - metricOrder || this.metricOrder, - this.selectedMetrics - ) - this.persist(PersistenceKey.PLOT_METRIC_ORDER, this.metricOrder) - } - public setNbItemsPerRowOrWidth(section: PlotsSection, nbItemsPerRow: number) { this.nbItemsPerRowOrWidth[section] = nbItemsPerRow this.persist( @@ -504,30 +450,6 @@ export class PlotsModel extends ModelWithPersistence { return this.commitRevisions[label] || label } - private getPlots( - checkpointPlots: CheckpointPlot[], - selectedExperiments: string[] - ) { - return reorderObjectList( - this.metricOrder, - checkpointPlots.map(plot => { - const { id, values } = plot - return { - id, - title: truncateVerticalTitle( - id, - this.getNbItemsPerRowOrWidth(PlotsSection.CHECKPOINT_PLOTS), - this.getHeight(PlotsSection.CHECKPOINT_PLOTS) - ) as string, - values: values.filter(value => - selectedExperiments.includes(value.group) - ) - } - }), - 'id' - ) - } - private getSelectedComparisonPlots( paths: string[], selectedRevisions: string[] diff --git a/extension/src/plots/model/quickPick.test.ts b/extension/src/plots/model/quickPick.test.ts index 363b3503a8..c73dca9695 100644 --- a/extension/src/plots/model/quickPick.test.ts +++ b/extension/src/plots/model/quickPick.test.ts @@ -1,9 +1,15 @@ -import { CustomPlotsOrderValue } from '.' -import { pickCustomPlots, pickMetricAndParam } from './quickPick' +import { CustomPlotsOrderValue } from './custom' +import { + pickCustomPlots, + pickCustomPlotType, + pickMetric, + pickMetricAndParam +} from './quickPick' import { quickPickManyValues, quickPickValue } from '../../vscode/quickPick' import { Title } from '../../vscode/title' import { Toast } from '../../vscode/toast' import { ColumnType } from '../../experiments/webview/contract' +import { CustomPlotType } from '../webview/contract' jest.mock('../../vscode/quickPick') jest.mock('../../vscode/toast') @@ -29,27 +35,30 @@ describe('pickCustomPlots', () => { it('should return the selected plots', async () => { const selectedPlots = [ - 'custom-metrics:summary.json:loss-params:params.yaml:dropout', - 'custom-metrics:summary.json:accuracy-params:params.yaml:epochs' + 'custom-summary.json:loss-epoch', + 'custom-summary.json:accuracy-params.yaml:epochs' ] - const mockedExperiments = [ + const mockedPlots = [ { - metric: 'metrics:summary.json:loss', - param: 'params:params.yaml:dropout' + metric: 'summary.json:loss', + param: 'epoch', + type: CustomPlotType.CHECKPOINT }, { - metric: 'metrics:summary.json:accuracy', - param: 'params:params.yaml:epochs' + metric: 'summary.json:accuracy', + param: 'params.yaml:epochs', + type: CustomPlotType.METRIC_VS_PARAM }, { - metric: 'metrics:summary.json:learning_rate', - param: 'param:summary.json:process.threshold' + metric: 'summary.json:learning_rate', + param: 'summary.json:process.threshold', + type: CustomPlotType.METRIC_VS_PARAM } ] as CustomPlotsOrderValue[] mockedQuickPickManyValues.mockResolvedValueOnce(selectedPlots) const picked = await pickCustomPlots( - mockedExperiments, + mockedPlots, 'There are no plots to remove.', { title: Title.SELECT_CUSTOM_PLOTS_TO_REMOVE } ) @@ -59,24 +68,24 @@ describe('pickCustomPlots', () => { expect(mockedQuickPickManyValues).toHaveBeenCalledWith( [ { - description: - 'metrics:summary.json:loss vs params:params.yaml:dropout', - label: 'loss vs dropout', - value: 'custom-metrics:summary.json:loss-params:params.yaml:dropout' + description: 'Checkpoint Trend Plot', + detail: 'metrics:summary.json:loss', + label: 'loss', + value: 'custom-summary.json:loss-epoch' }, { - description: - 'metrics:summary.json:accuracy vs params:params.yaml:epochs', + description: 'Metric Vs Param Plot', + detail: 'metrics:summary.json:accuracy vs params:params.yaml:epochs', label: 'accuracy vs epochs', - value: - 'custom-metrics:summary.json:accuracy-params:params.yaml:epochs' + value: 'custom-summary.json:accuracy-params.yaml:epochs' }, { - description: - 'metrics:summary.json:learning_rate vs param:summary.json:process.threshold', + description: 'Metric Vs Param Plot', + detail: + 'metrics:summary.json:learning_rate vs params:summary.json:process.threshold', label: 'learning_rate vs threshold', value: - 'custom-metrics:summary.json:learning_rate-param:summary.json:process.threshold' + 'custom-summary.json:learning_rate-summary.json:process.threshold' } ], { title: Title.SELECT_CUSTOM_PLOTS_TO_REMOVE } @@ -84,6 +93,37 @@ describe('pickCustomPlots', () => { }) }) +describe('pickCustomPlotType', () => { + it('should return a chosen custom plot type', async () => { + const expectedType = CustomPlotType.CHECKPOINT + mockedQuickPickValue.mockResolvedValueOnce(expectedType) + + const picked = await pickCustomPlotType() + + expect(picked).toStrictEqual(expectedType) + expect(mockedQuickPickValue).toHaveBeenCalledTimes(1) + expect(mockedQuickPickValue).toHaveBeenCalledWith( + [ + { + description: + 'A linear plot that compares a chosen metric and param with current experiments.', + label: 'Metric Vs Param', + value: CustomPlotType.METRIC_VS_PARAM + }, + { + description: + 'A linear plot that shows how a chosen metric changes over selected experiments.', + label: 'Checkpoint Trend', + value: CustomPlotType.CHECKPOINT + } + ], + { + title: Title.SELECT_PLOT_TYPE_CUSTOM_PLOT + } + ) + }) +}) + describe('pickMetricAndParam', () => { it('should end early if there are no metrics or params available', async () => { mockedQuickPickValue.mockResolvedValueOnce(undefined) @@ -92,6 +132,50 @@ describe('pickMetricAndParam', () => { expect(mockedShowError).toHaveBeenCalledTimes(1) }) + it('should end early if user does not select a param or a metric', async () => { + mockedQuickPickValue + .mockResolvedValueOnce({ + hasChildren: false, + label: 'dropout', + path: 'params:params.yaml:dropout', + type: ColumnType.PARAMS + }) + .mockResolvedValueOnce(undefined) + .mockResolvedValue(undefined) + + const noParamSelected = await pickMetricAndParam([ + { + hasChildren: false, + label: 'dropout', + path: 'params:params.yaml:dropout', + type: ColumnType.PARAMS + }, + { + hasChildren: false, + label: 'accuracy', + path: 'metrics:summary.json:accuracy', + type: ColumnType.METRICS + } + ]) + expect(noParamSelected).toBeUndefined() + + const noMetricSelected = await pickMetricAndParam([ + { + hasChildren: false, + label: 'dropout', + path: 'params:params.yaml:dropout', + type: ColumnType.PARAMS + }, + { + hasChildren: false, + label: 'accuracy', + path: 'metrics:summary.json:accuracy', + type: ColumnType.METRICS + } + ]) + expect(noMetricSelected).toBeUndefined() + }) + it('should return a metric and a param if both are selected by the user', async () => { const expectedMetric = { label: 'loss', @@ -99,7 +183,7 @@ describe('pickMetricAndParam', () => { } const expectedParam = { label: 'epochs', - path: 'summary.json:loss-params:params.yaml:epochs' + path: 'params:params.yaml:epochs' } mockedQuickPickValue .mockResolvedValueOnce(expectedMetric) @@ -116,13 +200,89 @@ describe('pickMetricAndParam', () => { { hasChildren: false, label: 'accuracy', - path: 'summary.json:accuracy', + path: 'metrics:summary.json:accuracy', type: ColumnType.METRICS } ]) expect(metricAndParam).toStrictEqual({ - metric: expectedMetric.path, - param: expectedParam.path + metric: 'summary.json:loss', + param: 'params.yaml:epochs' }) }) }) + +describe('pickMetric', () => { + it('should end early if there are no metrics or params available', async () => { + mockedQuickPickValue.mockResolvedValueOnce(undefined) + const undef = await pickMetric([]) + expect(undef).toBeUndefined() + expect(mockedShowError).toHaveBeenCalledTimes(1) + }) + + it('should end early if user does not select a metric', async () => { + mockedQuickPickValue.mockResolvedValue(undefined) + + const noMetricSelected = await pickMetric([ + { + hasChildren: false, + label: 'dropout', + path: 'params:params.yaml:dropout', + type: ColumnType.PARAMS + }, + { + hasChildren: false, + label: 'dropout', + path: 'params:params.yaml:epochs', + type: ColumnType.PARAMS + }, + { + hasChildren: false, + label: 'accuracy', + path: 'metrics:summary.json:loss', + type: ColumnType.METRICS + }, + { + hasChildren: false, + label: 'accuracy', + path: 'metrics:summary.json:accuracy', + type: ColumnType.METRICS + } + ]) + expect(noMetricSelected).toBeUndefined() + }) + + it('should return a metric', async () => { + const expectedMetric = { + label: 'loss', + path: 'metrics:summary.json:loss' + } + mockedQuickPickValue.mockResolvedValueOnce(expectedMetric) + const metric = await pickMetric([ + { ...expectedMetric, hasChildren: false, type: ColumnType.METRICS }, + { + hasChildren: false, + label: 'accuracy', + path: 'metrics:summary.json:accuracy', + type: ColumnType.METRICS + } + ]) + + expect(metric).toStrictEqual('summary.json:loss') + expect(mockedQuickPickValue).toHaveBeenCalledTimes(1) + expect(mockedQuickPickValue).toHaveBeenCalledWith( + [ + { + description: 'metrics:summary.json:loss', + label: 'loss', + value: { label: 'loss', path: 'metrics:summary.json:loss' } + }, + { + description: 'metrics:summary.json:accuracy', + label: 'accuracy', + value: { label: 'accuracy', path: 'metrics:summary.json:accuracy' } + } + ], + { title: Title.SELECT_METRIC_CUSTOM_PLOT } + ) + }) +}) diff --git a/extension/src/plots/model/quickPick.ts b/extension/src/plots/model/quickPick.ts index eee9723a3e..e9cacd4754 100644 --- a/extension/src/plots/model/quickPick.ts +++ b/extension/src/plots/model/quickPick.ts @@ -1,40 +1,101 @@ -import { CustomPlotsOrderValue } from '.' import { getCustomPlotId } from './collect' -import { splitColumnPath } from '../../experiments/columns/paths' +import { + getFullValuePath, + CustomPlotsOrderValue, + isCheckpointValue, + removeColumnTypeFromPath +} from './custom' +import { + FILE_SEPARATOR, + splitColumnPath +} from '../../experiments/columns/paths' import { pickFromColumnLikes } from '../../experiments/columns/quickPick' import { Column, ColumnType } from '../../experiments/webview/contract' import { definedAndNonEmpty } from '../../util/array' import { quickPickManyValues, + quickPickValue, QuickPickOptionsWithTitle } from '../../vscode/quickPick' import { Title } from '../../vscode/title' import { Toast } from '../../vscode/toast' +import { CustomPlotType } from '../webview/contract' + +const getMetricVsParamPlotItem = (metric: string, param: string) => { + const fullMetric = getFullValuePath( + ColumnType.METRICS, + metric, + FILE_SEPARATOR + ) + const fullParam = getFullValuePath(ColumnType.PARAMS, param, FILE_SEPARATOR) + const splitMetric = splitColumnPath(fullMetric) + const splitParam = splitColumnPath(fullParam) + + return { + description: 'Metric Vs Param Plot', + detail: `${fullMetric} vs ${fullParam}`, + label: `${splitMetric[splitMetric.length - 1]} vs ${ + splitParam[splitParam.length - 1] + }`, + value: getCustomPlotId(metric, param) + } +} + +const getCheckpointPlotItem = (metric: string) => { + const fullMetric = getFullValuePath( + ColumnType.METRICS, + metric, + FILE_SEPARATOR + ) + const splitMetric = splitColumnPath(fullMetric) + return { + description: 'Checkpoint Trend Plot', + detail: fullMetric, + label: splitMetric[splitMetric.length - 1], + value: getCustomPlotId(metric) + } +} export const pickCustomPlots = ( - plots: CustomPlotsOrderValue[], + plotsOrderValues: CustomPlotsOrderValue[], noPlotsErrorMessage: string, quickPickOptions: QuickPickOptionsWithTitle ): Thenable => { - if (!definedAndNonEmpty(plots)) { + if (!definedAndNonEmpty(plotsOrderValues)) { return Toast.showError(noPlotsErrorMessage) } - const plotsItems = plots.map(({ metric, param }) => { - const splitMetric = splitColumnPath(metric) - const splitParam = splitColumnPath(param) - return { - description: `${metric} vs ${param}`, - label: `${splitMetric[splitMetric.length - 1]} vs ${ - splitParam[splitParam.length - 1] - }`, - value: getCustomPlotId(metric, param) - } - }) + const plotsItems = plotsOrderValues.map(value => + isCheckpointValue(value.type) + ? getCheckpointPlotItem(value.metric) + : getMetricVsParamPlotItem(value.metric, value.param) + ) return quickPickManyValues(plotsItems, quickPickOptions) } +export const pickCustomPlotType = (): Thenable => { + return quickPickValue( + [ + { + description: + 'A linear plot that compares a chosen metric and param with current experiments.', + label: 'Metric Vs Param', + value: CustomPlotType.METRIC_VS_PARAM + }, + { + description: + 'A linear plot that shows how a chosen metric changes over selected experiments.', + label: 'Checkpoint Trend', + value: CustomPlotType.CHECKPOINT + } + ], + { + title: Title.SELECT_PLOT_TYPE_CUSTOM_PLOT + } + ) +} + const getTypeColumnLikes = (columns: Column[], columnType: ColumnType) => columns .filter(({ type }) => type === columnType) @@ -66,5 +127,39 @@ export const pickMetricAndParam = async (columns: Column[]) => { if (!param) { return } - return { metric: metric.path, param: param.path } + + return { + metric: removeColumnTypeFromPath( + metric.path, + ColumnType.METRICS, + FILE_SEPARATOR + ), + param: removeColumnTypeFromPath( + param.path, + ColumnType.PARAMS, + FILE_SEPARATOR + ) + } +} + +export const pickMetric = async (columns: Column[]) => { + const metricColumnLikes = getTypeColumnLikes(columns, ColumnType.METRICS) + + if (!definedAndNonEmpty(metricColumnLikes)) { + return Toast.showError('There are no metrics to select from.') + } + + const metric = await pickFromColumnLikes(metricColumnLikes, { + title: Title.SELECT_METRIC_CUSTOM_PLOT + }) + + if (!metric) { + return + } + + return removeColumnTypeFromPath( + metric.path, + ColumnType.METRICS, + FILE_SEPARATOR + ) } diff --git a/extension/src/plots/webview/contract.ts b/extension/src/plots/webview/contract.ts index 67c6b82040..6dcadff4d2 100644 --- a/extension/src/plots/webview/contract.ts +++ b/extension/src/plots/webview/contract.ts @@ -17,28 +17,24 @@ export const DEFAULT_PLOT_HEIGHT = PlotHeight.SMALL export const DEFAULT_PLOT_WIDTH = 2 export enum PlotsSection { - CHECKPOINT_PLOTS = 'checkpoint-plots', TEMPLATE_PLOTS = 'template-plots', COMPARISON_TABLE = 'comparison-table', CUSTOM_PLOTS = 'custom-plots' } export const DEFAULT_SECTION_NB_ITEMS_PER_ROW_OR_WIDTH = { - [PlotsSection.CHECKPOINT_PLOTS]: DEFAULT_NB_ITEMS_PER_ROW, [PlotsSection.TEMPLATE_PLOTS]: DEFAULT_NB_ITEMS_PER_ROW, [PlotsSection.COMPARISON_TABLE]: DEFAULT_PLOT_WIDTH, [PlotsSection.CUSTOM_PLOTS]: DEFAULT_NB_ITEMS_PER_ROW } export const DEFAULT_HEIGHT = { - [PlotsSection.CHECKPOINT_PLOTS]: DEFAULT_PLOT_HEIGHT, [PlotsSection.TEMPLATE_PLOTS]: DEFAULT_PLOT_HEIGHT, [PlotsSection.COMPARISON_TABLE]: DEFAULT_PLOT_HEIGHT, [PlotsSection.CUSTOM_PLOTS]: DEFAULT_PLOT_HEIGHT } export const DEFAULT_SECTION_COLLAPSED = { - [PlotsSection.CHECKPOINT_PLOTS]: false, [PlotsSection.TEMPLATE_PLOTS]: false, [PlotsSection.COMPARISON_TABLE]: false, [PlotsSection.CUSTOM_PLOTS]: false @@ -76,6 +72,17 @@ export interface PlotsComparisonData { revisions: Revision[] } +export enum CustomPlotType { + CHECKPOINT = 'checkpoint', + METRIC_VS_PARAM = 'metricVsParam' +} + +export type MetricVsParamPlotValues = { + expName: string + metric: number + param: number +}[] + export type CheckpointPlotValues = { group: string iteration: number @@ -84,40 +91,35 @@ export type CheckpointPlotValues = { export type ColorScale = { domain: string[]; range: Color[] } -export type CheckpointPlot = { +type CustomPlotBase = { id: string - values: CheckpointPlotValues + metric: string + param: string } -export type CustomPlotValues = { - expName: string - metric: number - param: number -} +export type CheckpointPlot = { + values: CheckpointPlotValues + type: CustomPlotType.CHECKPOINT +} & CustomPlotBase -export type CustomPlotData = { - id: string - values: CustomPlotValues[] - metric: string - param: string +export type MetricVsParamPlot = { + values: MetricVsParamPlotValues + type: CustomPlotType.METRIC_VS_PARAM +} & CustomPlotBase + +export type CustomPlot = MetricVsParamPlot | CheckpointPlot + +export type CustomPlotData = CustomPlot & { + yTitle: string } export type CustomPlotsData = { plots: CustomPlotData[] nbItemsPerRow: number + colors: ColorScale | undefined height: PlotHeight } -export type CheckpointPlotData = CheckpointPlot & { title: string } - -export type CheckpointPlotsData = { - plots: CheckpointPlotData[] - colors: ColorScale - nbItemsPerRow: number - height: PlotHeight - selectedMetrics?: string[] -} - export enum PlotsType { VEGA = 'vega', IMAGE = 'image' @@ -173,7 +175,6 @@ export type ComparisonPlot = { export enum PlotsDataKeys { COMPARISON = 'comparison', - CHECKPOINT = 'checkpoint', CUSTOM = 'custom', HAS_UNSELECTED_PLOTS = 'hasUnselectedPlots', HAS_PLOTS = 'hasPlots', @@ -185,7 +186,6 @@ export enum PlotsDataKeys { export type PlotsData = | { [PlotsDataKeys.COMPARISON]?: PlotsComparisonData | null - [PlotsDataKeys.CHECKPOINT]?: CheckpointPlotsData | null [PlotsDataKeys.CUSTOM]?: CustomPlotsData | null [PlotsDataKeys.HAS_PLOTS]?: boolean [PlotsDataKeys.HAS_UNSELECTED_PLOTS]?: boolean diff --git a/extension/src/plots/webview/messages.ts b/extension/src/plots/webview/messages.ts index 03d187f4b1..4c64602ab4 100644 --- a/extension/src/plots/webview/messages.ts +++ b/extension/src/plots/webview/messages.ts @@ -2,6 +2,7 @@ import isEmpty from 'lodash.isempty' import { ComparisonPlot, ComparisonRevisionData, + CustomPlotType, PlotHeight, PlotsData as TPlotsData, Revision, @@ -21,12 +22,22 @@ import { import { PlotsModel } from '../model' import { PathsModel } from '../paths/model' import { BaseWebview } from '../../webview' +import { + pickCustomPlots, + pickCustomPlotType, + pickMetric, + pickMetricAndParam +} from '../model/quickPick' import { getModifiedTime, openImageFileInEditor } from '../../fileSystem' -import { pickCustomPlots, pickMetricAndParam } from '../model/quickPick' import { Title } from '../../vscode/title' -import { ColumnType } from '../../experiments/webview/contract' -import { FILE_SEPARATOR } from '../../experiments/columns/paths' import { reorderObjectList } from '../../util/array' +import { + CHECKPOINTS_PARAM, + CustomPlotsOrderValue, + doesCustomPlotAlreadyExist, + isCheckpointValue +} from '../model/custom' +import { getCustomPlotId } from '../model/collect' export class WebviewMessages { private readonly paths: PathsModel @@ -58,7 +69,6 @@ export class WebviewMessages { this.plots.getOverrideRevisionDetails() void this.getWebview()?.show({ - checkpoint: this.getCheckpointPlots(), comparison: this.getComparisonPlots(overrideComparison), custom: this.getCustomPlots(), hasPlots: !!this.paths.hasPaths(), @@ -69,18 +79,10 @@ export class WebviewMessages { }) } - public sendCheckpointPlotsMessage() { - void this.getWebview()?.show({ - checkpoint: this.getCheckpointPlots() - }) - } - public handleMessageFromWebview(message: MessageFromWebview) { switch (message.type) { case MessageFromWebviewType.ADD_CUSTOM_PLOT: return this.addCustomPlot() - case MessageFromWebviewType.TOGGLE_METRIC: - return this.setSelectedMetrics(message.payload) case MessageFromWebviewType.RESIZE_PLOTS: return this.setPlotSize( message.payload.section, @@ -95,8 +97,6 @@ export class WebviewMessages { return this.setComparisonRowsOrder(message.payload) case MessageFromWebviewType.REORDER_PLOTS_TEMPLATES: return this.setTemplateOrder(message.payload) - case MessageFromWebviewType.REORDER_PLOTS_METRICS: - return this.setMetricOrder(message.payload) case MessageFromWebviewType.REORDER_PLOTS_CUSTOM: return this.setCustomPlotsOrder(message.payload) case MessageFromWebviewType.SELECT_PLOTS: @@ -126,11 +126,6 @@ export class WebviewMessages { } } - private setSelectedMetrics(metrics: string[]) { - this.plots.setSelectedMetrics(metrics) - this.sendCheckpointPlotsAndEvent(EventName.VIEWS_PLOTS_METRICS_SELECTED) - } - private setPlotSize( section: PlotsSection, nbItemsPerRow: number, @@ -145,9 +140,6 @@ export class WebviewMessages { ) switch (section) { - case PlotsSection.CHECKPOINT_PLOTS: - this.sendCheckpointPlotsMessage() - break case PlotsSection.COMPARISON_TABLE: this.sendComparisonPlots() break @@ -201,12 +193,9 @@ export class WebviewMessages { ) } - private setMetricOrder(order: string[]) { - this.plots.setMetricOrder(order) - this.sendCheckpointPlotsAndEvent(EventName.VIEWS_REORDER_PLOTS_METRICS) - } - - private async addCustomPlot() { + private async addMetricVsParamPlot(): Promise< + CustomPlotsOrderValue | undefined + > { const metricAndParam = await pickMetricAndParam( this.experiments.getColumnTerminalNodes() ) @@ -215,24 +204,65 @@ export class WebviewMessages { return } - const plotAlreadyExists = this.plots - .getCustomPlotsOrder() - .some( - ({ param, metric }) => - param === metricAndParam.param && metric === metricAndParam.metric - ) + const plotAlreadyExists = doesCustomPlotAlreadyExist( + this.plots.getCustomPlotsOrder(), + metricAndParam.metric, + metricAndParam.param + ) if (plotAlreadyExists) { return Toast.showError('Custom plot already exists.') } - this.plots.addCustomPlot(metricAndParam) - this.sendCustomPlots() - sendTelemetryEvent( - EventName.VIEWS_PLOTS_CUSTOM_PLOT_ADDED, - undefined, - undefined + const plot = { + ...metricAndParam, + type: CustomPlotType.METRIC_VS_PARAM + } + this.plots.addCustomPlot(plot) + this.sendCustomPlotsAndEvent(EventName.VIEWS_PLOTS_CUSTOM_PLOT_ADDED) + } + + private async addCheckpointPlot(): Promise< + CustomPlotsOrderValue | undefined + > { + const metric = await pickMetric(this.experiments.getColumnTerminalNodes()) + + if (!metric) { + return + } + + const plotAlreadyExists = doesCustomPlotAlreadyExist( + this.plots.getCustomPlotsOrder(), + metric ) + + if (plotAlreadyExists) { + return Toast.showError('Custom plot already exists.') + } + + const plot: CustomPlotsOrderValue = { + metric, + param: CHECKPOINTS_PARAM, + type: CustomPlotType.CHECKPOINT + } + this.plots.addCustomPlot(plot) + this.sendCustomPlotsAndEvent(EventName.VIEWS_PLOTS_CUSTOM_PLOT_ADDED) + } + + private async addCustomPlot() { + if (!this.experiments.hasCheckpoints()) { + return this.addMetricVsParamPlot() + } + + const plotType = await pickCustomPlotType() + + if (!plotType) { + return + } + + return isCheckpointValue(plotType) + ? this.addCheckpointPlot() + : this.addMetricVsParamPlot() } private async removeCustomPlots() { @@ -249,35 +279,39 @@ export class WebviewMessages { } this.plots.removeCustomPlots(selectedPlotsIds) - this.sendCustomPlots() - sendTelemetryEvent( - EventName.VIEWS_PLOTS_CUSTOM_PLOT_REMOVED, - undefined, - undefined - ) + this.sendCustomPlotsAndEvent(EventName.VIEWS_PLOTS_CUSTOM_PLOT_REMOVED) } private setCustomPlotsOrder(plotIds: string[]) { - const customPlots = this.plots.getCustomPlots()?.plots - if (!customPlots) { - return - } + const customPlotsOrderWithId = this.plots + .getCustomPlotsOrder() + .map(value => ({ + ...value, + id: getCustomPlotId(value.metric, value.param) + })) + + const newOrder: CustomPlotsOrderValue[] = reorderObjectList( + plotIds, + customPlotsOrderWithId, + 'id' + ).map(({ metric, param, type }) => ({ + metric, + param, + type + })) - const buildMetricOrParamPath = (type: string, path: string) => - type + FILE_SEPARATOR + path - const newOrder = reorderObjectList(plotIds, customPlots, 'id').map( - ({ metric, param }) => ({ - metric: buildMetricOrParamPath(ColumnType.METRICS, metric), - param: buildMetricOrParamPath(ColumnType.PARAMS, param) - }) - ) this.plots.setCustomPlotsOrder(newOrder) + this.sendCustomPlotsAndEvent(EventName.VIEWS_REORDER_PLOTS_CUSTOM) + } + + private sendCustomPlotsAndEvent( + event: + | typeof EventName.VIEWS_PLOTS_CUSTOM_PLOT_ADDED + | typeof EventName.VIEWS_PLOTS_CUSTOM_PLOT_REMOVED + | typeof EventName.VIEWS_REORDER_PLOTS_CUSTOM + ) { this.sendCustomPlots() - sendTelemetryEvent( - EventName.VIEWS_REORDER_PLOTS_CUSTOM, - undefined, - undefined - ) + sendTelemetryEvent(event, undefined, undefined) } private selectPlotsFromWebview() { @@ -329,15 +363,6 @@ export class WebviewMessages { ) } - private sendCheckpointPlotsAndEvent( - event: - | typeof EventName.VIEWS_REORDER_PLOTS_METRICS - | typeof EventName.VIEWS_PLOTS_METRICS_SELECTED - ) { - this.sendCheckpointPlotsMessage() - sendTelemetryEvent(event, undefined, undefined) - } - private sendSectionCollapsed() { void this.getWebview()?.show({ sectionCollapsed: this.plots.getSectionCollapsed() @@ -433,10 +458,6 @@ export class WebviewMessages { return url } - private getCheckpointPlots() { - return this.plots.getCheckpointPlots() || null - } - private getCustomPlots() { return this.plots.getCustomPlots() || null } diff --git a/extension/src/telemetry/constants.ts b/extension/src/telemetry/constants.ts index 63781387bc..b818522f84 100644 --- a/extension/src/telemetry/constants.ts +++ b/extension/src/telemetry/constants.ts @@ -68,7 +68,6 @@ export const EventName = Object.assign( VIEWS_PLOTS_EXPERIMENT_TOGGLE: 'views.plots.toggleExperimentStatus', VIEWS_PLOTS_FOCUS_CHANGED: 'views.plots.focusChanged', VIEWS_PLOTS_MANUAL_REFRESH: 'views.plots.manualRefresh', - VIEWS_PLOTS_METRICS_SELECTED: 'views.plots.metricsSelected', VIEWS_PLOTS_REVISIONS_REORDERED: 'views.plots.revisionsReordered', VIEWS_PLOTS_SECTION_RESIZED: 'views.plots.sectionResized', VIEWS_PLOTS_SECTION_TOGGLE: 'views.plots.toggleSection', @@ -76,7 +75,6 @@ export const EventName = Object.assign( VIEWS_PLOTS_SELECT_PLOTS: 'view.plots.selectPlots', VIEWS_PLOTS_ZOOM_PLOT: 'views.plots.zoomPlot', VIEWS_REORDER_PLOTS_CUSTOM: 'views.plots.customReordered', - VIEWS_REORDER_PLOTS_METRICS: 'views.plots.metricsReordered', VIEWS_REORDER_PLOTS_TEMPLATES: 'views.plots.templatesReordered', VIEWS_SETUP_CLOSE: 'view.setup.closed', @@ -246,7 +244,6 @@ export interface IEventNamePropertyMapping { [EventName.VIEWS_PLOTS_CUSTOM_PLOT_ADDED]: undefined [EventName.VIEWS_PLOTS_FOCUS_CHANGED]: WebviewFocusChangedProperties [EventName.VIEWS_PLOTS_MANUAL_REFRESH]: { revisions: number } - [EventName.VIEWS_PLOTS_METRICS_SELECTED]: undefined [EventName.VIEWS_PLOTS_REVISIONS_REORDERED]: undefined [EventName.VIEWS_PLOTS_COMPARISON_ROWS_REORDERED]: undefined [EventName.VIEWS_PLOTS_SECTION_RESIZED]: { @@ -259,7 +256,6 @@ export interface IEventNamePropertyMapping { [EventName.VIEWS_PLOTS_SELECT_PLOTS]: undefined [EventName.VIEWS_PLOTS_EXPERIMENT_TOGGLE]: undefined [EventName.VIEWS_PLOTS_ZOOM_PLOT]: { isImage: boolean } - [EventName.VIEWS_REORDER_PLOTS_METRICS]: undefined [EventName.VIEWS_REORDER_PLOTS_CUSTOM]: undefined [EventName.VIEWS_REORDER_PLOTS_TEMPLATES]: undefined diff --git a/extension/src/test/e2e/extension.test.ts b/extension/src/test/e2e/extension.test.ts index 17df6ca615..6572b762d9 100644 --- a/extension/src/test/e2e/extension.test.ts +++ b/extension/src/test/e2e/extension.test.ts @@ -125,7 +125,7 @@ describe('Plots Webview', function () { await browser.waitUntil( async () => { - return (await webview.vegaVisualization$$.length) === 10 + return (await webview.vegaVisualization$$.length) === 5 }, { timeout: 30000 } ) diff --git a/extension/src/test/fixtures/expShow/base/checkpointPlots.ts b/extension/src/test/fixtures/expShow/base/checkpointPlots.ts deleted file mode 100644 index 91fd7a733f..0000000000 --- a/extension/src/test/fixtures/expShow/base/checkpointPlots.ts +++ /dev/null @@ -1,99 +0,0 @@ -import { copyOriginalColors } from '../../../../experiments/model/status/colors' -import { - CheckpointPlotsData, - DEFAULT_NB_ITEMS_PER_ROW, - DEFAULT_PLOT_HEIGHT -} from '../../../../plots/webview/contract' - -const colors = copyOriginalColors() - -const data: CheckpointPlotsData = { - colors: { - domain: ['exp-e7a67', 'test-branch', 'exp-83425'], - range: [colors[2], colors[3], colors[4]] - }, - plots: [ - { - id: 'summary.json:loss', - title: 'summary.json:loss', - values: [ - { group: 'exp-83425', iteration: 1, y: 1.9896177053451538 }, - { group: 'exp-83425', iteration: 2, y: 1.9329891204833984 }, - { group: 'exp-83425', iteration: 3, y: 1.8798457384109497 }, - { group: 'exp-83425', iteration: 4, y: 1.8261293172836304 }, - { group: 'exp-83425', iteration: 5, y: 1.775016188621521 }, - { group: 'exp-83425', iteration: 6, y: 1.775016188621521 }, - { group: 'test-branch', iteration: 1, y: 1.9882521629333496 }, - { group: 'test-branch', iteration: 2, y: 1.9293040037155151 }, - { group: 'test-branch', iteration: 3, y: 1.9293040037155151 }, - { group: 'exp-e7a67', iteration: 1, y: 2.020392894744873 }, - { group: 'exp-e7a67', iteration: 2, y: 2.0205044746398926 }, - { group: 'exp-e7a67', iteration: 3, y: 2.0205044746398926 } - ] - }, - { - id: 'summary.json:accuracy', - title: 'summary.json:accuracy', - values: [ - { group: 'exp-83425', iteration: 1, y: 0.40904998779296875 }, - { group: 'exp-83425', iteration: 2, y: 0.46094998717308044 }, - { group: 'exp-83425', iteration: 3, y: 0.5113166570663452 }, - { group: 'exp-83425', iteration: 4, y: 0.557449996471405 }, - { group: 'exp-83425', iteration: 5, y: 0.5926499962806702 }, - { group: 'exp-83425', iteration: 6, y: 0.5926499962806702 }, - { group: 'test-branch', iteration: 1, y: 0.4083833396434784 }, - { group: 'test-branch', iteration: 2, y: 0.4668000042438507 }, - { group: 'test-branch', iteration: 3, y: 0.4668000042438507 }, - { group: 'exp-e7a67', iteration: 1, y: 0.3723166584968567 }, - { group: 'exp-e7a67', iteration: 2, y: 0.3724166750907898 }, - { group: 'exp-e7a67', iteration: 3, y: 0.3724166750907898 } - ] - }, - { - id: 'summary.json:val_loss', - title: 'summary.json:val_loss', - values: [ - { group: 'exp-83425', iteration: 1, y: 1.9391471147537231 }, - { group: 'exp-83425', iteration: 2, y: 1.8825950622558594 }, - { group: 'exp-83425', iteration: 3, y: 1.827923059463501 }, - { group: 'exp-83425', iteration: 4, y: 1.7749212980270386 }, - { group: 'exp-83425', iteration: 5, y: 1.7233840227127075 }, - { group: 'exp-83425', iteration: 6, y: 1.7233840227127075 }, - { group: 'test-branch', iteration: 1, y: 1.9363881349563599 }, - { group: 'test-branch', iteration: 2, y: 1.8770883083343506 }, - { group: 'test-branch', iteration: 3, y: 1.8770883083343506 }, - { group: 'exp-e7a67', iteration: 1, y: 1.9979370832443237 }, - { group: 'exp-e7a67', iteration: 2, y: 1.9979370832443237 }, - { group: 'exp-e7a67', iteration: 3, y: 1.9979370832443237 } - ] - }, - { - id: 'summary.json:val_accuracy', - title: 'summary.json:val_accuracy', - values: [ - { group: 'exp-83425', iteration: 1, y: 0.49399998784065247 }, - { group: 'exp-83425', iteration: 2, y: 0.5550000071525574 }, - { group: 'exp-83425', iteration: 3, y: 0.6035000085830688 }, - { group: 'exp-83425', iteration: 4, y: 0.6414999961853027 }, - { group: 'exp-83425', iteration: 5, y: 0.6704000234603882 }, - { group: 'exp-83425', iteration: 6, y: 0.6704000234603882 }, - { group: 'test-branch', iteration: 1, y: 0.4970000088214874 }, - { group: 'test-branch', iteration: 2, y: 0.5608000159263611 }, - { group: 'test-branch', iteration: 3, y: 0.5608000159263611 }, - { group: 'exp-e7a67', iteration: 1, y: 0.4277999997138977 }, - { group: 'exp-e7a67', iteration: 2, y: 0.4277999997138977 }, - { group: 'exp-e7a67', iteration: 3, y: 0.4277999997138977 } - ] - } - ], - selectedMetrics: [ - 'summary.json:loss', - 'summary.json:accuracy', - 'summary.json:val_loss', - 'summary.json:val_accuracy' - ], - nbItemsPerRow: DEFAULT_NB_ITEMS_PER_ROW, - height: DEFAULT_PLOT_HEIGHT -} - -export default data diff --git a/extension/src/test/fixtures/expShow/base/customPlots.ts b/extension/src/test/fixtures/expShow/base/customPlots.ts index 9cb5457f16..3eb732d19f 100644 --- a/extension/src/test/fixtures/expShow/base/customPlots.ts +++ b/extension/src/test/fixtures/expShow/base/customPlots.ts @@ -1,15 +1,203 @@ +import { ExperimentWithCheckpoints } from '../../../../experiments/model' +import { copyOriginalColors } from '../../../../experiments/model/status/colors' +import { + CHECKPOINTS_PARAM, + CustomPlotsOrderValue +} from '../../../../plots/model/custom' import { CustomPlotsData, + CustomPlotType, DEFAULT_NB_ITEMS_PER_ROW, DEFAULT_PLOT_HEIGHT } from '../../../../plots/webview/contract' +export const customPlotsOrderFixture: CustomPlotsOrderValue[] = [ + { + metric: 'summary.json:loss', + param: 'params.yaml:dropout', + type: CustomPlotType.METRIC_VS_PARAM + }, + { + metric: 'summary.json:accuracy', + param: 'params.yaml:epochs', + type: CustomPlotType.METRIC_VS_PARAM + }, + { + metric: 'summary.json:loss', + param: CHECKPOINTS_PARAM, + type: CustomPlotType.CHECKPOINT + }, + { + metric: 'summary.json:accuracy', + param: CHECKPOINTS_PARAM, + type: CustomPlotType.CHECKPOINT + } +] + +export const experimentsWithCheckpoints: ExperimentWithCheckpoints[] = [ + { + id: '12345', + metrics: { + 'summary.json': { + accuracy: 0.3724166750907898, + loss: 2.0205044746398926 + } + }, + label: 'exp-e7a67', + params: { 'params.yaml': { dropout: 0.15, epochs: 2 } }, + checkpoints: [ + { + id: '12345', + metrics: { + 'summary.json': { + accuracy: 0.3724166750907898, + loss: 2.0205044746398926 + } + }, + label: 'exp-e7a67', + params: { 'params.yaml': { dropout: 0.15, epochs: 2 } } + }, + { + id: '12345', + label: '123', + metrics: { + 'summary.json': { + accuracy: 0.3723166584968567, + loss: 2.020392894744873 + } + }, + name: 'exp-e7a67', + params: { 'params.yaml': { dropout: 0.15, epochs: 2 } } + } + ] + }, + { + id: '12345', + label: '123', + metrics: { + 'summary.json': { + accuracy: 0.4668000042438507, + loss: 1.9293040037155151 + } + }, + name: 'test-branch', + params: { 'params.yaml': { dropout: 0.122, epochs: 2 } }, + checkpoints: [ + { + id: '12345', + label: '123', + metrics: { + 'summary.json': { + accuracy: 0.4668000042438507, + loss: 1.9293040037155151 + } + }, + name: 'test-branch', + params: { 'params.yaml': { dropout: 0.122, epochs: 2 } } + }, + { + id: '12345', + label: '123', + metrics: { + 'summary.json': { + accuracy: 0.4083833396434784, + loss: 1.9882521629333496 + } + }, + name: 'test-branch', + params: { 'params.yaml': { dropout: 0.122, epochs: 2 } } + } + ] + }, + { + id: '12345', + label: '123', + metrics: { + 'summary.json': { + accuracy: 0.5926499962806702, + loss: 1.775016188621521 + } + }, + name: 'exp-83425', + params: { 'params.yaml': { dropout: 0.124, epochs: 5 } }, + checkpoints: [ + { + id: '12345', + label: '123', + metrics: { + 'summary.json': { + accuracy: 0.5926499962806702, + loss: 1.775016188621521 + } + }, + name: 'exp-83425', + params: { 'params.yaml': { dropout: 0.124, epochs: 5 } } + }, + { + id: '12345', + label: '123', + metrics: { + 'summary.json': { + accuracy: 0.557449996471405, + loss: 1.8261293172836304 + } + }, + name: 'exp-83425', + params: { 'params.yaml': { dropout: 0.124, epochs: 5 } } + }, + { + id: '12345', + label: '123', + metrics: { + 'summary.json': { + accuracy: 0.5113166570663452, + loss: 1.8798457384109497 + } + }, + name: 'exp-83425', + params: { 'params.yaml': { dropout: 0.124, epochs: 5 } } + }, + { + id: '12345', + label: '123', + metrics: { + 'summary.json': { + accuracy: 0.46094998717308044, + loss: 1.9329891204833984 + } + }, + name: 'exp-83425', + params: { 'params.yaml': { dropout: 0.124, epochs: 5 } } + }, + { + id: '12345', + label: '123', + metrics: { + 'summary.json': { + accuracy: 0.40904998779296875, + loss: 1.9896177053451538 + } + }, + name: 'exp-83425', + params: { 'params.yaml': { dropout: 0.124, epochs: 5 } } + } + ] + } +] + +const colors = copyOriginalColors() + const data: CustomPlotsData = { + colors: { + domain: ['exp-e7a67', 'test-branch', 'exp-83425'], + range: [colors[2], colors[3], colors[4]] + }, plots: [ { - id: 'custom-metrics:summary.json:loss-params:params.yaml:dropout', + id: 'custom-summary.json:loss-params.yaml:dropout', metric: 'summary.json:loss', param: 'params.yaml:dropout', + type: CustomPlotType.METRIC_VS_PARAM, values: [ { expName: 'exp-e7a67', @@ -17,38 +205,83 @@ const data: CustomPlotsData = { param: 0.15 }, { - expName: 'exp-83425', + expName: 'test-branch', metric: 1.9293040037155151, - param: 0.25 + param: 0.122 }, { - expName: 'exp-f13bca', - metric: 2.298503875732422, - param: 0.32 + expName: 'exp-83425', + metric: 1.775016188621521, + param: 0.124 } - ] + ], + yTitle: 'summary.json:loss' }, { - id: 'custom-metrics:summary.json:accuracy-params:params.yaml:epochs', + id: 'custom-summary.json:accuracy-params.yaml:epochs', metric: 'summary.json:accuracy', param: 'params.yaml:epochs', + type: CustomPlotType.METRIC_VS_PARAM, values: [ { expName: 'exp-e7a67', - metric: 0.4668000042438507, - param: 16 + metric: 0.3724166750907898, + param: 2 }, { - expName: 'exp-83425', - metric: 0.3484833240509033, - param: 10 + expName: 'test-branch', + metric: 0.4668000042438507, + param: 2 }, { - expName: 'exp-f13bca', - metric: 0.6768440509033, - param: 20 + expName: 'exp-83425', + metric: 0.5926499962806702, + param: 5 } - ] + ], + yTitle: 'summary.json:accuracy' + }, + { + id: 'custom-summary.json:loss-epoch', + metric: 'summary.json:loss', + param: CHECKPOINTS_PARAM, + values: [ + { group: 'exp-e7a67', iteration: 3, y: 2.0205044746398926 }, + { group: 'exp-e7a67', iteration: 2, y: 2.0205044746398926 }, + { group: 'exp-e7a67', iteration: 1, y: 2.020392894744873 }, + { group: 'test-branch', iteration: 3, y: 1.9293040037155151 }, + { group: 'test-branch', iteration: 2, y: 1.9293040037155151 }, + { group: 'test-branch', iteration: 1, y: 1.9882521629333496 }, + { group: 'exp-83425', iteration: 6, y: 1.775016188621521 }, + { group: 'exp-83425', iteration: 5, y: 1.775016188621521 }, + { group: 'exp-83425', iteration: 4, y: 1.8261293172836304 }, + { group: 'exp-83425', iteration: 3, y: 1.8798457384109497 }, + { group: 'exp-83425', iteration: 2, y: 1.9329891204833984 }, + { group: 'exp-83425', iteration: 1, y: 1.9896177053451538 } + ], + type: CustomPlotType.CHECKPOINT, + yTitle: 'summary.json:loss' + }, + { + id: 'custom-summary.json:accuracy-epoch', + metric: 'summary.json:accuracy', + param: CHECKPOINTS_PARAM, + values: [ + { group: 'exp-e7a67', iteration: 3, y: 0.3724166750907898 }, + { group: 'exp-e7a67', iteration: 2, y: 0.3724166750907898 }, + { group: 'exp-e7a67', iteration: 1, y: 0.3723166584968567 }, + { group: 'test-branch', iteration: 3, y: 0.4668000042438507 }, + { group: 'test-branch', iteration: 2, y: 0.4668000042438507 }, + { group: 'test-branch', iteration: 1, y: 0.4083833396434784 }, + { group: 'exp-83425', iteration: 6, y: 0.5926499962806702 }, + { group: 'exp-83425', iteration: 5, y: 0.5926499962806702 }, + { group: 'exp-83425', iteration: 4, y: 0.557449996471405 }, + { group: 'exp-83425', iteration: 3, y: 0.5113166570663452 }, + { group: 'exp-83425', iteration: 2, y: 0.46094998717308044 }, + { group: 'exp-83425', iteration: 1, y: 0.40904998779296875 } + ], + type: CustomPlotType.CHECKPOINT, + yTitle: 'summary.json:accuracy' } ], nbItemsPerRow: DEFAULT_NB_ITEMS_PER_ROW, diff --git a/extension/src/test/suite/experiments/model/tree.test.ts b/extension/src/test/suite/experiments/model/tree.test.ts index a88c27d0c9..d19e6e4346 100644 --- a/extension/src/test/suite/experiments/model/tree.test.ts +++ b/extension/src/test/suite/experiments/model/tree.test.ts @@ -23,9 +23,10 @@ import { RegisteredCliCommands, RegisteredCommands } from '../../../../commands/external' -import { buildPlots, getExpectedCheckpointPlotsData } from '../../plots/util' -import checkpointPlotsFixture from '../../../fixtures/expShow/base/checkpointPlots' +import { buildPlots, getExpectedCustomPlotsData } from '../../plots/util' +import customPlotsFixture from '../../../fixtures/expShow/base/customPlots' import expShowFixture from '../../../fixtures/expShow/base/output' +import plotsRevisionsFixture from '../../../fixtures/plotsDiff/revisions' import { ExperimentsTree } from '../../../../experiments/model/tree' import { buildExperiments, @@ -45,6 +46,11 @@ import { WorkspaceExperiments } from '../../../../experiments/workspace' import { ExperimentItem } from '../../../../experiments/model/collect' import { EXPERIMENT_WORKSPACE_ID } from '../../../../cli/dvc/contract' import { DvcReader } from '../../../../cli/dvc/reader' +import { + ColorScale, + CustomPlotType, + DEFAULT_SECTION_COLLAPSED +} from '../../../../plots/webview/contract' suite('Experiments Tree Test Suite', () => { const disposable = getTimeSafeDisposer() @@ -59,8 +65,8 @@ suite('Experiments Tree Test Suite', () => { // eslint-disable-next-line sonarjs/cognitive-complexity describe('ExperimentsTree', () => { - const { colors } = checkpointPlotsFixture - const { domain, range } = colors + const { colors } = customPlotsFixture + const { domain, range } = colors as ColorScale it('should appear in the UI', async () => { await expect( @@ -75,18 +81,19 @@ suite('Experiments Tree Test Suite', () => { const expectedRange = [...range] const webview = await plots.showWebview() + await webview.isReady() while (expectedDomain.length > 0) { - const expectedData = getExpectedCheckpointPlotsData( + const expectedData = getExpectedCustomPlotsData( expectedDomain, expectedRange ) - const { checkpoint } = getFirstArgOfLastCall(messageSpy) + const { custom } = getFirstArgOfLastCall(messageSpy) expect( - { checkpoint }, + { custom }, 'a message is sent with colors for the currently selected experiments' ).to.deep.equal(expectedData) messageSpy.resetHistory() @@ -107,9 +114,21 @@ suite('Experiments Tree Test Suite', () => { expect( messageSpy, - 'when there are no experiments selected we send null (show empty state)' + 'when there are no experiments selected we dont send checkpoint type plots' ).to.be.calledWithMatch({ - checkpoint: null + comparison: null, + custom: { + ...customPlotsFixture, + colors: undefined, + plots: customPlotsFixture.plots.filter( + plot => plot.type !== CustomPlotType.CHECKPOINT + ) + }, + hasPlots: false, + hasUnselectedPlots: false, + sectionCollapsed: DEFAULT_SECTION_COLLAPSED, + selectedRevisions: plotsRevisionsFixture.slice(0, 2), + template: null }) messageSpy.resetHistory() @@ -127,7 +146,7 @@ suite('Experiments Tree Test Suite', () => { expect(selected, 'the experiment is now selected').to.equal(range[0]) expect(messageSpy, 'we no longer send null').to.be.calledWithMatch( - getExpectedCheckpointPlotsData(expectedDomain, expectedRange) + getExpectedCustomPlotsData(expectedDomain, expectedRange) ) }).timeout(WEBVIEW_TEST_TIMEOUT) @@ -263,7 +282,7 @@ suite('Experiments Tree Test Suite', () => { messageSpy, 'a message is sent with colors for the currently selected experiments' ).to.be.calledWithMatch( - getExpectedCheckpointPlotsData([selectedDisplayName], [selectedColor]) + getExpectedCustomPlotsData([selectedDisplayName], [selectedColor]) ) }).timeout(WEBVIEW_TEST_TIMEOUT) diff --git a/extension/src/test/suite/plots/index.test.ts b/extension/src/test/suite/plots/index.test.ts index ec50f25d09..8ba89d0b63 100644 --- a/extension/src/test/suite/plots/index.test.ts +++ b/extension/src/test/suite/plots/index.test.ts @@ -8,8 +8,9 @@ import { commands, Uri } from 'vscode' import { buildPlots } from '../plots/util' import { Disposable } from '../../../extension' import expShowFixtureWithoutErrors from '../../fixtures/expShow/base/noErrors' -import checkpointPlotsFixture from '../../fixtures/expShow/base/checkpointPlots' -import customPlotsFixture from '../../fixtures/expShow/base/customPlots' +import customPlotsFixture, { + customPlotsOrderFixture +} from '../../fixtures/expShow/base/customPlots' import plotsDiffFixture from '../../fixtures/plotsDiff/output' import multiSourcePlotsDiffFixture from '../../fixtures/plotsDiff/multiSource' import templatePlotsFixture from '../../fixtures/plotsDiff/template' @@ -29,7 +30,8 @@ import { PlotsData as TPlotsData, PlotsSection, TemplatePlotGroup, - TemplatePlotsData + TemplatePlotsData, + CustomPlotType } from '../../../plots/webview/contract' import { TEMP_PLOTS_DIR } from '../../../cli/dvc/constants' import { WEBVIEW_TEST_TIMEOUT } from '../timeouts' @@ -43,6 +45,7 @@ import { } from '../../../cli/dvc/contract' import { SelectedExperimentWithColor } from '../../../experiments/model' import * as customPlotQuickPickUtil from '../../../plots/model/quickPick' +import { CHECKPOINTS_PARAM } from '../../../plots/model/custom' suite('Plots Test Suite', () => { const disposable = Disposable.fn() @@ -192,48 +195,6 @@ suite('Plots Test Suite', () => { ) }) - it('should handle a set selected metrics message from the webview', async () => { - const { plots, plotsModel, messageSpy } = await buildPlots( - disposable, - plotsDiffFixture - ) - - const webview = await plots.showWebview() - - const mockSendTelemetryEvent = stub(Telemetry, 'sendTelemetryEvent') - const mockMessageReceived = getMessageReceivedEmitter(webview) - - const mockSetSelectedMetrics = spy(plotsModel, 'setSelectedMetrics') - const mockSelectedMetrics = ['summary.json:loss'] - - messageSpy.resetHistory() - mockMessageReceived.fire({ - payload: mockSelectedMetrics, - type: MessageFromWebviewType.TOGGLE_METRIC - }) - - expect(mockSetSelectedMetrics).to.be.calledOnce - expect(mockSetSelectedMetrics).to.be.calledWithExactly( - mockSelectedMetrics - ) - expect(messageSpy).to.be.calledOnce - expect( - messageSpy, - "should update the webview's checkpoint plot state" - ).to.be.calledWithExactly({ - checkpoint: { - ...checkpointPlotsFixture, - selectedMetrics: mockSelectedMetrics - } - }) - expect(mockSendTelemetryEvent).to.be.calledOnce - expect(mockSendTelemetryEvent).to.be.calledWithExactly( - EventName.VIEWS_PLOTS_METRICS_SELECTED, - undefined, - undefined - ) - }).timeout(WEBVIEW_TEST_TIMEOUT) - it('should handle a section resized message from the webview', async () => { const { plots, plotsModel } = await buildPlots(disposable) @@ -282,7 +243,7 @@ suite('Plots Test Suite', () => { const mockMessageReceived = getMessageReceivedEmitter(webview) const mockSetSectionCollapsed = spy(plotsModel, 'setSectionCollapsed') - const mockSectionCollapsed = { [PlotsSection.CHECKPOINT_PLOTS]: true } + const mockSectionCollapsed = { [PlotsSection.CUSTOM_PLOTS]: true } messageSpy.resetHistory() mockMessageReceived.fire({ @@ -467,55 +428,6 @@ suite('Plots Test Suite', () => { ) }).timeout(WEBVIEW_TEST_TIMEOUT) - it('should handle a metric reordered message from the webview', async () => { - const { plots, plotsModel, messageSpy } = await buildPlots( - disposable, - plotsDiffFixture - ) - - const webview = await plots.showWebview() - - const mockSendTelemetryEvent = stub(Telemetry, 'sendTelemetryEvent') - const mockMessageReceived = getMessageReceivedEmitter(webview) - - const mockSetMetricOrder = spy(plotsModel, 'setMetricOrder') - const mockMetricOrder = [ - 'summary.json:loss', - 'summary.json:accuracy', - 'summary.json:val_loss', - 'summary.json:val_accuracy' - ] - - messageSpy.resetHistory() - mockMessageReceived.fire({ - payload: mockMetricOrder, - type: MessageFromWebviewType.REORDER_PLOTS_METRICS - }) - - expect(mockSetMetricOrder).to.be.calledOnce - expect(mockSetMetricOrder).to.be.calledWithExactly(mockMetricOrder) - expect(messageSpy).to.be.calledOnce - expect( - messageSpy, - "should update the webview's checkpoint plot order state" - ).to.be.calledWithExactly({ - checkpoint: { - ...checkpointPlotsFixture, - plots: reorderObjectList( - mockMetricOrder, - checkpointPlotsFixture.plots, - 'title' - ) - } - }) - expect(mockSendTelemetryEvent).to.be.calledOnce - expect(mockSendTelemetryEvent).to.be.calledWithExactly( - EventName.VIEWS_REORDER_PLOTS_METRICS, - undefined, - undefined - ) - }).timeout(WEBVIEW_TEST_TIMEOUT) - it('should handle a plot zoomed message from the webview', async () => { const { plots } = await buildPlots(disposable, plotsDiffFixture) @@ -587,8 +499,8 @@ suite('Plots Test Suite', () => { const webview = await plots.showWebview() const mockNewCustomPlotsOrder = [ - 'custom-metrics:summary.json:accuracy-params:params.yaml:epochs', - 'custom-metrics:summary.json:loss-params:params.yaml:dropout' + 'custom-summary.json:accuracy-params.yaml:epochs', + 'custom-summary.json:loss-params.yaml:dropout' ] stub(plotsModel, 'getCustomPlots') @@ -610,12 +522,14 @@ suite('Plots Test Suite', () => { expect(mockSetCustomPlotsOrder).to.be.calledOnce expect(mockSetCustomPlotsOrder).to.be.calledWithExactly([ { - metric: 'metrics:summary.json:accuracy', - param: 'params:params.yaml:epochs' + metric: 'summary.json:accuracy', + param: 'params.yaml:epochs', + type: CustomPlotType.METRIC_VS_PARAM }, { - metric: 'metrics:summary.json:loss', - param: 'params:params.yaml:dropout' + metric: 'summary.json:loss', + param: 'params.yaml:dropout', + type: CustomPlotType.METRIC_VS_PARAM } ]) expect(messageSpy).to.be.calledOnce @@ -779,21 +693,18 @@ suite('Plots Test Suite', () => { expect(mockPlotsDiff).to.be.called const { - checkpoint: checkpointData, comparison: comparisonData, sectionCollapsed, template: templateData } = getFirstArgOfLastCall(messageSpy) - expect(checkpointData).to.deep.equal(checkpointPlotsFixture) expect(comparisonData).to.deep.equal(comparisonPlotsFixture) expect(sectionCollapsed).to.deep.equal(DEFAULT_SECTION_COLLAPSED) expect(templateData).to.deep.equal(templatePlotsFixture) const expectedPlotsData: TPlotsData = { - checkpoint: checkpointPlotsFixture, comparison: comparisonPlotsFixture, - custom: { height: DEFAULT_PLOT_HEIGHT, nbItemsPerRow: 2, plots: [] }, + custom: customPlotsFixture, hasPlots: true, hasUnselectedPlots: false, sectionCollapsed: DEFAULT_SECTION_COLLAPSED, @@ -927,17 +838,36 @@ suite('Plots Test Suite', () => { const webview = await plots.showWebview() + const mockPickCustomPlotType = stub( + customPlotQuickPickUtil, + 'pickCustomPlotType' + ) const mockGetMetricAndParam = stub( customPlotQuickPickUtil, 'pickMetricAndParam' ) + const mockGetMetric = stub(customPlotQuickPickUtil, 'pickMetric') - const quickPickEvent = new Promise(resolve => - mockGetMetricAndParam.callsFake(() => { + const mockMetricVsParamOrderValue = { + metric: 'summary.json:accuracy', + param: 'params.yaml:dropout', + type: CustomPlotType.METRIC_VS_PARAM + } + + const pickMetricVsParamType = new Promise(resolve => + mockPickCustomPlotType.onFirstCall().callsFake(() => { + resolve(undefined) + + return Promise.resolve(CustomPlotType.METRIC_VS_PARAM) + }) + ) + + const pickMetricVsParamOptions = new Promise(resolve => + mockGetMetricAndParam.onFirstCall().callsFake(() => { resolve(undefined) return Promise.resolve({ - metric: 'metrics:summary.json:loss', - param: 'params:params.yaml:dropout' + metric: mockMetricVsParamOrderValue.metric, + param: mockMetricVsParamOrderValue.param }) }) ) @@ -950,13 +880,47 @@ suite('Plots Test Suite', () => { mockMessageReceived.fire({ type: MessageFromWebviewType.ADD_CUSTOM_PLOT }) - await quickPickEvent + await pickMetricVsParamType + await pickMetricVsParamOptions expect(mockSetCustomPlotsOrder).to.be.calledWith([ - { - metric: 'metrics:summary.json:loss', - param: 'params:params.yaml:dropout' - } + ...customPlotsOrderFixture, + mockMetricVsParamOrderValue + ]) + expect(mockSendTelemetryEvent).to.be.calledWith( + EventName.VIEWS_PLOTS_CUSTOM_PLOT_ADDED, + undefined + ) + + const mockCheckpointsOrderValue = { + metric: 'summary.json:val_loss', + param: CHECKPOINTS_PARAM, + type: CustomPlotType.CHECKPOINT + } + + const pickCheckpointsType = new Promise(resolve => + mockPickCustomPlotType.onSecondCall().callsFake(() => { + resolve(undefined) + + return Promise.resolve(CustomPlotType.CHECKPOINT) + }) + ) + + const pickCheckpointOption = new Promise(resolve => + mockGetMetric.onFirstCall().callsFake(() => { + resolve(undefined) + return Promise.resolve(mockCheckpointsOrderValue.metric) + }) + ) + + mockMessageReceived.fire({ type: MessageFromWebviewType.ADD_CUSTOM_PLOT }) + + await pickCheckpointsType + await pickCheckpointOption + + expect(mockSetCustomPlotsOrder).to.be.calledWith([ + ...customPlotsOrderFixture, + mockCheckpointsOrderValue ]) expect(mockSendTelemetryEvent).to.be.calledWith( EventName.VIEWS_PLOTS_CUSTOM_PLOT_ADDED, @@ -964,6 +928,92 @@ suite('Plots Test Suite', () => { ) }) + it('should handle a add custom plot message when user ends early', async () => { + const { plots, plotsModel } = await buildPlots( + disposable, + plotsDiffFixture + ) + + const webview = await plots.showWebview() + + const mockPickCustomPlotType = stub( + customPlotQuickPickUtil, + 'pickCustomPlotType' + ) + + const mockGetMetricAndParam = stub( + customPlotQuickPickUtil, + 'pickMetricAndParam' + ) + const mockGetMetric = stub(customPlotQuickPickUtil, 'pickMetric') + + const pickUndefinedType = new Promise(resolve => + mockPickCustomPlotType.onFirstCall().callsFake(() => { + resolve(undefined) + + return Promise.resolve(undefined) + }) + ) + + const mockSetCustomPlotsOrder = stub(plotsModel, 'setCustomPlotsOrder') + mockSetCustomPlotsOrder.returns(undefined) + + const mockSendTelemetryEvent = stub(Telemetry, 'sendTelemetryEvent') + const mockMessageReceived = getMessageReceivedEmitter(webview) + + mockMessageReceived.fire({ type: MessageFromWebviewType.ADD_CUSTOM_PLOT }) + + await pickUndefinedType + + expect(mockSetCustomPlotsOrder).to.not.be.called + expect(mockSendTelemetryEvent).to.not.be.called + + const pickMetricVsParamType = new Promise(resolve => + mockPickCustomPlotType.onSecondCall().callsFake(() => { + resolve(undefined) + + return Promise.resolve(CustomPlotType.METRIC_VS_PARAM) + }) + ) + + const pickMetricVsParamUndefOptions = new Promise(resolve => + mockGetMetricAndParam.onFirstCall().callsFake(() => { + resolve(undefined) + return Promise.resolve(undefined) + }) + ) + + mockMessageReceived.fire({ type: MessageFromWebviewType.ADD_CUSTOM_PLOT }) + + await pickMetricVsParamType + await pickMetricVsParamUndefOptions + + expect(mockSetCustomPlotsOrder).to.not.be.called + expect(mockSendTelemetryEvent).to.not.be.called + + const pickCheckpointType = new Promise(resolve => + mockPickCustomPlotType.onThirdCall().callsFake(() => { + resolve(undefined) + + return Promise.resolve(CustomPlotType.CHECKPOINT) + }) + ) + const pickCheckpointUndefOptions = new Promise(resolve => + mockGetMetric.onFirstCall().callsFake(() => { + resolve(undefined) + return Promise.resolve(undefined) + }) + ) + + mockMessageReceived.fire({ type: MessageFromWebviewType.ADD_CUSTOM_PLOT }) + + await pickCheckpointType + await pickCheckpointUndefOptions + + expect(mockSetCustomPlotsOrder).to.not.be.called + expect(mockSendTelemetryEvent).to.not.be.called + }) + it('should handle a remove custom plot message from the webview', async () => { const { plots, plotsModel } = await buildPlots( disposable, @@ -981,15 +1031,16 @@ suite('Plots Test Suite', () => { mockSelectCustomPlots.callsFake(() => { resolve(undefined) return Promise.resolve([ - 'custom-metrics:summary.json:loss-params:params.yaml:dropout' + 'custom-summary.json:loss-params.yaml:dropout' ]) }) ) stub(plotsModel, 'getCustomPlotsOrder').returns([ { - metric: 'metrics:summary.json:loss', - param: 'params:params.yaml:dropout' + metric: 'summary.json:loss', + param: 'params.yaml:dropout', + type: CustomPlotType.METRIC_VS_PARAM } ]) @@ -1011,5 +1062,49 @@ suite('Plots Test Suite', () => { undefined ) }) + + it('should handle a remove custom plot message from the webview when user ends early', async () => { + const { plots, plotsModel } = await buildPlots( + disposable, + plotsDiffFixture + ) + + const webview = await plots.showWebview() + + const mockSelectCustomPlots = stub( + customPlotQuickPickUtil, + 'pickCustomPlots' + ) + + const quickPickEvent = new Promise(resolve => + mockSelectCustomPlots.callsFake(() => { + resolve(undefined) + return Promise.resolve(undefined) + }) + ) + + stub(plotsModel, 'getCustomPlotsOrder').returns([ + { + metric: 'summary.json:loss', + param: 'params.yaml:dropout', + type: CustomPlotType.METRIC_VS_PARAM + } + ]) + + const mockSetCustomPlotsOrder = stub(plotsModel, 'setCustomPlotsOrder') + mockSetCustomPlotsOrder.returns(undefined) + + const mockSendTelemetryEvent = stub(Telemetry, 'sendTelemetryEvent') + const mockMessageReceived = getMessageReceivedEmitter(webview) + + mockMessageReceived.fire({ + type: MessageFromWebviewType.REMOVE_CUSTOM_PLOTS + }) + + await quickPickEvent + + expect(mockSetCustomPlotsOrder).to.not.be.called + expect(mockSendTelemetryEvent).to.not.be.called + }) }) }) diff --git a/extension/src/test/suite/plots/util.ts b/extension/src/test/suite/plots/util.ts index 007783f886..3ac226b408 100644 --- a/extension/src/test/suite/plots/util.ts +++ b/extension/src/test/suite/plots/util.ts @@ -2,7 +2,9 @@ import { Disposer } from '@hediet/std/disposable' import { stub } from 'sinon' import * as FileSystem from '../../../fileSystem' import expShowFixtureWithoutErrors from '../../fixtures/expShow/base/noErrors' -import checkpointPlotsFixture from '../../fixtures/expShow/base/checkpointPlots' +import customPlotsFixture, { + customPlotsOrderFixture +} from '../../fixtures/expShow/base/customPlots' import { Plots } from '../../../plots' import { buildMockMemento, dvcDemoPath } from '../../util' import { WorkspacePlots } from '../../../plots/workspace' @@ -22,6 +24,7 @@ import { WebviewMessages } from '../../../plots/webview/messages' import { ExperimentsModel } from '../../../experiments/model' import { Experiment } from '../../../experiments/webview/contract' import { EXPERIMENT_WORKSPACE_ID } from '../../../cli/dvc/contract' +import { isCheckpointPlot } from '../../../plots/model/custom' export const buildPlots = async ( disposer: Disposer, @@ -89,6 +92,7 @@ export const buildPlots = async ( // eslint-disable-next-line @typescript-eslint/no-explicit-any const plotsModel: PlotsModel = (plots as any).plots + plotsModel.updateCustomPlotsOrder(customPlotsOrderFixture) // eslint-disable-next-line @typescript-eslint/no-explicit-any const pathsModel: PathsModel = (plots as any).paths @@ -127,14 +131,13 @@ export const buildWorkspacePlots = (disposer: Disposer) => { } } -export const getExpectedCheckpointPlotsData = ( +export const getExpectedCustomPlotsData = ( domain: string[], range: Color[] ) => { - const { plots, selectedMetrics, nbItemsPerRow, height } = - checkpointPlotsFixture + const { plots, nbItemsPerRow, height } = customPlotsFixture return { - checkpoint: { + custom: { colors: { domain, range @@ -142,11 +145,11 @@ export const getExpectedCheckpointPlotsData = ( height, nbItemsPerRow, plots: plots.map(plot => ({ - id: plot.id, - title: plot.title, - values: plot.values.filter(values => domain.includes(values.group)) - })), - selectedMetrics + ...plot, + values: isCheckpointPlot(plot) + ? plot.values.filter(value => domain.includes(value.group)) + : plot.values + })) } } } diff --git a/extension/src/vscode/title.ts b/extension/src/vscode/title.ts index 39be6d0c11..eecb9ff650 100644 --- a/extension/src/vscode/title.ts +++ b/extension/src/vscode/title.ts @@ -25,6 +25,7 @@ export enum Title { SELECT_PARAM_OR_METRIC_SORT = 'Select a Param or Metric to Sort by', SELECT_METRIC_CUSTOM_PLOT = 'Select a Metric to Create a Custom Plot', SELECT_PARAM_CUSTOM_PLOT = 'Select a Param to Create a Custom Plot', + SELECT_PLOT_TYPE_CUSTOM_PLOT = 'Select a Custom Plot Type', SELECT_CUSTOM_PLOTS_TO_REMOVE = 'Select Custom Plot(s) to Remove', SELECT_PARAM_TO_MODIFY = 'Select Param(s) to Modify', SELECT_PLOTS = 'Select Plots to Display', diff --git a/extension/src/webview/contract.ts b/extension/src/webview/contract.ts index ed12376331..43c53a998e 100644 --- a/extension/src/webview/contract.ts +++ b/extension/src/webview/contract.ts @@ -30,7 +30,6 @@ export enum MessageFromWebviewType { REORDER_COLUMNS = 'reorder-columns', REORDER_PLOTS_COMPARISON = 'reorder-plots-comparison', REORDER_PLOTS_COMPARISON_ROWS = 'reorder-plots-comparison-rows', - REORDER_PLOTS_METRICS = 'reorder-plots-metrics', REORDER_PLOTS_CUSTOM = 'reorder-plots-custom', REORDER_PLOTS_TEMPLATES = 'reorder-plots-templates', REFRESH_REVISION = 'refresh-revision', @@ -54,7 +53,6 @@ export enum MessageFromWebviewType { SET_STUDIO_SHARE_EXPERIMENTS_LIVE = 'set-studio-share-experiments-live', SHARE_EXPERIMENT_AS_BRANCH = 'share-experiment-as-branch', SHARE_EXPERIMENT_AS_COMMIT = 'share-experiment-as-commit', - TOGGLE_METRIC = 'toggle-metric', TOGGLE_PLOTS_SECTION = 'toggle-plots-section', REMOVE_CUSTOM_PLOTS = 'remove-custom-plots', REMOVE_STUDIO_TOKEN = 'remove-studio-token', @@ -161,10 +159,6 @@ export type MessageFromWebview = type: MessageFromWebviewType.REMOVE_COLUMN_SORT payload: string } - | { - type: MessageFromWebviewType.TOGGLE_METRIC - payload: string[] - } | { type: MessageFromWebviewType.REMOVE_CUSTOM_PLOTS } @@ -177,10 +171,6 @@ export type MessageFromWebview = type: MessageFromWebviewType.REORDER_PLOTS_COMPARISON_ROWS payload: string[] } - | { - type: MessageFromWebviewType.REORDER_PLOTS_METRICS - payload: string[] - } | { type: MessageFromWebviewType.REORDER_PLOTS_CUSTOM payload: string[] diff --git a/webview/src/plots/components/App.test.tsx b/webview/src/plots/components/App.test.tsx index b18cfb7fa7..261a98bc9a 100644 --- a/webview/src/plots/components/App.test.tsx +++ b/webview/src/plots/components/App.test.tsx @@ -13,14 +13,12 @@ import { } from '@testing-library/react' import '@testing-library/jest-dom/extend-expect' import comparisonTableFixture from 'dvc/src/test/fixtures/plotsDiff/comparison' -import checkpointPlotsFixture from 'dvc/src/test/fixtures/expShow/base/checkpointPlots' import customPlotsFixture from 'dvc/src/test/fixtures/expShow/base/customPlots' import plotsRevisionsFixture from 'dvc/src/test/fixtures/plotsDiff/revisions' import templatePlotsFixture from 'dvc/src/test/fixtures/plotsDiff/template/webview' import smoothTemplatePlotContent from 'dvc/src/test/fixtures/plotsDiff/template/smoothTemplatePlot' import manyTemplatePlots from 'dvc/src/test/fixtures/plotsDiff/template/virtualization' import { - CheckpointPlotsData, DEFAULT_SECTION_COLLAPSED, PlotsData, PlotsType, @@ -28,6 +26,8 @@ import { PlotsSection, TemplatePlotGroup, TemplatePlotsData, + CustomPlotType, + CustomPlotsData, DEFAULT_PLOT_HEIGHT, DEFAULT_NB_ITEMS_PER_ROW } from 'dvc/src/plots/webview/contract' @@ -35,14 +35,13 @@ import { MessageFromWebviewType, MessageToWebviewType } from 'dvc/src/webview/contract' -import { reorderObjectList } from 'dvc/src/util/array' import { act } from 'react-dom/test-utils' import { EXPERIMENT_WORKSPACE_ID } from 'dvc/src/cli/dvc/contract' import { VisualizationSpec } from 'react-vega' import { App } from './App' import { NewSectionBlock } from './templatePlots/TemplatePlots' import { - CheckpointPlotsById, + CustomPlotsById, plotDataStore, TemplatePlotsById } from './plotDataStore' @@ -73,18 +72,16 @@ jest.mock('../../shared/components/dragDrop/currentTarget', () => { jest.mock('../../shared/api') -jest.mock('./checkpointPlots/util', () => ({ - createSpec: () => ({ +jest.mock('./customPlots/util', () => ({ + createCheckpointSpec: () => ({ $schema: 'https://vega.github.io/schema/vega-lite/v5.json', encoding: {}, height: 100, layer: [], transform: [], width: 100 - }) -})) -jest.mock('./customPlots/util', () => ({ - createSpec: () => ({ + }), + createMetricVsParamSpec: () => ({ $schema: 'https://vega.github.io/schema/vega-lite/v5.json', encoding: {}, height: 100, @@ -109,13 +106,6 @@ const originalOffsetWidth = Object.getOwnPropertyDescriptor( )?.value describe('App', () => { - const sectionPosition = { - [PlotsSection.CHECKPOINT_PLOTS]: 2, - [PlotsSection.TEMPLATE_PLOTS]: 0, - [PlotsSection.COMPARISON_TABLE]: 1, - [PlotsSection.CUSTOM_PLOTS]: 3 - } - const sendSetDataMessage = (data: PlotsData) => { const message = new MessageEvent('message', { data: { @@ -162,13 +152,6 @@ describe('App', () => { ] } as TemplatePlotsData - const getCheckpointMenuItem = (position: number) => - within( - screen.getAllByTestId('section-container')[ - sectionPosition[PlotsSection.CHECKPOINT_PLOTS] - ] - ).getAllByTestId('icon-menu-item')[position] - const renderAppAndChangeSize = async ( data: PlotsData, nbItemsPerRow: number, @@ -181,11 +164,11 @@ describe('App', () => { ...data, sectionCollapsed: DEFAULT_SECTION_COLLAPSED } - if (section === PlotsSection.CHECKPOINT_PLOTS) { - plotsData.checkpoint = { - ...data?.checkpoint, + if (section === PlotsSection.CUSTOM_PLOTS) { + plotsData.custom = { + ...data?.custom, ...withSize - } as CheckpointPlotsData + } as CustomPlotsData } if (section === PlotsSection.TEMPLATE_PLOTS) { plotsData.template = { @@ -216,7 +199,7 @@ describe('App', () => { jest .spyOn(HTMLElement.prototype, 'clientHeight', 'get') .mockImplementation(() => heightToSuppressVegaError) - plotDataStore[PlotsSection.CHECKPOINT_PLOTS] = {} as CheckpointPlotsById + plotDataStore[PlotsSection.CUSTOM_PLOTS] = {} as CustomPlotsById plotDataStore[PlotsSection.TEMPLATE_PLOTS] = {} as TemplatePlotsById }) @@ -251,9 +234,7 @@ describe('App', () => { }) it('should render the empty state when given data with no plots', async () => { - renderAppWithOptionalData({ - checkpoint: null - }) + renderAppWithOptionalData({ custom: null }) const emptyState = await screen.findByText('No Plots Detected.') expect(emptyState).toBeInTheDocument() @@ -261,7 +242,6 @@ describe('App', () => { it('should render loading section states when given a single revision which has not been fetched', async () => { renderAppWithOptionalData({ - checkpoint: null, comparison: { height: DEFAULT_PLOT_HEIGHT, plots: [ @@ -299,12 +279,11 @@ describe('App', () => { }) const loading = await screen.findAllByText('Loading...') - expect(loading).toHaveLength(3) + expect(loading).toHaveLength(2) }) it('should render the Add Plots and Add Experiments get started button when there are experiments which have plots that are all unselected', async () => { renderAppWithOptionalData({ - checkpoint: null, hasPlots: true, hasUnselectedPlots: true, selectedRevisions: [{} as Revision] @@ -335,7 +314,7 @@ describe('App', () => { it('should render only the Add Experiments get started button when no experiments are selected', async () => { renderAppWithOptionalData({ - checkpoint: null, + custom: null, hasPlots: true, hasUnselectedPlots: false, selectedRevisions: undefined @@ -354,30 +333,6 @@ describe('App', () => { }) }) - it('should render other sections given a message with only checkpoint plots data', () => { - renderAppWithOptionalData({ - checkpoint: checkpointPlotsFixture - }) - - expect(screen.queryByText('Loading Plots...')).not.toBeInTheDocument() - expect(screen.getByText('Trends')).toBeInTheDocument() - expect(screen.getByText('Data Series')).toBeInTheDocument() - expect(screen.getByText('Images')).toBeInTheDocument() - expect(screen.getByText('Custom')).toBeInTheDocument() - expect(screen.getAllByText('No Plots to Display')).toHaveLength(2) - expect(screen.getByText('No Images to Compare')).toBeInTheDocument() - }) - - it('should render checkpoint even when there is no checkpoint plots data', () => { - renderAppWithOptionalData({ - template: templatePlotsFixture - }) - - expect(screen.queryByText('Loading Plots...')).not.toBeInTheDocument() - expect(screen.getByText('Trends')).toBeInTheDocument() - expect(screen.getAllByText('No Plots to Display')).toHaveLength(2) - }) - it('should render an empty state given a message with only custom plots data', () => { renderAppWithOptionalData({ custom: customPlotsFixture @@ -391,12 +346,13 @@ describe('App', () => { it('should render custom with "No Plots to Display" message when there is no custom plots data', () => { renderAppWithOptionalData({ - comparison: comparisonTableFixture + comparison: comparisonTableFixture, + template: templatePlotsFixture }) expect(screen.queryByText('Loading Plots...')).not.toBeInTheDocument() expect(screen.getByText('Custom')).toBeInTheDocument() - expect(screen.getAllByText('No Plots to Display')).toHaveLength(3) + expect(screen.getByText('No Plots to Display')).toBeInTheDocument() }) it('should render custom with "No Plots Added" message when there are no plots added', () => { @@ -410,7 +366,7 @@ describe('App', () => { expect(screen.queryByText('Loading Plots...')).not.toBeInTheDocument() expect(screen.getByText('Custom')).toBeInTheDocument() - expect(screen.getAllByText('No Plots to Display')).toHaveLength(2) + expect(screen.getByText('No Plots to Display')).toBeInTheDocument() expect(screen.getByText('No Plots Added')).toBeInTheDocument() }) @@ -418,7 +374,7 @@ describe('App', () => { const expectedSectionName = 'Images' renderAppWithOptionalData({ - checkpoint: checkpointPlotsFixture + custom: customPlotsFixture }) sendSetDataMessage({ @@ -432,8 +388,8 @@ describe('App', () => { const emptyStateText = 'No Images to Compare' renderAppWithOptionalData({ - checkpoint: checkpointPlotsFixture, - comparison: comparisonTableFixture + comparison: comparisonTableFixture, + template: templatePlotsFixture }) expect(screen.queryByText(emptyStateText)).not.toBeInTheDocument() @@ -447,11 +403,10 @@ describe('App', () => { expect(emptyState).toBeInTheDocument() }) - it('should remove checkpoint plots given a message showing checkpoint plots as null', async () => { + it('should remove custom plots given a message showing custom plots as null', async () => { const emptyStateText = 'No Plots to Display' renderAppWithOptionalData({ - checkpoint: checkpointPlotsFixture, comparison: comparisonTableFixture, custom: customPlotsFixture, template: templatePlotsFixture @@ -460,7 +415,7 @@ describe('App', () => { expect(screen.queryByText(emptyStateText)).not.toBeInTheDocument() sendSetDataMessage({ - checkpoint: null + custom: null }) const emptyState = await screen.findByText(emptyStateText) @@ -470,24 +425,25 @@ describe('App', () => { it('should remove all sections from the document if there is no data provided', () => { renderAppWithOptionalData({ - checkpoint: checkpointPlotsFixture + comparison: comparisonTableFixture }) - expect(screen.getByText('Trends')).toBeInTheDocument() + expect(screen.getByText('Images')).toBeInTheDocument() sendSetDataMessage({ - checkpoint: null + comparison: null }) - expect(screen.queryByText('Trends')).not.toBeInTheDocument() + expect(screen.queryByText('Images')).not.toBeInTheDocument() }) - it('should toggle the checkpoint plots section in state when its header is clicked', async () => { + it('should toggle the custom plots section in state when its header is clicked', async () => { renderAppWithOptionalData({ - checkpoint: checkpointPlotsFixture + comparison: comparisonTableFixture, + custom: customPlotsFixture }) - const summaryElement = await screen.findByText('Trends') + const summaryElement = await screen.findByText('Custom') const visiblePlots = await screen.findAllByLabelText('Vega visualization') for (const visiblePlot of visiblePlots) { expect(visiblePlot).toBeInTheDocument() @@ -500,14 +456,14 @@ describe('App', () => { }) expect(mockPostMessage).toHaveBeenCalledWith({ - payload: { [PlotsSection.CHECKPOINT_PLOTS]: true }, + payload: { [PlotsSection.CUSTOM_PLOTS]: true }, type: MessageFromWebviewType.TOGGLE_PLOTS_SECTION }) sendSetDataMessage({ sectionCollapsed: { ...DEFAULT_SECTION_COLLAPSED, - [PlotsSection.CHECKPOINT_PLOTS]: true + [PlotsSection.CUSTOM_PLOTS]: true } }) @@ -516,21 +472,22 @@ describe('App', () => { ).not.toBeInTheDocument() }) - it('should not toggle the checkpoint plots section when its header is clicked and its title is selected', async () => { + it('should not toggle the custom plots section when its header is clicked and its title is selected', async () => { renderAppWithOptionalData({ - checkpoint: checkpointPlotsFixture + comparison: comparisonTableFixture, + custom: customPlotsFixture }) - const summaryElement = await screen.findByText('Trends') + const summaryElement = await screen.findByText('Custom') - createWindowTextSelection('Trends', 2) + createWindowTextSelection('Custom', 2) fireEvent.click(summaryElement, { bubbles: true, cancelable: true }) expect(mockPostMessage).not.toHaveBeenCalledWith({ - payload: { [PlotsSection.CHECKPOINT_PLOTS]: true }, + payload: { [PlotsSection.CUSTOM_PLOTS]: true }, type: MessageFromWebviewType.TOGGLE_PLOTS_SECTION }) @@ -541,25 +498,25 @@ describe('App', () => { }) expect(mockPostMessage).toHaveBeenCalledWith({ - payload: { [PlotsSection.CHECKPOINT_PLOTS]: true }, + payload: { [PlotsSection.CUSTOM_PLOTS]: true }, type: MessageFromWebviewType.TOGGLE_PLOTS_SECTION }) }) - it('should not toggle the checkpoint plots section if the tooltip is clicked', () => { + it('should not toggle the comparison plots section if the tooltip is clicked', () => { renderAppWithOptionalData({ - checkpoint: checkpointPlotsFixture + comparison: comparisonTableFixture }) - const checkpointsTooltipToggle = screen.getAllByTestId( + const comparisonTooltipToggle = screen.getAllByTestId( 'info-tooltip-toggle' - )[2] - fireEvent.mouseEnter(checkpointsTooltipToggle, { + )[1] + fireEvent.mouseEnter(comparisonTooltipToggle, { bubbles: true, cancelable: true }) - const tooltip = screen.getByTestId('tooltip-checkpoint-plots') + const tooltip = screen.getByTestId('tooltip-comparison-plots') const tooltipLink = within(tooltip).getByRole('link') fireEvent.click(tooltipLink, { bubbles: true, @@ -567,30 +524,31 @@ describe('App', () => { }) expect(mockPostMessage).not.toHaveBeenCalledWith({ - payload: { [PlotsSection.CHECKPOINT_PLOTS]: true }, + payload: { [PlotsSection.CUSTOM_PLOTS]: true }, type: MessageFromWebviewType.TOGGLE_PLOTS_SECTION }) - fireEvent.click(checkpointsTooltipToggle, { + fireEvent.click(comparisonTooltipToggle, { bubbles: true, cancelable: true }) expect(mockPostMessage).not.toHaveBeenCalledWith({ - payload: { [PlotsSection.CHECKPOINT_PLOTS]: true }, + payload: { [PlotsSection.CUSTOM_PLOTS]: true }, type: MessageFromWebviewType.TOGGLE_PLOTS_SECTION }) }) - it('should not toggle the checkpoint plots section when its header is clicked and the content of its tooltip is selected', async () => { + it('should not toggle the custom plots section when its header is clicked and the content of its tooltip is selected', async () => { renderAppWithOptionalData({ - checkpoint: checkpointPlotsFixture + comparison: comparisonTableFixture, + custom: customPlotsFixture }) - const summaryElement = await screen.findByText('Trends') + const summaryElement = await screen.findByText('Custom') createWindowTextSelection( // eslint-disable-next-line testing-library/no-node-access - SectionDescription['checkpoint-plots'].props.children, + SectionDescription['custom-plots'].props.children, 2 ) fireEvent.click(summaryElement, { @@ -599,7 +557,7 @@ describe('App', () => { }) expect(mockPostMessage).not.toHaveBeenCalledWith({ - payload: { [PlotsSection.CHECKPOINT_PLOTS]: true }, + payload: { [PlotsSection.CUSTOM_PLOTS]: true }, type: MessageFromWebviewType.TOGGLE_PLOTS_SECTION }) @@ -610,100 +568,19 @@ describe('App', () => { }) expect(mockPostMessage).toHaveBeenCalledWith({ - payload: { [PlotsSection.CHECKPOINT_PLOTS]: true }, + payload: { [PlotsSection.CUSTOM_PLOTS]: true }, type: MessageFromWebviewType.TOGGLE_PLOTS_SECTION }) }) - it('should toggle the visibility of plots when clicking the metrics in the metrics picker', async () => { - renderAppWithOptionalData({ - checkpoint: checkpointPlotsFixture - }) - - const summaryElement = await screen.findByText('Trends') - fireEvent.click(summaryElement, { - bubbles: true, - cancelable: true - }) - - expect(screen.getByTestId('plot-summary.json:loss')).toBeInTheDocument() - - const pickerButton = getCheckpointMenuItem(0) - fireEvent.mouseEnter(pickerButton) - fireEvent.click(pickerButton) - - const lossItem = await screen.findByText('summary.json:loss', { - ignore: 'text' - }) - - fireEvent.click(lossItem, { - bubbles: true, - cancelable: true - }) - - expect( - screen.queryByTestId('plot-summary.json:loss') - ).not.toBeInTheDocument() - - fireEvent.mouseEnter(pickerButton) - fireEvent.click(pickerButton) - - fireEvent.click(lossItem, { - bubbles: true, - cancelable: true - }) - - expect(screen.getByTestId('plot-summary.json:loss')).toBeInTheDocument() - }) - - it('should send a message to the extension with the selected metrics when toggling the visibility of a plot', async () => { - renderAppWithOptionalData({ - checkpoint: checkpointPlotsFixture - }) - - const pickerButton = getCheckpointMenuItem(0) - fireEvent.mouseEnter(pickerButton) - fireEvent.click(pickerButton) - - const lossItem = await screen.findByText('summary.json:loss') - - fireEvent.click(lossItem, { - bubbles: true, - cancelable: true - }) - - expect(mockPostMessage).toHaveBeenCalledWith({ - payload: [ - 'summary.json:accuracy', - 'summary.json:val_accuracy', - 'summary.json:val_loss' - ], - type: MessageFromWebviewType.TOGGLE_METRIC - }) - - fireEvent.click(lossItem, { - bubbles: true, - cancelable: true - }) - - expect(mockPostMessage).toHaveBeenCalledWith({ - payload: [ - 'summary.json:accuracy', - 'summary.json:loss', - 'summary.json:val_accuracy', - 'summary.json:val_loss' - ], - type: MessageFromWebviewType.TOGGLE_METRIC - }) - }) - it('should display a slider to pick the number of items per row if there are items and the action is available', () => { const store = renderAppWithOptionalData({ - checkpoint: checkpointPlotsFixture + comparison: comparisonTableFixture, + custom: customPlotsFixture }) setWrapperSize(store) - expect(screen.getByTestId('size-sliders')).toBeInTheDocument() + expect(screen.getAllByTestId('size-sliders')[1]).toBeInTheDocument() }) it('should not display a slider to pick the number of items per row if there are no items', () => { @@ -715,26 +592,14 @@ describe('App', () => { it('should not display a slider to pick the number of items per row if the only width available for one item per row or less', () => { const store = renderAppWithOptionalData({ - checkpoint: checkpointPlotsFixture + comparison: comparisonTableFixture, + custom: customPlotsFixture }) setWrapperSize(store, 400) expect(screen.queryByTestId('size-sliders')).not.toBeInTheDocument() }) - it('should display both size sliders for checkpoint plots', () => { - const store = renderAppWithOptionalData({ - checkpoint: checkpointPlotsFixture - }) - setWrapperSize(store) - - const plotResizers = within( - screen.getByTestId('size-sliders') - ).getAllByRole('slider') - - expect(plotResizers.length).toBe(2) - }) - it('should display both size sliders for template plots', () => { const store = renderAppWithOptionalData({ template: templatePlotsFixture @@ -777,20 +642,21 @@ describe('App', () => { it('should send a message to the extension with the selected size when changing the width of plots', () => { const store = renderAppWithOptionalData({ - checkpoint: checkpointPlotsFixture + comparison: comparisonTableFixture, + custom: customPlotsFixture }) setWrapperSize(store) - const plotResizer = within(screen.getByTestId('size-sliders')).getAllByRole( - 'slider' - )[0] + const plotResizer = within( + screen.getAllByTestId('size-sliders')[1] + ).getAllByRole('slider')[0] fireEvent.change(plotResizer, { target: { value: -3 } }) expect(mockPostMessage).toHaveBeenCalledWith({ payload: { height: 1, nbItemsPerRow: 3, - section: PlotsSection.CHECKPOINT_PLOTS + section: PlotsSection.CUSTOM_PLOTS }, type: MessageFromWebviewType.RESIZE_PLOTS }) @@ -798,37 +664,39 @@ describe('App', () => { it('should send a message to the extension with the selected size when changing the height of plots', () => { const store = renderAppWithOptionalData({ - checkpoint: checkpointPlotsFixture + comparison: comparisonTableFixture, + custom: customPlotsFixture }) setWrapperSize(store) - const plotResizer = within(screen.getByTestId('size-sliders')).getAllByRole( - 'slider' - )[1] + const plotResizer = within( + screen.getAllByTestId('size-sliders')[1] + ).getAllByRole('slider')[1] fireEvent.change(plotResizer, { target: { value: 3 } }) expect(mockPostMessage).toHaveBeenCalledWith({ payload: { height: 3, nbItemsPerRow: 2, - section: PlotsSection.CHECKPOINT_PLOTS + section: PlotsSection.CUSTOM_PLOTS }, type: MessageFromWebviewType.RESIZE_PLOTS }) }) - it('should display the checkpoint plots in the order stored', () => { + it('should display the custom plots in the order stored', () => { renderAppWithOptionalData({ - checkpoint: checkpointPlotsFixture + comparison: comparisonTableFixture, + custom: customPlotsFixture }) let plots = screen.getAllByTestId(/summary\.json/) expect(plots.map(plot => plot.id)).toStrictEqual([ - 'summary.json:loss', - 'summary.json:accuracy', - 'summary.json:val_loss', - 'summary.json:val_accuracy' + 'custom-summary.json:loss-params.yaml:dropout', + 'custom-summary.json:accuracy-params.yaml:epochs', + 'custom-summary.json:loss-epoch', + 'custom-summary.json:accuracy-epoch' ]) dragAndDrop(plots[1], plots[0]) @@ -836,24 +704,25 @@ describe('App', () => { plots = screen.getAllByTestId(/summary\.json/) expect(plots.map(plot => plot.id)).toStrictEqual([ - 'summary.json:accuracy', - 'summary.json:loss', - 'summary.json:val_loss', - 'summary.json:val_accuracy' + 'custom-summary.json:accuracy-params.yaml:epochs', + 'custom-summary.json:loss-params.yaml:dropout', + 'custom-summary.json:loss-epoch', + 'custom-summary.json:accuracy-epoch' ]) }) - it('should send a message to the extension when the checkpoint plots are reordered', () => { + it('should send a message to the extension when the custom plots are reordered', () => { renderAppWithOptionalData({ - checkpoint: checkpointPlotsFixture + comparison: comparisonTableFixture, + custom: customPlotsFixture }) const plots = screen.getAllByTestId(/summary\.json/) expect(plots.map(plot => plot.id)).toStrictEqual([ - 'summary.json:loss', - 'summary.json:accuracy', - 'summary.json:val_loss', - 'summary.json:val_accuracy' + 'custom-summary.json:loss-params.yaml:dropout', + 'custom-summary.json:accuracy-params.yaml:epochs', + 'custom-summary.json:loss-epoch', + 'custom-summary.json:accuracy-epoch' ]) mockPostMessage.mockClear() @@ -861,57 +730,37 @@ describe('App', () => { dragAndDrop(plots[2], plots[0]) const expectedOrder = [ - 'summary.json:val_loss', - 'summary.json:loss', - 'summary.json:accuracy', - 'summary.json:val_accuracy' + 'custom-summary.json:loss-epoch', + 'custom-summary.json:loss-params.yaml:dropout', + 'custom-summary.json:accuracy-params.yaml:epochs', + 'custom-summary.json:accuracy-epoch' ] expect(mockPostMessage).toHaveBeenCalledTimes(1) expect(mockPostMessage).toHaveBeenCalledWith({ payload: expectedOrder, - type: MessageFromWebviewType.REORDER_PLOTS_METRICS + type: MessageFromWebviewType.REORDER_PLOTS_CUSTOM }) expect( screen.getAllByTestId(/summary\.json/).map(plot => plot.id) ).toStrictEqual(expectedOrder) }) - it('should remove the checkpoint plot from the order if it is removed from the plots', () => { - renderAppWithOptionalData({ - checkpoint: checkpointPlotsFixture - }) - - let plots = screen.getAllByTestId(/summary\.json/) - dragAndDrop(plots[1], plots[0]) - - sendSetDataMessage({ - checkpoint: { - ...checkpointPlotsFixture, - plots: checkpointPlotsFixture.plots.slice(1) - } - }) - plots = screen.getAllByTestId(/summary\.json/) - expect(plots.map(plot => plot.id)).toStrictEqual([ - 'summary.json:accuracy', - 'summary.json:val_loss', - 'summary.json:val_accuracy' - ]) - }) - it('should add a custom plot if a user creates a custom plot', () => { renderAppWithOptionalData({ comparison: comparisonTableFixture, custom: { ...customPlotsFixture, - plots: customPlotsFixture.plots.slice(1) + plots: customPlotsFixture.plots.slice(0, 3) } }) expect( screen.getAllByTestId(/summary\.json/).map(plot => plot.id) ).toStrictEqual([ - 'custom-metrics:summary.json:accuracy-params:params.yaml:epochs' + 'custom-summary.json:loss-params.yaml:dropout', + 'custom-summary.json:accuracy-params.yaml:epochs', + 'custom-summary.json:loss-epoch' ]) sendSetDataMessage({ @@ -921,8 +770,10 @@ describe('App', () => { expect( screen.getAllByTestId(/summary\.json/).map(plot => plot.id) ).toStrictEqual([ - 'custom-metrics:summary.json:accuracy-params:params.yaml:epochs', - 'custom-metrics:summary.json:loss-params:params.yaml:dropout' + 'custom-summary.json:loss-params.yaml:dropout', + 'custom-summary.json:accuracy-params.yaml:epochs', + 'custom-summary.json:loss-epoch', + 'custom-summary.json:accuracy-epoch' ]) }) @@ -935,8 +786,10 @@ describe('App', () => { expect( screen.getAllByTestId(/summary\.json/).map(plot => plot.id) ).toStrictEqual([ - 'custom-metrics:summary.json:loss-params:params.yaml:dropout', - 'custom-metrics:summary.json:accuracy-params:params.yaml:epochs' + 'custom-summary.json:loss-params.yaml:dropout', + 'custom-summary.json:accuracy-params.yaml:epochs', + 'custom-summary.json:loss-epoch', + 'custom-summary.json:accuracy-epoch' ]) sendSetDataMessage({ @@ -949,87 +802,28 @@ describe('App', () => { expect( screen.getAllByTestId(/summary\.json/).map(plot => plot.id) ).toStrictEqual([ - 'custom-metrics:summary.json:accuracy-params:params.yaml:epochs' + 'custom-summary.json:accuracy-params.yaml:epochs', + 'custom-summary.json:loss-epoch', + 'custom-summary.json:accuracy-epoch' ]) }) - it('should not change the metric order in the hover menu by reordering the plots', () => { - renderAppWithOptionalData({ - checkpoint: checkpointPlotsFixture - }) - - const [pickerButton] = within( - screen.getAllByTestId('section-container')[ - sectionPosition[PlotsSection.CHECKPOINT_PLOTS] - ] - ).queryAllByTestId('icon-menu-item') - - fireEvent.mouseEnter(pickerButton) - fireEvent.click(pickerButton) - - let options = screen.getAllByTestId('select-menu-option-label') - const optionsOrder = [ - 'summary.json:accuracy', - 'summary.json:loss', - 'summary.json:val_accuracy', - 'summary.json:val_loss' - ] - expect(options.map(({ textContent }) => textContent)).toStrictEqual( - optionsOrder - ) - - fireEvent.click(pickerButton) - - let plots = screen.getAllByTestId(/summary\.json/) - const newPlotOrder = [ - 'summary.json:val_accuracy', - 'summary.json:loss', - 'summary.json:accuracy', - 'summary.json:val_loss' - ] - expect(plots.map(plot => plot.id)).not.toStrictEqual(newPlotOrder) - - dragAndDrop(plots[3], plots[0]) - sendSetDataMessage({ - checkpoint: { - ...checkpointPlotsFixture, - plots: reorderObjectList( - newPlotOrder, - checkpointPlotsFixture.plots, - 'title' - ) - } - }) - - plots = screen.getAllByTestId(/summary\.json/) - - expect(plots.map(plot => plot.id)).toStrictEqual(newPlotOrder) - - fireEvent.mouseEnter(pickerButton) - fireEvent.click(pickerButton) - - options = screen.getAllByTestId('select-menu-option-label') - expect(options.map(({ textContent }) => textContent)).toStrictEqual( - optionsOrder - ) - }) - it('should not be possible to drag a plot from a section to another', () => { renderAppWithOptionalData({ - checkpoint: checkpointPlotsFixture, + custom: customPlotsFixture, template: templatePlotsFixture }) - const checkpointPlots = screen.getAllByTestId(/summary\.json/) + const customPlots = screen.getAllByTestId(/summary\.json/) const templatePlots = screen.getAllByTestId(/^plot_/) - dragAndDrop(templatePlots[0], checkpointPlots[2]) + dragAndDrop(templatePlots[0], customPlots[2]) - expect(checkpointPlots.map(plot => plot.id)).toStrictEqual([ - 'summary.json:loss', - 'summary.json:accuracy', - 'summary.json:val_loss', - 'summary.json:val_accuracy' + expect(customPlots.map(plot => plot.id)).toStrictEqual([ + 'custom-summary.json:loss-params.yaml:dropout', + 'custom-summary.json:accuracy-params.yaml:epochs', + 'custom-summary.json:loss-epoch', + 'custom-summary.json:accuracy-epoch' ]) }) @@ -1246,32 +1040,34 @@ describe('App', () => { ]) }) - it('should show a drop target at the end of the checkpoint plots when moving a plot inside the section but not over any other plot', () => { + it('should show a drop target at the end of the custom plots when moving a plot inside the section but not over any other plot', () => { renderAppWithOptionalData({ - checkpoint: checkpointPlotsFixture + custom: customPlotsFixture, + template: templatePlotsFixture }) const plots = screen.getAllByTestId(/summary\.json/) - dragEnter(plots[0], 'checkpoint-plots', DragEnterDirection.LEFT) + dragEnter(plots[0], 'custom-plots', DragEnterDirection.LEFT) expect(screen.getByTestId('plot_drop-target')).toBeInTheDocument() }) - it('should show a drop a plot at the end of the checkpoint plots when moving a plot inside the section but not over any other plot', () => { + it('should show a drop a plot at the end of the custom plots when moving a plot inside the section but not over any other plot', () => { renderAppWithOptionalData({ - checkpoint: checkpointPlotsFixture + custom: customPlotsFixture, + template: templatePlotsFixture }) const plots = screen.getAllByTestId(/summary\.json/) - dragAndDrop(plots[0], screen.getByTestId('checkpoint-plots')) + dragAndDrop(plots[0], screen.getByTestId('custom-plots')) const expectedOrder = [ - 'summary.json:accuracy', - 'summary.json:val_loss', - 'summary.json:val_accuracy', - 'summary.json:loss' + 'custom-summary.json:accuracy-params.yaml:epochs', + 'custom-summary.json:loss-epoch', + 'custom-summary.json:accuracy-epoch', + 'custom-summary.json:loss-params.yaml:dropout' ] expect( @@ -1495,9 +1291,10 @@ describe('App', () => { }) }) - it('should open a modal with the plot zoomed in when clicking a checkpoint plot', () => { + it('should open a modal with the plot zoomed in when clicking a custom plot', () => { renderAppWithOptionalData({ - checkpoint: checkpointPlotsFixture + comparison: comparisonTableFixture, + custom: customPlotsFixture }) expect(screen.queryByTestId('modal')).not.toBeInTheDocument() @@ -1559,14 +1356,14 @@ describe('App', () => { it('should show a tooltip with the meaning of each plot section', () => { renderAppWithOptionalData({ - checkpoint: checkpointPlotsFixture, comparison: comparisonTableFixture, custom: customPlotsFixture, template: complexTemplatePlotsFixture }) - const [templateInfo, comparisonInfo, checkpointInfo, customInfo] = - screen.getAllByTestId('info-tooltip-toggle') + const [templateInfo, comparisonInfo, customInfo] = screen.getAllByTestId( + 'info-tooltip-toggle' + ) fireEvent.mouseEnter(templateInfo, { bubbles: true }) expect(screen.getByTestId('tooltip-template-plots')).toBeInTheDocument() @@ -1574,16 +1371,12 @@ describe('App', () => { fireEvent.mouseEnter(comparisonInfo, { bubbles: true }) expect(screen.getByTestId('tooltip-comparison-plots')).toBeInTheDocument() - fireEvent.mouseEnter(checkpointInfo, { bubbles: true }) - expect(screen.getByTestId('tooltip-checkpoint-plots')).toBeInTheDocument() - fireEvent.mouseEnter(customInfo, { bubbles: true }) expect(screen.getByTestId('tooltip-custom-plots')).toBeInTheDocument() }) it('should dismiss a tooltip by pressing esc', () => { renderAppWithOptionalData({ - checkpoint: checkpointPlotsFixture, comparison: comparisonTableFixture, custom: customPlotsFixture, template: complexTemplatePlotsFixture @@ -1604,21 +1397,24 @@ describe('App', () => { }) describe('Virtualization', () => { - const createCheckpointPlots = (nbOfPlots: number) => { + const createCustomPlots = (nbOfPlots: number): CustomPlotsData => { const plots = [] for (let i = 0; i < nbOfPlots; i++) { const id = `plot-${i}` plots.push({ id, - title: id, - values: [] + metric: '', + param: '', + type: CustomPlotType.CHECKPOINT, + values: [], + yTitle: id }) } return { - ...checkpointPlotsFixture, + ...customPlotsFixture, plots, selectedMetrics: plots.map(plot => plot.id) - } + } as CustomPlotsData } const resizeScreen = (width: number, store: typeof plotsStore) => { @@ -1632,17 +1428,17 @@ describe('App', () => { } describe('Large plots', () => { - it('should wrap the checkpoint plots in a big grid (virtualize them) when there are more than eight large plots', async () => { + it('should wrap the custom plots in a big grid (virtualize them) when there are more than eight large plots', async () => { await renderAppAndChangeSize( - { checkpoint: createCheckpointPlots(9) }, + { comparison: comparisonTableFixture, custom: createCustomPlots(9) }, 1, - PlotsSection.CHECKPOINT_PLOTS + PlotsSection.CUSTOM_PLOTS ) expect(screen.getByRole('grid')).toBeInTheDocument() sendSetDataMessage({ - checkpoint: createCheckpointPlots(50) + custom: createCustomPlots(50) }) await screen.findAllByTestId('plots-wrapper') @@ -1650,17 +1446,17 @@ describe('App', () => { expect(screen.getByRole('grid')).toBeInTheDocument() }) - it('should not wrap the checkpoint plots in a big grid (virtualize them) when there are eight or fewer large plots', async () => { + it('should not wrap the custom plots in a big grid (virtualize them) when there are eight or fewer large plots', async () => { await renderAppAndChangeSize( - { checkpoint: createCheckpointPlots(8) }, + { comparison: comparisonTableFixture, custom: createCustomPlots(8) }, 1, - PlotsSection.CHECKPOINT_PLOTS + PlotsSection.CUSTOM_PLOTS ) expect(screen.queryByRole('grid')).not.toBeInTheDocument() sendSetDataMessage({ - checkpoint: createCheckpointPlots(1) + custom: createCustomPlots(1) }) await screen.findAllByTestId('plots-wrapper') @@ -1705,30 +1501,28 @@ describe('App', () => { }) describe('Sizing', () => { - const checkpoint = createCheckpointPlots(25) + const custom = createCustomPlots(25) let store: typeof plotsStore beforeEach(async () => { store = await renderAppAndChangeSize( - { checkpoint }, + { comparison: comparisonTableFixture, custom }, 1, - PlotsSection.CHECKPOINT_PLOTS + PlotsSection.CUSTOM_PLOTS ) }) it('should render the plots correctly when the screen is larger than 2000px', () => { - resizeScreen(3000, store) - let plots = screen.getAllByTestId(/^plot-/) - expect(plots[4].id).toBe(checkpoint.plots[4].title) + expect(plots[4].id).toBe(custom.plots[4].yTitle) expect(plots.length).toBe(OVERSCAN_ROW_COUNT + 1) resizeScreen(5453, store) plots = screen.getAllByTestId(/^plot-/) - expect(plots[3].id).toBe(checkpoint.plots[3].title) + expect(plots[3].id).toBe(custom.plots[3].yTitle) expect(plots.length).toBe(OVERSCAN_ROW_COUNT + 1) }) @@ -1737,7 +1531,7 @@ describe('App', () => { const plots = screen.getAllByTestId(/^plot-/) - expect(plots[12].id).toBe(checkpoint.plots[12].title) + expect(plots[12].id).toBe(custom.plots[12].yTitle) expect(plots.length).toBe(OVERSCAN_ROW_COUNT + 1) }) @@ -1746,7 +1540,7 @@ describe('App', () => { const plots = screen.getAllByTestId(/^plot-/) - expect(plots[14].id).toBe(checkpoint.plots[14].title) + expect(plots[14].id).toBe(custom.plots[14].yTitle) expect(plots.length).toBe(1 + OVERSCAN_ROW_COUNT) // Only the first and the next lines defined by the overscan row count will be rendered }) @@ -1755,27 +1549,27 @@ describe('App', () => { const plots = screen.getAllByTestId(/^plot-/) - expect(plots[4].id).toBe(checkpoint.plots[4].title) + expect(plots[4].id).toBe(custom.plots[4].yTitle) }) }) }) describe('Regular plots', () => { - it('should wrap the checkpoint plots in a big grid (virtualize them) when there are more than fourteen regular plots', async () => { + it('should wrap the custom plots in a big grid (virtualize them) when there are more than fourteen regular plots', async () => { await renderAppAndChangeSize( - { checkpoint: createCheckpointPlots(15) }, + { comparison: comparisonTableFixture, custom: createCustomPlots(15) }, DEFAULT_NB_ITEMS_PER_ROW, - PlotsSection.CHECKPOINT_PLOTS + PlotsSection.CUSTOM_PLOTS ) expect(screen.getByRole('grid')).toBeInTheDocument() }) - it('should not wrap the checkpoint plots in a big grid (virtualize them) when there are fourteen regular plots', async () => { + it('should not wrap the custom plots in a big grid (virtualize them) when there are fourteen regular plots', async () => { await renderAppAndChangeSize( - { checkpoint: createCheckpointPlots(14) }, + { comparison: comparisonTableFixture, custom: createCustomPlots(14) }, DEFAULT_NB_ITEMS_PER_ROW, - PlotsSection.CHECKPOINT_PLOTS + PlotsSection.CUSTOM_PLOTS ) expect(screen.queryByRole('grid')).not.toBeInTheDocument() @@ -1802,14 +1596,14 @@ describe('App', () => { }) describe('Sizing', () => { - const checkpoint = createCheckpointPlots(25) + const custom = createCustomPlots(25) let store: typeof plotsStore beforeEach(async () => { store = await renderAppAndChangeSize( - { checkpoint }, + { comparison: comparisonTableFixture, custom }, DEFAULT_NB_ITEMS_PER_ROW, - PlotsSection.CHECKPOINT_PLOTS + PlotsSection.CUSTOM_PLOTS ) }) @@ -1818,15 +1612,15 @@ describe('App', () => { let plots = screen.getAllByTestId(/^plot-/) - expect(plots[20].id).toBe(checkpoint.plots[20].title) - expect(plots.length).toBe(checkpoint.plots.length) + expect(plots[20].id).toBe(custom.plots[20].yTitle) + expect(plots.length).toBe(custom.plots.length) resizeScreen(6453, store) plots = screen.getAllByTestId(/^plot-/) - expect(plots[19].id).toBe(checkpoint.plots[19].title) - expect(plots.length).toBe(checkpoint.plots.length) + expect(plots[19].id).toBe(custom.plots[19].yTitle) + expect(plots.length).toBe(custom.plots.length) }) it('should render the plots correctly when the screen is larger than 1600px (but less than 2000px)', () => { @@ -1834,8 +1628,8 @@ describe('App', () => { const plots = screen.getAllByTestId(/^plot-/) - expect(plots[7].id).toBe(checkpoint.plots[7].title) - expect(plots.length).toBe(checkpoint.plots.length) + expect(plots[7].id).toBe(custom.plots[7].yTitle) + expect(plots.length).toBe(custom.plots.length) }) it('should render the plots correctly when the screen is larger than 800px (but less than 1600px)', () => { @@ -1843,8 +1637,8 @@ describe('App', () => { const plots = screen.getAllByTestId(/^plot-/) - expect(plots[7].id).toBe(checkpoint.plots[7].title) - expect(plots.length).toBe(checkpoint.plots.length) + expect(plots[7].id).toBe(custom.plots[7].yTitle) + expect(plots.length).toBe(custom.plots.length) }) it('should render the plots correctly when the screen is smaller than 800px', () => { @@ -1852,27 +1646,27 @@ describe('App', () => { const plots = screen.getAllByTestId(/^plot-/) - expect(plots[4].id).toBe(checkpoint.plots[4].title) + expect(plots[4].id).toBe(custom.plots[4].yTitle) }) }) }) describe('Smaller plots', () => { - it('should wrap the checkpoint plots in a big grid (virtualize them) when there are more than twenty small plots', async () => { + it('should wrap the custom plots in a big grid (virtualize them) when there are more than twenty small plots', async () => { await renderAppAndChangeSize( - { checkpoint: createCheckpointPlots(21) }, + { comparison: comparisonTableFixture, custom: createCustomPlots(21) }, 4, - PlotsSection.CHECKPOINT_PLOTS + PlotsSection.CUSTOM_PLOTS ) expect(screen.getByRole('grid')).toBeInTheDocument() }) - it('should not wrap the checkpoint plots in a big grid (virtualize them) when there are twenty or fewer small plots', async () => { + it('should not wrap the custom plots in a big grid (virtualize them) when there are twenty or fewer small plots', async () => { await renderAppAndChangeSize( - { checkpoint: createCheckpointPlots(20) }, + { comparison: comparisonTableFixture, custom: createCustomPlots(20) }, 4, - PlotsSection.CHECKPOINT_PLOTS + PlotsSection.CUSTOM_PLOTS ) expect(screen.queryByRole('grid')).not.toBeInTheDocument() @@ -1899,14 +1693,14 @@ describe('App', () => { }) describe('Sizing', () => { - const checkpoint = createCheckpointPlots(25) + const custom = createCustomPlots(25) let store: typeof plotsStore beforeEach(async () => { store = await renderAppAndChangeSize( - { checkpoint }, + { comparison: comparisonTableFixture, custom }, 4, - PlotsSection.CHECKPOINT_PLOTS + PlotsSection.CUSTOM_PLOTS ) }) @@ -1915,15 +1709,15 @@ describe('App', () => { let plots = screen.getAllByTestId(/^plot-/) - expect(plots[7].id).toBe(checkpoint.plots[7].title) - expect(plots.length).toBe(checkpoint.plots.length) + expect(plots[7].id).toBe(custom.plots[7].yTitle) + expect(plots.length).toBe(custom.plots.length) resizeScreen(5473, store) plots = screen.getAllByTestId(/^plot-/) - expect(plots[9].id).toBe(checkpoint.plots[9].title) - expect(plots.length).toBe(checkpoint.plots.length) + expect(plots[9].id).toBe(custom.plots[9].yTitle) + expect(plots.length).toBe(custom.plots.length) }) it('should render the plots correctly when the screen is larger than 1600px (but less than 2000px)', () => { @@ -1931,8 +1725,8 @@ describe('App', () => { const plots = screen.getAllByTestId(/^plot-/) - expect(plots[24].id).toBe(checkpoint.plots[24].title) - expect(plots.length).toBe(checkpoint.plots.length) + expect(plots[24].id).toBe(custom.plots[24].yTitle) + expect(plots.length).toBe(custom.plots.length) }) it('should render the plots correctly when the screen is larger than 800px (but less than 1600px)', () => { @@ -1940,8 +1734,8 @@ describe('App', () => { const plots = screen.getAllByTestId(/^plot-/) - expect(plots[9].id).toBe(checkpoint.plots[9].title) - expect(plots.length).toBe(checkpoint.plots.length) + expect(plots[9].id).toBe(custom.plots[9].yTitle) + expect(plots.length).toBe(custom.plots.length) }) it('should render the plots correctly when the screen is smaller than 800px but larger than 600px', () => { @@ -1949,8 +1743,8 @@ describe('App', () => { const plots = screen.getAllByTestId(/^plot-/) - expect(plots[9].id).toBe(checkpoint.plots[9].title) - expect(plots.length).toBe(checkpoint.plots.length) + expect(plots[9].id).toBe(custom.plots[9].yTitle) + expect(plots.length).toBe(custom.plots.length) }) it('should render the plots correctly when the screen is smaller than 600px', () => { @@ -1958,33 +1752,12 @@ describe('App', () => { const plots = screen.getAllByTestId(/^plot-/) - expect(plots[4].id).toBe(checkpoint.plots[4].title) + expect(plots[4].id).toBe(custom.plots[4].yTitle) }) }) }) }) - describe('Context Menu Suppression', () => { - it('Suppresses the context menu with no plots data', () => { - renderAppWithOptionalData() - const target = screen.getByText('Loading Plots...') - const contextMenuEvent = createEvent.contextMenu(target) - fireEvent(target, contextMenuEvent) - expect(contextMenuEvent.defaultPrevented).toBe(true) - }) - - it('Suppresses the context menu with plots data', () => { - renderAppWithOptionalData({ - checkpoint: checkpointPlotsFixture, - sectionCollapsed: DEFAULT_SECTION_COLLAPSED - }) - const target = screen.getByText('Trends') - const contextMenuEvent = createEvent.contextMenu(target) - fireEvent(target, contextMenuEvent) - expect(contextMenuEvent.defaultPrevented).toBe(true) - }) - }) - // eslint-disable-next-line sonarjs/cognitive-complexity describe('Ribbon', () => { const getDisplayedRevisionOrder = () => { diff --git a/webview/src/plots/components/App.tsx b/webview/src/plots/components/App.tsx index f10543ecea..155e0a0fb2 100644 --- a/webview/src/plots/components/App.tsx +++ b/webview/src/plots/components/App.tsx @@ -1,7 +1,6 @@ import React, { useCallback } from 'react' import { useDispatch } from 'react-redux' import { - CheckpointPlotsData, CustomPlotsData, PlotsComparisonData, PlotsData, @@ -13,10 +12,6 @@ import { } from 'dvc/src/plots/webview/contract' import { MessageToWebview } from 'dvc/src/webview/contract' import { Plots } from './Plots' -import { - setCollapsed as setCheckpointPlotsCollapsed, - update as updateCheckpointPlots -} from './checkpointPlots/checkpointPlotsSlice' import { setCollapsed as setCustomPlotsCollapsed, update as updateCustomPlots @@ -43,9 +38,6 @@ const dispatchCollapsedSections = ( dispatch: PlotsDispatch ) => { if (sections) { - dispatch( - setCheckpointPlotsCollapsed(sections[PlotsSection.CHECKPOINT_PLOTS]) - ) dispatch(setCustomPlotsCollapsed(sections[PlotsSection.CUSTOM_PLOTS])) dispatch( setComparisonTableCollapsed(sections[PlotsSection.COMPARISON_TABLE]) @@ -63,9 +55,6 @@ export const feedStore = ( dispatch(initialize()) for (const key of Object.keys(data.data)) { switch (key) { - case PlotsDataKeys.CHECKPOINT: - dispatch(updateCheckpointPlots(data.data[key] as CheckpointPlotsData)) - continue case PlotsDataKeys.CUSTOM: dispatch(updateCustomPlots(data.data[key] as CustomPlotsData)) continue diff --git a/webview/src/plots/components/Plots.tsx b/webview/src/plots/components/Plots.tsx index 8d893258dd..994d3c9a87 100644 --- a/webview/src/plots/components/Plots.tsx +++ b/webview/src/plots/components/Plots.tsx @@ -2,7 +2,6 @@ import React, { createRef, useLayoutEffect } from 'react' import { useSelector, useDispatch } from 'react-redux' import { AddPlots, Welcome } from './GetStarted' import { ZoomedInPlot } from './ZoomedInPlot' -import { CheckpointPlotsWrapper } from './checkpointPlots/CheckpointPlotsWrapper' import { CustomPlotsWrapper } from './customPlots/CustomPlotsWrapper' import { TemplatePlotsWrapper } from './templatePlots/TemplatePlotsWrapper' import { ComparisonTableWrapper } from './comparisonTable/ComparisonTableWrapper' @@ -19,9 +18,6 @@ const PlotsContent = () => { const { hasData, hasPlots, hasUnselectedPlots, zoomedInPlot } = useSelector( (state: PlotsState) => state.webview ) - const hasCheckpointData = useSelector( - (state: PlotsState) => state.checkpoint.hasData - ) const hasComparisonData = useSelector( (state: PlotsState) => state.comparison.hasData ) @@ -51,7 +47,7 @@ const PlotsContent = () => { return Loading Plots... } - if (!hasCheckpointData && !hasComparisonData && !hasTemplateData) { + if (!hasComparisonData && !hasTemplateData) { return ( } @@ -66,7 +62,6 @@ const PlotsContent = () => { - {zoomedInPlot?.plot && ( diff --git a/webview/src/plots/components/checkpointPlots/CheckpointPlot.tsx b/webview/src/plots/components/checkpointPlots/CheckpointPlot.tsx deleted file mode 100644 index f837dc1815..0000000000 --- a/webview/src/plots/components/checkpointPlots/CheckpointPlot.tsx +++ /dev/null @@ -1,60 +0,0 @@ -import { ColorScale, PlotsSection } from 'dvc/src/plots/webview/contract' -import React, { useMemo, useEffect, useState } from 'react' -import { useSelector } from 'react-redux' -import { createSpec } from './util' -import { changeDisabledDragIds } from './checkpointPlotsSlice' -import { ZoomablePlot } from '../ZoomablePlot' -import styles from '../styles.module.scss' -import { withScale } from '../../../util/styles' -import { plotDataStore } from '../plotDataStore' -import { PlotsState } from '../../store' - -interface CheckpointPlotProps { - id: string - colors: ColorScale -} - -export const CheckpointPlot: React.FC = ({ - id, - colors -}) => { - const plotSnapshot = useSelector( - (state: PlotsState) => state.checkpoint.plotsSnapshots[id] - ) - const [plot, setPlot] = useState( - plotDataStore[PlotsSection.CHECKPOINT_PLOTS][id] - ) - const nbItemsPerRow = useSelector( - (state: PlotsState) => state.checkpoint.nbItemsPerRow - ) - - const spec = useMemo(() => { - const title = plot?.title - if (!title) { - return {} - } - return createSpec(title, colors) - }, [plot?.title, colors]) - - useEffect(() => { - setPlot(plotDataStore[PlotsSection.CHECKPOINT_PLOTS][id]) - }, [plotSnapshot, id]) - - if (!plot) { - return null - } - - const key = `plot-${id}` - - return ( -
- -
- ) -} diff --git a/webview/src/plots/components/checkpointPlots/CheckpointPlots.tsx b/webview/src/plots/components/checkpointPlots/CheckpointPlots.tsx deleted file mode 100644 index 5b6048b64d..0000000000 --- a/webview/src/plots/components/checkpointPlots/CheckpointPlots.tsx +++ /dev/null @@ -1,117 +0,0 @@ -import React, { DragEvent, useEffect, useState } from 'react' -import { useSelector } from 'react-redux' -import cx from 'classnames' -import { ColorScale } from 'dvc/src/plots/webview/contract' -import { MessageFromWebviewType } from 'dvc/src/webview/contract' -import { performSimpleOrderedUpdate } from 'dvc/src/util/array' -import { CheckpointPlot } from './CheckpointPlot' -import styles from '../styles.module.scss' -import { EmptyState } from '../../../shared/components/emptyState/EmptyState' -import { - DragDropContainer, - WrapperProps -} from '../../../shared/components/dragDrop/DragDropContainer' -import { sendMessage } from '../../../shared/vscode' -import { DropTarget } from '../DropTarget' -import { VirtualizedGrid } from '../../../shared/components/virtualizedGrid/VirtualizedGrid' -import { shouldUseVirtualizedGrid } from '../util' -import { PlotsState } from '../../store' -import { changeOrderWithDraggedInfo } from '../../../util/array' -import { LoadingSection, sectionIsLoading } from '../LoadingSection' - -interface CheckpointPlotsProps { - plotsIds: string[] - colors: ColorScale -} - -export const CheckpointPlots: React.FC = ({ - plotsIds, - colors -}) => { - const [order, setOrder] = useState(plotsIds) - const { nbItemsPerRow, hasData, disabledDragPlotIds } = useSelector( - (state: PlotsState) => state.checkpoint - ) - const [onSection, setOnSection] = useState(false) - const draggedRef = useSelector( - (state: PlotsState) => state.dragAndDrop.draggedRef - ) - - const selectedRevisions = useSelector( - (state: PlotsState) => state.webview.selectedRevisions - ) - - useEffect(() => { - setOrder(pastOrder => performSimpleOrderedUpdate(pastOrder, plotsIds)) - }, [plotsIds]) - - const setMetricOrder = (order: string[]): void => { - setOrder(order) - sendMessage({ - payload: order, - type: MessageFromWebviewType.REORDER_PLOTS_METRICS - }) - } - - if (sectionIsLoading(selectedRevisions)) { - return - } - - if (!hasData) { - return No Plots to Display - } - - const items = order.map(plot => ( -
- -
- )) - - const useVirtualizedGrid = shouldUseVirtualizedGrid( - items.length, - nbItemsPerRow - ) - - const handleDropAtTheEnd = () => { - setMetricOrder(changeOrderWithDraggedInfo(order, draggedRef)) - } - - const handleDragOver = (e: DragEvent) => { - e.preventDefault() - setOnSection(true) - } - - return items.length > 0 ? ( -
setOnSection(true)} - onDragLeave={() => setOnSection(false)} - onDragOver={handleDragOver} - onDrop={handleDropAtTheEnd} - > - } - wrapperComponent={ - useVirtualizedGrid - ? { - component: VirtualizedGrid as React.FC, - props: { nbItemsPerRow } - } - : undefined - } - parentDraggedOver={onSection} - /> -
- ) : ( - No Metrics Selected - ) -} diff --git a/webview/src/plots/components/checkpointPlots/CheckpointPlotsWrapper.tsx b/webview/src/plots/components/checkpointPlots/CheckpointPlotsWrapper.tsx deleted file mode 100644 index 140c4900ba..0000000000 --- a/webview/src/plots/components/checkpointPlots/CheckpointPlotsWrapper.tsx +++ /dev/null @@ -1,60 +0,0 @@ -import { PlotsSection } from 'dvc/src/plots/webview/contract' -import { MessageFromWebviewType } from 'dvc/src/webview/contract' -import React, { useEffect, useState } from 'react' -import { useSelector } from 'react-redux' -import { CheckpointPlots } from './CheckpointPlots' -import { changeSize } from './checkpointPlotsSlice' -import { PlotsContainer } from '../PlotsContainer' -import { sendMessage } from '../../../shared/vscode' -import { PlotsState } from '../../store' - -export const CheckpointPlotsWrapper: React.FC = () => { - const { - plotsIds, - nbItemsPerRow, - selectedMetrics, - isCollapsed, - colors, - height - } = useSelector((state: PlotsState) => state.checkpoint) - const [metrics, setMetrics] = useState([]) - const [selectedPlots, setSelectedPlots] = useState([]) - - useEffect(() => { - setMetrics([...plotsIds].sort()) - setSelectedPlots(selectedMetrics || []) - }, [plotsIds, selectedMetrics, setSelectedPlots, setMetrics]) - - const setSelectedMetrics = (metrics: string[]) => { - setSelectedPlots(metrics) - sendMessage({ - payload: metrics, - type: MessageFromWebviewType.TOGGLE_METRIC - }) - } - - const hasItems = plotsIds.length > 0 - - const menu = hasItems - ? { - plots: metrics, - selectedPlots, - setSelectedPlots: setSelectedMetrics - } - : undefined - - return ( - - - - ) -} diff --git a/webview/src/plots/components/checkpointPlots/checkpointPlotsSlice.ts b/webview/src/plots/components/checkpointPlots/checkpointPlotsSlice.ts deleted file mode 100644 index d6f0c9d809..0000000000 --- a/webview/src/plots/components/checkpointPlots/checkpointPlotsSlice.ts +++ /dev/null @@ -1,79 +0,0 @@ -import { createSlice, PayloadAction } from '@reduxjs/toolkit' -import { - CheckpointPlotsData, - DEFAULT_HEIGHT, - DEFAULT_SECTION_COLLAPSED, - DEFAULT_SECTION_NB_ITEMS_PER_ROW_OR_WIDTH, - PlotHeight, - PlotsSection -} from 'dvc/src/plots/webview/contract' -import { addPlotsWithSnapshots, removePlots } from '../plotDataStore' - -export interface CheckpointPlotsState - extends Omit { - isCollapsed: boolean - hasData: boolean - plotsIds: string[] - plotsSnapshots: { [key: string]: string } - disabledDragPlotIds: string[] -} - -export const checkpointPlotsInitialState: CheckpointPlotsState = { - colors: { domain: [], range: [] }, - disabledDragPlotIds: [], - hasData: false, - height: DEFAULT_HEIGHT[PlotsSection.CHECKPOINT_PLOTS], - isCollapsed: DEFAULT_SECTION_COLLAPSED[PlotsSection.CHECKPOINT_PLOTS], - nbItemsPerRow: - DEFAULT_SECTION_NB_ITEMS_PER_ROW_OR_WIDTH[PlotsSection.CHECKPOINT_PLOTS], - plotsIds: [], - plotsSnapshots: {}, - selectedMetrics: [] -} - -export const checkpointPlotsSlice = createSlice({ - initialState: checkpointPlotsInitialState, - name: 'checkpoint', - reducers: { - changeDisabledDragIds: (state, action: PayloadAction) => { - state.disabledDragPlotIds = action.payload - }, - changeSize: ( - state, - action: PayloadAction<{ - nbItemsPerRowOrWidth: number - height: PlotHeight - }> - ) => { - state.nbItemsPerRow = action.payload.nbItemsPerRowOrWidth - state.height = action.payload.height - }, - setCollapsed: (state, action: PayloadAction) => { - state.isCollapsed = action.payload - }, - update: (state, action: PayloadAction) => { - if (!action.payload) { - return checkpointPlotsInitialState - } - const { plots, ...statePayload } = action.payload - const plotsIds = plots?.map(plot => plot.id) || [] - const snapShots = addPlotsWithSnapshots( - plots, - PlotsSection.CHECKPOINT_PLOTS - ) - removePlots(plotsIds, PlotsSection.CHECKPOINT_PLOTS) - return { - ...state, - ...statePayload, - hasData: !!action.payload, - plotsIds: plots?.map(plot => plot.id) || [], - plotsSnapshots: snapShots - } - } - } -}) - -export const { update, setCollapsed, changeSize, changeDisabledDragIds } = - checkpointPlotsSlice.actions - -export default checkpointPlotsSlice.reducer diff --git a/webview/src/plots/components/checkpointPlots/util.ts b/webview/src/plots/components/checkpointPlots/util.ts deleted file mode 100644 index 94e099555b..0000000000 --- a/webview/src/plots/components/checkpointPlots/util.ts +++ /dev/null @@ -1,93 +0,0 @@ -import { VisualizationSpec } from 'react-vega' -import { ColorScale } from 'dvc/src/plots/webview/contract' - -export const createSpec = ( - title: string, - scale?: ColorScale -): VisualizationSpec => - ({ - $schema: 'https://vega.github.io/schema/vega-lite/v5.json', - data: { name: 'values' }, - encoding: { - color: { - field: 'group', - legend: { disable: true }, - scale, - title: 'rev', - type: 'nominal' - }, - x: { - axis: { format: '0d', tickMinStep: 1 }, - field: 'iteration', - title: 'iteration', - type: 'quantitative' - }, - y: { - field: 'y', - scale: { zero: false }, - title, - type: 'quantitative' - } - }, - height: 'container', - layer: [ - { - layer: [ - { mark: { type: 'line' } }, - { - mark: { type: 'point' }, - transform: [ - { - filter: { empty: false, param: 'hover' } - } - ] - } - ] - }, - { - encoding: { - opacity: { value: 0 }, - tooltip: [ - { field: 'group', title: 'name' }, - { - field: 'y', - title: title.slice(Math.max(0, title.indexOf(':') + 1)), - type: 'quantitative' - } - ] - }, - mark: { type: 'rule' }, - params: [ - { - name: 'hover', - select: { - clear: 'mouseout', - fields: ['iteration', 'y'], - nearest: true, - on: 'mouseover', - type: 'point' - } - } - ] - }, - { - encoding: { - color: { field: 'group', scale }, - x: { aggregate: 'max', field: 'iteration', type: 'quantitative' }, - y: { - aggregate: { argmax: 'iteration' }, - field: 'y', - type: 'quantitative' - } - }, - mark: { stroke: null, type: 'circle' } - } - ], - transform: [ - { - as: 'y', - calculate: "format(datum['y'],'.5f')" - } - ], - width: 'container' - } as VisualizationSpec) diff --git a/webview/src/plots/components/customPlots/CustomPlot.tsx b/webview/src/plots/components/customPlots/CustomPlot.tsx index edd8175425..667c6b12f2 100644 --- a/webview/src/plots/components/customPlots/CustomPlot.tsx +++ b/webview/src/plots/components/customPlots/CustomPlot.tsx @@ -1,7 +1,12 @@ -import { PlotsSection } from 'dvc/src/plots/webview/contract' +import { + ColorScale, + CustomPlotData, + PlotsSection +} from 'dvc/src/plots/webview/contract' +import { isCheckpointPlot } from 'dvc/src/plots/model/custom' import React, { useMemo, useEffect, useState } from 'react' import { useSelector } from 'react-redux' -import { createSpec } from './util' +import { createMetricVsParamSpec, createCheckpointSpec } from './util' import { changeDisabledDragIds } from './customPlotsSlice' import { ZoomablePlot } from '../ZoomablePlot' import styles from '../styles.module.scss' @@ -13,26 +18,38 @@ interface CustomPlotProps { id: string } +const createCustomPlotSpec = ( + plot: CustomPlotData | undefined, + colors: ColorScale | undefined +) => { + if (!plot) { + return {} + } + + if (isCheckpointPlot(plot)) { + return createCheckpointSpec(plot.yTitle, plot.metric, plot.param, colors) + } + return createMetricVsParamSpec(plot.yTitle, plot.param) +} + export const CustomPlot: React.FC = ({ id }) => { const plotSnapshot = useSelector( (state: PlotsState) => state.custom.plotsSnapshots[id] ) + const [plot, setPlot] = useState(plotDataStore[PlotsSection.CUSTOM_PLOTS][id]) - const nbItemsPerRow = useSelector( - (state: PlotsState) => state.custom.nbItemsPerRow + const { nbItemsPerRow, colors } = useSelector( + (state: PlotsState) => state.custom ) - const spec = useMemo(() => { - if (plot) { - return createSpec(plot.metric, plot.param) - } - }, [plot]) + return createCustomPlotSpec(plot, colors) + }, [plot, colors]) useEffect(() => { setPlot(plotDataStore[PlotsSection.CUSTOM_PLOTS][id]) }, [plotSnapshot, id]) - if (!plot || !spec) { + if (!plot) { return null } diff --git a/webview/src/plots/components/customPlots/CustomPlots.tsx b/webview/src/plots/components/customPlots/CustomPlots.tsx index 6bb54ef9a4..1c07e4a3f0 100644 --- a/webview/src/plots/components/customPlots/CustomPlots.tsx +++ b/webview/src/plots/components/customPlots/CustomPlots.tsx @@ -2,7 +2,6 @@ import React, { DragEvent, useEffect, useState } from 'react' import { useSelector } from 'react-redux' import cx from 'classnames' import { MessageFromWebviewType } from 'dvc/src/webview/contract' -import { performSimpleOrderedUpdate } from 'dvc/src/util/array' import { CustomPlot } from './CustomPlot' import styles from '../styles.module.scss' import { EmptyState } from '../../../shared/components/emptyState/EmptyState' @@ -32,7 +31,7 @@ export const CustomPlots: React.FC = ({ plotsIds }) => { ) useEffect(() => { - setOrder(pastOrder => performSimpleOrderedUpdate(pastOrder, plotsIds)) + setOrder(plotsIds) }, [plotsIds]) const setPlotsIdsOrder = (order: string[]): void => { diff --git a/webview/src/plots/components/customPlots/customPlotsSlice.ts b/webview/src/plots/components/customPlots/customPlotsSlice.ts index 148600416c..a81eefda36 100644 --- a/webview/src/plots/components/customPlots/customPlotsSlice.ts +++ b/webview/src/plots/components/customPlots/customPlotsSlice.ts @@ -17,7 +17,10 @@ export interface CustomPlotsState extends Omit { disabledDragPlotIds: string[] } +const initialColorsState = { domain: [], range: [] } + export const customPlotsInitialState: CustomPlotsState = { + colors: initialColorsState, disabledDragPlotIds: [], hasData: false, height: DEFAULT_HEIGHT[PlotsSection.CUSTOM_PLOTS], @@ -52,13 +55,14 @@ export const customPlotsSlice = createSlice({ if (!action.payload) { return customPlotsInitialState } - const { plots, ...statePayload } = action.payload + const { plots, colors, ...statePayload } = action.payload const plotsIds = plots?.map(plot => plot.id) || [] const snapShots = addPlotsWithSnapshots(plots, PlotsSection.CUSTOM_PLOTS) removePlots(plotsIds, PlotsSection.CUSTOM_PLOTS) return { ...state, ...statePayload, + colors: colors || initialColorsState, hasData: !!action.payload, plotsIds: plots?.map(plot => plot.id) || [], plotsSnapshots: snapShots diff --git a/webview/src/plots/components/customPlots/util.ts b/webview/src/plots/components/customPlots/util.ts index bbb7cfbd49..3e6a3336a4 100644 --- a/webview/src/plots/components/customPlots/util.ts +++ b/webview/src/plots/components/customPlots/util.ts @@ -1,6 +1,100 @@ import { VisualizationSpec } from 'react-vega' +import { ColorScale } from 'dvc/src/plots/webview/contract' -export const createSpec = (metric: string, param: string) => +export const createCheckpointSpec = ( + title: string, + fullTitle: string, + param: string, + scale?: ColorScale +): VisualizationSpec => + ({ + $schema: 'https://vega.github.io/schema/vega-lite/v5.json', + data: { name: 'values' }, + encoding: { + color: { + field: 'group', + legend: { disable: true }, + scale, + title: 'rev', + type: 'nominal' + }, + x: { + axis: { format: '0d', tickMinStep: 1 }, + field: 'iteration', + title: param, + type: 'quantitative' + }, + y: { + field: 'y', + scale: { zero: false }, + title, + type: 'quantitative' + } + }, + height: 'container', + layer: [ + { + layer: [ + { mark: { type: 'line' } }, + { + mark: { type: 'point' }, + transform: [ + { + filter: { empty: false, param: 'hover' } + } + ] + } + ] + }, + { + encoding: { + opacity: { value: 0 }, + tooltip: [ + { field: 'group', title: 'name' }, + { + field: 'y', + title: fullTitle.slice(Math.max(0, fullTitle.indexOf(':') + 1)), + type: 'quantitative' + } + ] + }, + mark: { type: 'rule' }, + params: [ + { + name: 'hover', + select: { + clear: 'mouseout', + fields: ['iteration', 'y'], + nearest: true, + on: 'mouseover', + type: 'point' + } + } + ] + }, + { + encoding: { + color: { field: 'group', scale }, + x: { aggregate: 'max', field: 'iteration', type: 'quantitative' }, + y: { + aggregate: { argmax: 'iteration' }, + field: 'y', + type: 'quantitative' + } + }, + mark: { stroke: null, type: 'circle' } + } + ], + transform: [ + { + as: 'y', + calculate: "format(datum['y'],'.5f')" + } + ], + width: 'container' + } as VisualizationSpec) + +export const createMetricVsParamSpec = (metric: string, param: string) => ({ $schema: 'https://vega.github.io/schema/vega-lite/v5.json', data: { name: 'values' }, diff --git a/webview/src/plots/components/plotDataStore.ts b/webview/src/plots/components/plotDataStore.ts index 93355c6021..b474a517d8 100644 --- a/webview/src/plots/components/plotDataStore.ts +++ b/webview/src/plots/components/plotDataStore.ts @@ -1,23 +1,20 @@ import { - CheckpointPlotData, CustomPlotData, PlotsSection, TemplatePlotEntry } from 'dvc/src/plots/webview/contract' -export type CheckpointPlotsById = { [key: string]: CheckpointPlotData } export type CustomPlotsById = { [key: string]: CustomPlotData } export type TemplatePlotsById = { [key: string]: TemplatePlotEntry } export const plotDataStore = { - [PlotsSection.CHECKPOINT_PLOTS]: {} as CheckpointPlotsById, [PlotsSection.TEMPLATE_PLOTS]: {} as TemplatePlotsById, - [PlotsSection.COMPARISON_TABLE]: {} as CheckpointPlotsById, // This category is unused but exists only to make typings easier, + [PlotsSection.COMPARISON_TABLE]: {} as CustomPlotsById, // This category is unused but exists only to make typings easier, [PlotsSection.CUSTOM_PLOTS]: {} as CustomPlotsById } export const addPlotsWithSnapshots = ( - plots: (CheckpointPlotData | TemplatePlotEntry | CustomPlotData)[], + plots: (TemplatePlotEntry | CustomPlotData)[], section: PlotsSection ) => { const snapShots: { [key: string]: string } = {} diff --git a/webview/src/plots/hooks/useGetPlot.ts b/webview/src/plots/hooks/useGetPlot.ts index dbe0825eb5..0d64b46d09 100644 --- a/webview/src/plots/hooks/useGetPlot.ts +++ b/webview/src/plots/hooks/useGetPlot.ts @@ -1,5 +1,4 @@ import { - CheckpointPlotData, CustomPlotData, PlotsSection, TemplatePlotEntry @@ -10,26 +9,13 @@ import { PlainObject, VisualizationSpec } from 'react-vega' import { plotDataStore } from '../components/plotDataStore' import { PlotsState } from '../store' -const getStoreSection = (section: PlotsSection) => { - switch (section) { - case PlotsSection.CHECKPOINT_PLOTS: - return 'checkpoint' - case PlotsSection.TEMPLATE_PLOTS: - return 'template' - default: - return 'custom' - } -} - export const useGetPlot = ( section: PlotsSection, id: string, spec?: VisualizationSpec ) => { - const isPlotWithSpec = - section === PlotsSection.CHECKPOINT_PLOTS || - section === PlotsSection.CUSTOM_PLOTS - const storeSection = getStoreSection(section) + const isCustomPlot = section === PlotsSection.CUSTOM_PLOTS + const storeSection = isCustomPlot ? 'custom' : 'template' const snapshot = useSelector( (state: PlotsState) => state[storeSection].plotsSnapshots ) @@ -42,8 +28,8 @@ export const useGetPlot = ( return } - if (isPlotWithSpec) { - setData({ values: (plot as CheckpointPlotData | CustomPlotData).values }) + if (isCustomPlot) { + setData({ values: (plot as CustomPlotData).values }) setContent(spec) return } @@ -54,7 +40,7 @@ export const useGetPlot = ( height: 'container', width: 'container' } as VisualizationSpec) - }, [id, isPlotWithSpec, setData, setContent, section, spec]) + }, [id, isCustomPlot, setData, setContent, section, spec]) useEffect(() => { setPlotData() diff --git a/webview/src/plots/store.ts b/webview/src/plots/store.ts index 9686b1fda8..f5deecb327 100644 --- a/webview/src/plots/store.ts +++ b/webview/src/plots/store.ts @@ -1,5 +1,4 @@ import { configureStore } from '@reduxjs/toolkit' -import checkpointPlotsReducer from './components/checkpointPlots/checkpointPlotsSlice' import comparisonTableReducer from './components/comparisonTable/comparisonTableSlice' import templatePlotsReducer from './components/templatePlots/templatePlotsSlice' import customPlotsReducer from './components/customPlots/customPlotsSlice' @@ -8,7 +7,6 @@ import ribbonReducer from './components/ribbon/ribbonSlice' import dragAndDropReducer from '../shared/components/dragDrop/dragDropSlice' export const plotsReducers = { - checkpoint: checkpointPlotsReducer, comparison: comparisonTableReducer, custom: customPlotsReducer, dragAndDrop: dragAndDropReducer, diff --git a/webview/src/shared/components/sectionContainer/SectionContainer.tsx b/webview/src/shared/components/sectionContainer/SectionContainer.tsx index b1bf8f408c..e89869a57d 100644 --- a/webview/src/shared/components/sectionContainer/SectionContainer.tsx +++ b/webview/src/shared/components/sectionContainer/SectionContainer.tsx @@ -12,17 +12,6 @@ import { IconMenu } from '../iconMenu/IconMenu' import { IconMenuItemProps } from '../iconMenu/IconMenuItem' export const SectionDescription = { - // "Trends" - [PlotsSection.CHECKPOINT_PLOTS]: ( - - Automatically generated and updated linear plots that show metric value - per epoch if{' '} - - checkpoints - {' '} - are enabled. - - ), // "Custom" [PlotsSection.CUSTOM_PLOTS]: ( diff --git a/webview/src/stories/Plots.stories.tsx b/webview/src/stories/Plots.stories.tsx index 117bb962c0..b54aad1e93 100644 --- a/webview/src/stories/Plots.stories.tsx +++ b/webview/src/stories/Plots.stories.tsx @@ -12,7 +12,6 @@ import { DEFAULT_NB_ITEMS_PER_ROW } from 'dvc/src/plots/webview/contract' import { MessageToWebviewType } from 'dvc/src/webview/contract' -import checkpointPlotsFixture from 'dvc/src/test/fixtures/expShow/base/checkpointPlots' import customPlotsFixture from 'dvc/src/test/fixtures/expShow/base/customPlots' import templatePlotsFixture from 'dvc/src/test/fixtures/plotsDiff/template' import manyTemplatePlots from 'dvc/src/test/fixtures/plotsDiff/template/virtualization' @@ -32,36 +31,34 @@ import '../plots/components/styles.module.scss' import { feedStore } from '../plots/components/App' import { plotsReducers } from '../plots/store' -const smallCheckpointPlotsFixture = { - ...checkpointPlotsFixture, +const smallCustomPlotsFixture = { + ...customPlotsFixture, nbItemsPerRow: 3, - plots: checkpointPlotsFixture.plots.map(plot => ({ + plots: customPlotsFixture.plots.map(plot => ({ ...plot, - title: truncateVerticalTitle( - plot.title, + yTitle: truncateVerticalTitle( + plot.yTitle, DEFAULT_NB_ITEMS_PER_ROW, DEFAULT_PLOT_HEIGHT ) as string })) } -const manyCheckpointPlots = (length: number) => - Array.from({ length }, () => checkpointPlotsFixture.plots[0]).map( - (plot, i) => { - const id = plot.id + i.toString() - return { - ...plot, +const manyCustomPlots = (length: number) => + Array.from({ length }, () => customPlotsFixture.plots[2]).map((plot, i) => { + const id = plot.id + i.toString() + return { + ...plot, + id, + yTitle: truncateVerticalTitle( id, - title: truncateVerticalTitle( - id, - DEFAULT_NB_ITEMS_PER_ROW, - DEFAULT_PLOT_HEIGHT - ) as string - } + DEFAULT_NB_ITEMS_PER_ROW, + DEFAULT_PLOT_HEIGHT + ) as string } - ) + }) -const manyCheckpointPlotsFixture = manyCheckpointPlots(15) +const manyCustomPlotsFixture = manyCustomPlots(15) const MockedState: React.FC<{ data: PlotsData; children: React.ReactNode }> = ({ children, @@ -77,7 +74,6 @@ const MockedState: React.FC<{ data: PlotsData; children: React.ReactNode }> = ({ export default { args: { data: { - checkpoint: checkpointPlotsFixture, comparison: comparisonPlotsFixture, custom: customPlotsFixture, hasPlots: true, @@ -107,28 +103,6 @@ const Template: Story<{ export const WithData = Template.bind({}) WithData.parameters = CHROMATIC_VIEWPORTS_WITH_DELAY -export const WithEmptyCheckpoints = Template.bind({}) -WithEmptyCheckpoints.args = { - data: { - checkpoint: { ...checkpointPlotsFixture, selectedMetrics: [] }, - comparison: comparisonPlotsFixture, - sectionCollapsed: DEFAULT_SECTION_COLLAPSED, - selectedRevisions: plotsRevisionsFixture, - template: templatePlotsFixture - } -} -WithEmptyCheckpoints.parameters = DISABLE_CHROMATIC_SNAPSHOTS - -export const WithCheckpointOnly = Template.bind({}) -WithCheckpointOnly.args = { - data: { - checkpoint: checkpointPlotsFixture, - sectionCollapsed: DEFAULT_SECTION_COLLAPSED, - selectedRevisions: plotsRevisionsFixture - } -} -WithCheckpointOnly.parameters = DISABLE_CHROMATIC_SNAPSHOTS - export const WithCustomOnly = Template.bind({}) WithCustomOnly.args = { data: { @@ -199,10 +173,6 @@ WithoutData.args = { export const AllLarge = Template.bind({}) AllLarge.args = { data: { - checkpoint: { - ...checkpointPlotsFixture, - nbItemsPerRow: 1 - }, comparison: { ...comparisonPlotsFixture, width: 1 @@ -224,15 +194,11 @@ AllLarge.parameters = CHROMATIC_VIEWPORTS_WITH_DELAY export const AllSmall = Template.bind({}) AllSmall.args = { data: { - checkpoint: smallCheckpointPlotsFixture, comparison: { ...comparisonPlotsFixture, width: 3 }, - custom: { - ...customPlotsFixture, - nbItemsPerRow: 3 - }, + custom: smallCustomPlotsFixture, sectionCollapsed: DEFAULT_SECTION_COLLAPSED, selectedRevisions: plotsRevisionsFixture, template: { @@ -246,13 +212,11 @@ AllSmall.parameters = CHROMATIC_VIEWPORTS_WITH_DELAY export const VirtualizedPlots = Template.bind({}) VirtualizedPlots.args = { data: { - checkpoint: { - ...checkpointPlotsFixture, - plots: manyCheckpointPlotsFixture, - selectedMetrics: manyCheckpointPlotsFixture.map(plot => plot.id) - }, comparison: undefined, - custom: customPlotsFixture, + custom: { + ...customPlotsFixture, + plots: manyCustomPlotsFixture + }, sectionCollapsed: DEFAULT_SECTION_COLLAPSED, selectedRevisions: plotsRevisionsFixture, template: manyTemplatePlots(125) @@ -321,7 +285,6 @@ ScrolledHeaders.parameters = { export const ScrolledWithManyRevisions = Template.bind({}) ScrolledWithManyRevisions.args = { data: { - checkpoint: checkpointPlotsFixture, comparison: comparisonPlotsFixture, custom: customPlotsFixture, hasPlots: true,