From c83e76367f7ea3159acd9adfc5474b33b7269a01 Mon Sep 17 00:00:00 2001 From: mattseddon <37993418+mattseddon@users.noreply.github.com> Date: Tue, 28 Jun 2022 05:33:26 +1000 Subject: [PATCH] Patch plots for branches containing path separators (#1949) * use short sha to fetch HEAD plots data (workaround branch names containing path separators) * refactor --- extension/src/plots/data/index.ts | 16 ++++- extension/src/plots/model/collect.test.ts | 23 ++++++- extension/src/plots/model/collect.ts | 56 ++++++++++++------ extension/src/plots/model/index.ts | 37 +++++++++--- .../src/test/fixtures/plotsDiff/index.ts | 29 +++++---- ...in_plots_acc.png => 53c3851_plots_acc.png} | Bin ..._heatmap.png => 53c3851_plots_heatmap.png} | Bin ..._plots_loss.png => 53c3851_plots_loss.png} | Bin extension/src/test/fixtures/plotsDiff/util.ts | 13 ++++ extension/src/test/fixtures/plotsDiff/vega.ts | 8 +-- .../src/test/suite/plots/data/index.test.ts | 39 +++++++++--- extension/src/test/suite/plots/index.test.ts | 18 +++--- .../src/test/suite/plots/paths/tree.test.ts | 2 +- 13 files changed, 173 insertions(+), 68 deletions(-) rename extension/src/test/fixtures/plotsDiff/staticImages/{main_plots_acc.png => 53c3851_plots_acc.png} (100%) rename extension/src/test/fixtures/plotsDiff/staticImages/{main_plots_heatmap.png => 53c3851_plots_heatmap.png} (100%) rename extension/src/test/fixtures/plotsDiff/staticImages/{main_plots_loss.png => 53c3851_plots_loss.png} (100%) create mode 100644 extension/src/test/fixtures/plotsDiff/util.ts diff --git a/extension/src/plots/data/index.ts b/extension/src/plots/data/index.ts index 05f65662b4..3d2e32ada4 100644 --- a/extension/src/plots/data/index.ts +++ b/extension/src/plots/data/index.ts @@ -46,8 +46,7 @@ export class PlotsData extends BaseData<{ data: PlotsOutput; revs: string[] }> { return } - const args = sameContents(revs, ['workspace']) ? [] : revs - + const args = this.getArgs(revs) const data = await this.internalCommands.executeCommand( AvailableCommands.PLOTS_DIFF, this.dvcRoot, @@ -58,7 +57,7 @@ export class PlotsData extends BaseData<{ data: PlotsOutput; revs: string[] }> { this.compareFiles(files) - return this.notifyChanged({ data, revs }) + return this.notifyChanged({ data, revs: args }) } public managedUpdate() { @@ -72,4 +71,15 @@ export class PlotsData extends BaseData<{ data: PlotsOutput; revs: string[] }> { public setModel(model: PlotsModel) { this.model = model } + + private getArgs(revs: string[]) { + if ( + this.model && + (sameContents(revs, ['workspace']) || sameContents(revs, [])) + ) { + return this.model.getDefaultRevs() + } + + return revs + } } diff --git a/extension/src/plots/model/collect.test.ts b/extension/src/plots/model/collect.test.ts index 9c2f710b9f..046010490e 100644 --- a/extension/src/plots/model/collect.test.ts +++ b/extension/src/plots/model/collect.test.ts @@ -22,6 +22,7 @@ import { uniqueValues } from '../../util/array' import { TemplatePlot } from '../webview/contract' +import { getCLIBranchId } from '../../test/fixtures/plotsDiff/util' const logsLossPath = join('logs', 'loss.tsv') @@ -226,7 +227,17 @@ describe('collectMetricOrder', () => { describe('collectData', () => { it('should return the expected output from the test fixture', () => { - const { revisionData, comparisonData } = collectData(plotsDiffFixture) + const mapping = { + '1ba7bcd': '1ba7bcd', + '42b8736': '42b8736', + '4fb124a': '4fb124a', + '53c3851': 'main', + workspace: 'workspace' + } + const { revisionData, comparisonData } = collectData( + plotsDiffFixture, + mapping + ) const revisions = ['workspace', 'main', '42b8736', '1ba7bcd', '4fb124a'] const values = @@ -237,7 +248,7 @@ describe('collectData', () => { expect(isEmpty(values)).toBeFalsy() for (const revision of revisions) { - const expectedValues = values[revision].map(value => ({ + const expectedValues = values[getCLIBranchId(revision)].map(value => ({ ...value, rev: revision })) @@ -287,7 +298,13 @@ describe('collectTemplates', () => { }) describe('collectWorkspaceRaceConditionData', () => { - const { comparisonData, revisionData } = collectData(plotsDiffFixture) + const { comparisonData, revisionData } = collectData(plotsDiffFixture, { + '1ba7bcd': '1ba7bcd', + '42b8736': '42b8736', + '4fb124a': '4fb124a', + '53c3851': 'main', + workspace: 'workspace' + }) it('should return no overwrite data if there is no selected checkpoint experiment running in the workspace', () => { const { overwriteComparisonData, overwriteRevisionData } = diff --git a/extension/src/plots/model/collect.ts b/extension/src/plots/model/collect.ts index a66f555029..0a31a507d6 100644 --- a/extension/src/plots/model/collect.ts +++ b/extension/src/plots/model/collect.ts @@ -312,31 +312,39 @@ export const collectMetricOrder = ( type RevisionPathData = { [path: string]: Record[] } export type RevisionData = { - [revision: string]: RevisionPathData + [label: string]: RevisionPathData } export type ComparisonData = { - [revision: string]: { + [label: string]: { [path: string]: ImagePlot } } +export type CLIRevisionIdToLabel = { [shortSha: string]: string } + const collectImageData = ( acc: ComparisonData, path: string, - plot: ImagePlot + plot: ImagePlot, + cliIdToLabel: CLIRevisionIdToLabel ) => { const rev = plot.revisions?.[0] - if (!rev) { return } - if (!acc[rev]) { - acc[rev] = {} + const label = cliIdToLabel[rev] + + if (!label) { + return + } + + if (!acc[label]) { + acc[label] = {} } - acc[rev][path] = plot + acc[label][path] = plot } const collectDatapoints = ( @@ -353,15 +361,17 @@ const collectDatapoints = ( const collectPlotData = ( acc: RevisionData, path: string, - plot: TemplatePlot + plot: TemplatePlot, + cliIdToLabel: CLIRevisionIdToLabel ) => { - for (const rev of plot.revisions || []) { - if (!acc[rev]) { - acc[rev] = {} + for (const id of plot.revisions || []) { + const label = cliIdToLabel[id] + if (!acc[label]) { + acc[label] = {} } - acc[rev][path] = [] + acc[label][path] = [] - collectDatapoints(acc, path, rev, plot.datapoints?.[rev]) + collectDatapoints(acc, path, label, plot.datapoints?.[id]) } } @@ -370,25 +380,33 @@ type DataAccumulator = { comparisonData: ComparisonData } -const collectPathData = (acc: DataAccumulator, path: string, plots: Plot[]) => { +const collectPathData = ( + acc: DataAccumulator, + path: string, + plots: Plot[], + cliIdToLabel: CLIRevisionIdToLabel +) => { for (const plot of plots) { if (isImagePlot(plot)) { - collectImageData(acc.comparisonData, path, plot) + collectImageData(acc.comparisonData, path, plot, cliIdToLabel) continue } - collectPlotData(acc.revisionData, path, plot) + collectPlotData(acc.revisionData, path, plot, cliIdToLabel) } } -export const collectData = (data: PlotsOutput): DataAccumulator => { +export const collectData = ( + data: PlotsOutput, + cliIdToLabel: CLIRevisionIdToLabel +): DataAccumulator => { const acc = { comparisonData: {}, revisionData: {} } as DataAccumulator for (const [path, plots] of Object.entries(data)) { - collectPathData(acc, path, plots) + collectPathData(acc, path, plots, cliIdToLabel) } return acc @@ -543,7 +561,7 @@ export const collectBranchRevisionDetails = ( const branchRevisions: Record = {} for (const { id, sha } of branchShas) { if (sha) { - branchRevisions[id] = sha + branchRevisions[id] = shortenForLabel(sha) } } return branchRevisions diff --git a/extension/src/plots/model/index.ts b/extension/src/plots/model/index.ts index a340ab16c6..65460fc106 100644 --- a/extension/src/plots/model/index.ts +++ b/extension/src/plots/model/index.ts @@ -96,10 +96,15 @@ export class PlotsModel extends ModelWithPersistence { } public async transformAndSetPlots(data: PlotsOutput, revs: string[]) { - this.fetchedRevs = new Set([...this.fetchedRevs, ...revs]) + const cliIdToLabel = this.getCLIIdToLabel() + + this.fetchedRevs = new Set([ + ...this.fetchedRevs, + ...revs.map(rev => cliIdToLabel[rev]) + ]) const [{ comparisonData, revisionData }, templates] = await Promise.all([ - collectData(data), + collectData(data, cliIdToLabel), collectTemplates(data) ]) @@ -168,9 +173,9 @@ export class PlotsModel extends ModelWithPersistence { ...Object.keys(this.revisionData) ]) - return this.getSelectedRevisions().filter( - revision => !cachedRevisions.has(revision) - ) + return this.getSelectedRevisions() + .filter(label => !cachedRevisions.has(label)) + .map(label => this.getCLIId(label)) } public getMutableRevisions() { @@ -186,16 +191,20 @@ export class PlotsModel extends ModelWithPersistence { this.comparisonOrder, this.experiments .getSelectedRevisions() - .map(({ label: revision, displayColor, logicalGroupName, id }) => ({ + .map(({ label, displayColor, logicalGroupName, id }) => ({ displayColor, group: logicalGroupName, id, - revision + revision: label })), 'revision' ) } + public getDefaultRevs() { + return ['workspace', ...Object.values(this.branchRevisions)] + } + public getTemplatePlots(order: TemplateOrder | undefined) { if (!definedAndNonEmpty(order)) { return @@ -330,6 +339,20 @@ export class PlotsModel extends ModelWithPersistence { this.fetchedRevs.delete(id) } + private getCLIIdToLabel() { + const mapping: { [shortSha: string]: string } = {} + + for (const rev of this.getSelectedRevisions()) { + mapping[this.getCLIId(rev)] = rev + } + + return mapping + } + + private getCLIId(label: string) { + return this.branchRevisions[label] || label + } + private getSelectedRevisions() { return this.experiments.getSelectedRevisions().map(({ label }) => label) } diff --git a/extension/src/test/fixtures/plotsDiff/index.ts b/extension/src/test/fixtures/plotsDiff/index.ts index 91c1833d58..28f93f6297 100644 --- a/extension/src/test/fixtures/plotsDiff/index.ts +++ b/extension/src/test/fixtures/plotsDiff/index.ts @@ -14,12 +14,13 @@ import { } from '../../../plots/webview/contract' import { join } from '../../util/path' import { copyOriginalColors } from '../../../experiments/model/status/colors' +import { getCLIBranchId, replaceBranchCLIId } from './util' const basicVega = { [join('logs', 'loss.tsv')]: [ { type: PlotsType.VEGA, - revisions: ['workspace', 'main', '42b8736', '1ba7bcd', '4fb124a'], + revisions: ['workspace', '53c3851', '42b8736', '1ba7bcd', '4fb124a'], datapoints: { workspace: [ { @@ -68,7 +69,7 @@ const basicVega = { timestamp: '1641966351758' } ], - main: [ + '53c3851': [ { loss: '2.298783302307129', step: '0', @@ -361,8 +362,8 @@ const getImageData = (baseUrl: string, joinFunc = join) => ({ }, { type: PlotsType.IMAGE, - revisions: ['main'], - url: joinFunc(baseUrl, 'main_plots_acc.png') + revisions: ['53c3851'], + url: joinFunc(baseUrl, '53c3851_plots_acc.png') }, { type: PlotsType.IMAGE, @@ -388,8 +389,8 @@ const getImageData = (baseUrl: string, joinFunc = join) => ({ }, { type: PlotsType.IMAGE, - revisions: ['main'], - url: joinFunc(baseUrl, 'main_plots_heatmap.png') + revisions: ['53c3851'], + url: joinFunc(baseUrl, '53c3851_plots_heatmap.png') }, { type: PlotsType.IMAGE, @@ -415,8 +416,8 @@ const getImageData = (baseUrl: string, joinFunc = join) => ({ }, { type: PlotsType.IMAGE, - revisions: ['main'], - url: joinFunc(baseUrl, 'main_plots_loss.png') + revisions: ['53c3851'], + url: joinFunc(baseUrl, '53c3851_plots_loss.png') }, { type: PlotsType.IMAGE, @@ -468,10 +469,12 @@ const extendedSpecs = (plotsOutput: TemplatePlots): TemplatePlotSection[] => { data: { values: expectedRevisions.flatMap(revision => - originalPlot.datapoints?.[revision].map(values => ({ - ...values, - rev: revision - })) + originalPlot.datapoints?.[getCLIBranchId(revision)].map( + values => ({ + ...values, + rev: revision + }) + ) ) || [] } } as TopLevelSpec, @@ -557,7 +560,7 @@ export const getComparisonWebviewMessage = ( for (const [path, plots] of Object.entries(getImageData(baseUrl, joinFunc))) { const revisionsAcc: ComparisonRevisionData = {} for (const { url, revisions } of plots) { - const revision = revisions?.[0] + const revision = replaceBranchCLIId(revisions?.[0]) if (!revision) { continue } diff --git a/extension/src/test/fixtures/plotsDiff/staticImages/main_plots_acc.png b/extension/src/test/fixtures/plotsDiff/staticImages/53c3851_plots_acc.png similarity index 100% rename from extension/src/test/fixtures/plotsDiff/staticImages/main_plots_acc.png rename to extension/src/test/fixtures/plotsDiff/staticImages/53c3851_plots_acc.png diff --git a/extension/src/test/fixtures/plotsDiff/staticImages/main_plots_heatmap.png b/extension/src/test/fixtures/plotsDiff/staticImages/53c3851_plots_heatmap.png similarity index 100% rename from extension/src/test/fixtures/plotsDiff/staticImages/main_plots_heatmap.png rename to extension/src/test/fixtures/plotsDiff/staticImages/53c3851_plots_heatmap.png diff --git a/extension/src/test/fixtures/plotsDiff/staticImages/main_plots_loss.png b/extension/src/test/fixtures/plotsDiff/staticImages/53c3851_plots_loss.png similarity index 100% rename from extension/src/test/fixtures/plotsDiff/staticImages/main_plots_loss.png rename to extension/src/test/fixtures/plotsDiff/staticImages/53c3851_plots_loss.png diff --git a/extension/src/test/fixtures/plotsDiff/util.ts b/extension/src/test/fixtures/plotsDiff/util.ts new file mode 100644 index 0000000000..f6b5198360 --- /dev/null +++ b/extension/src/test/fixtures/plotsDiff/util.ts @@ -0,0 +1,13 @@ +export const replaceBranchCLIId = (revision: string): string => { + if (revision === '53c3851') { + return 'main' + } + return revision +} + +export const getCLIBranchId = (revision: string): string => { + if (revision === 'main') { + return '53c3851' + } + return revision +} diff --git a/extension/src/test/fixtures/plotsDiff/vega.ts b/extension/src/test/fixtures/plotsDiff/vega.ts index f22490676d..9ef0b08fa6 100644 --- a/extension/src/test/fixtures/plotsDiff/vega.ts +++ b/extension/src/test/fixtures/plotsDiff/vega.ts @@ -5,7 +5,7 @@ const data = { { multiView: false, type: 'vega', - revisions: ['workspace', 'main', '42b8736', '1ba7bcd', '4fb124a'], + revisions: ['workspace', '53c3851', '42b8736', '1ba7bcd', '4fb124a'], datapoints: { workspace: [ { @@ -54,7 +54,7 @@ const data = { timestamp: '1641966351759' } ], - main: [ + '53c3851': [ { acc: '0.123', step: '0', @@ -323,7 +323,7 @@ const data = { { multiView: true, type: 'vega', - revisions: ['workspace', 'main', '42b8736', '1ba7bcd', '4fb124a'], + revisions: ['workspace', '53c3851', '42b8736', '1ba7bcd', '4fb124a'], datapoints: { workspace: [ { actual: 7, predicted: 7 }, @@ -10327,7 +10327,7 @@ const data = { { actual: 5, predicted: 0 }, { actual: 6, predicted: 0 } ], - main: [ + '53c3851': [ { actual: 7, predicted: 7 }, { actual: 2, predicted: 0 }, { actual: 1, predicted: 1 }, diff --git a/extension/src/test/suite/plots/data/index.test.ts b/extension/src/test/suite/plots/data/index.test.ts index 583bd66dcd..d8e59fe60e 100644 --- a/extension/src/test/suite/plots/data/index.test.ts +++ b/extension/src/test/suite/plots/data/index.test.ts @@ -22,7 +22,8 @@ suite('Plots Data Test Suite', () => { const buildPlotsData = ( experimentIsRunning: boolean, missingRevisions: string[] = [], - mutableRevisions: string[] = [] + mutableRevisions: string[] = [], + defaultRevs: string[] = [] ) => { const { internalCommands, updatesPaused, mockPlotsDiff, cliRunner } = buildDependencies(disposable) @@ -35,8 +36,10 @@ suite('Plots Data Test Suite', () => { const mockGetMissingRevisions = stub().returns(missingRevisions) const mockGetMutableRevisions = stub().returns(mutableRevisions) + const mockGetDefaultRevs = stub().returns(defaultRevs) const mockPlotsModel = { + getDefaultRevs: mockGetDefaultRevs, getMissingRevisions: mockGetMissingRevisions, getMutableRevisions: mockGetMutableRevisions } as unknown as PlotsModel @@ -59,20 +62,38 @@ suite('Plots Data Test Suite', () => { }) it('should call plots diff when there are no revisions to fetch and no experiment is running (workspace updates)', async () => { - const { data, mockPlotsDiff } = buildPlotsData(false) + const defaultRevisions = ['workspace', '4d78b9e'] + const { data, mockPlotsDiff } = buildPlotsData( + false, + [], + [], + defaultRevisions + ) await data.update() expect(mockPlotsDiff).to.be.calledOnce - expect(mockPlotsDiff).to.be.calledWithExactly(dvcDemoPath) + expect(mockPlotsDiff).to.be.calledWithExactly( + dvcDemoPath, + ...defaultRevisions + ) }) it('should call plots diff when an experiment is running in the workspace (live updates)', async () => { - const { data, mockPlotsDiff } = buildPlotsData(true, [], ['workspace']) + const defaultRevisions = ['workspace', '4d78b9e'] + const { data, mockPlotsDiff } = buildPlotsData( + true, + [], + ['workspace'], + defaultRevisions + ) await data.update() - expect(mockPlotsDiff).to.be.calledWithExactly(dvcDemoPath) + expect(mockPlotsDiff).to.be.calledWithExactly( + dvcDemoPath, + ...defaultRevisions + ) }) it('should call plots diff when an experiment is running in a temporary directory (live updates)', async () => { @@ -87,7 +108,7 @@ suite('Plots Data Test Suite', () => { it('should call plots diff when an experiment is running and there are missing revisions (checkpoints)', async () => { const { data, mockPlotsDiff } = buildPlotsData( true, - ['main', '4fb124a', '42b8736', '1ba7bcd'], + ['53c3851', '4fb124a', '42b8736', '1ba7bcd'], [] ) @@ -99,14 +120,14 @@ suite('Plots Data Test Suite', () => { '1ba7bcd', '42b8736', '4fb124a', - 'main' + '53c3851' ) }) it('should call plots diff when an experiment is running and there are missing revisions and one of them is mutable', async () => { const { data, mockPlotsDiff } = buildPlotsData( true, - ['main', '4fb124a', '42b8736', '1ba7bcd'], + ['53c3851', '4fb124a', '42b8736', '1ba7bcd'], ['1ba7bcd'] ) @@ -118,7 +139,7 @@ suite('Plots Data Test Suite', () => { '1ba7bcd', '42b8736', '4fb124a', - 'main' + '53c3851' ) }) }) diff --git a/extension/src/test/suite/plots/index.test.ts b/extension/src/test/suite/plots/index.test.ts index 515591cc99..d57fc6ac1f 100644 --- a/extension/src/test/suite/plots/index.test.ts +++ b/extension/src/test/suite/plots/index.test.ts @@ -64,7 +64,7 @@ suite('Plots Test Suite', () => { '1ba7bcd', '42b8736', '4fb124a', - 'main', + '53c3851', 'workspace' ) mockPlotsDiff.resetHistory() @@ -160,7 +160,7 @@ suite('Plots Test Suite', () => { expect(mockPlotsDiff).to.be.calledOnce expect(mockPlotsDiff).to.be.calledWithExactly( dvcDemoPath, - 'main', + '9235a02', 'workspace' ) }) @@ -184,9 +184,9 @@ suite('Plots Test Suite', () => { url: join(basePlotsUrl, 'workspace_plots_acc.png') }, { - revisions: ['another-branch'], + revisions: ['9235a028'], type: PlotsType.IMAGE, - url: join(basePlotsUrl, 'another-branch_plots_acc.png') + url: join(basePlotsUrl, '9235a028_plots_acc.png') } ] }) @@ -220,7 +220,7 @@ suite('Plots Test Suite', () => { ) const plotsSentEvent = new Promise(resolve => mockSendPlots.callsFake(() => { - if (isEqual(plotsModel.getMissingRevisions(), [])) { + if (isEqual(plotsModel.getMissingRevisions(), ['9235a02'])) { resolve(undefined) } }) @@ -235,7 +235,7 @@ suite('Plots Test Suite', () => { expect(mockPlotsDiff).to.be.calledOnce expect(mockPlotsDiff).to.be.calledWithExactly( dvcDemoPath, - 'another-branch', + '9235a02', 'workspace' ) @@ -253,7 +253,7 @@ suite('Plots Test Suite', () => { '1ba7bcd', '42b8736', '4fb124a', - 'main', + '53c3851', 'workspace' ) }).timeout(WEBVIEW_TEST_TIMEOUT) @@ -620,7 +620,7 @@ suite('Plots Test Suite', () => { undefined ) expect(mockPlotsDiff).to.be.calledOnce - expect(mockPlotsDiff).to.be.calledWithExactly(dvcDemoPath, 'main') + expect(mockPlotsDiff).to.be.calledWithExactly(dvcDemoPath, '53c3851') }).timeout(WEBVIEW_TEST_TIMEOUT) it('should handle a message to manually refresh all visible plots from the webview', async () => { @@ -658,7 +658,7 @@ suite('Plots Test Suite', () => { '1ba7bcd', '42b8736', '4fb124a', - 'main', + '53c3851', 'workspace' ) }).timeout(WEBVIEW_TEST_TIMEOUT) diff --git a/extension/src/test/suite/plots/paths/tree.test.ts b/extension/src/test/suite/plots/paths/tree.test.ts index 89fce5147e..28e1da26e1 100644 --- a/extension/src/test/suite/plots/paths/tree.test.ts +++ b/extension/src/test/suite/plots/paths/tree.test.ts @@ -151,7 +151,7 @@ suite('Plots Paths Tree Test Suite', () => { '1ba7bcd', '42b8736', '4fb124a', - 'main', + '53c3851', 'workspace' ) }).timeout(WEBVIEW_TEST_TIMEOUT)