Skip to content

Commit

Permalink
Patch plots for branches containing path separators (#1949)
Browse files Browse the repository at this point in the history
* use short sha to fetch HEAD plots data (workaround branch names containing path separators)

* refactor
mattseddon authored Jun 27, 2022
1 parent 4d78b9e commit c83e763
Showing 13 changed files with 173 additions and 68 deletions.
16 changes: 13 additions & 3 deletions extension/src/plots/data/index.ts
Original file line number Diff line number Diff line change
@@ -46,8 +46,7 @@ export class PlotsData extends BaseData<{ data: PlotsOutput; revs: string[] }> {
return
}

const args = sameContents(revs, ['workspace']) ? [] : revs

const args = this.getArgs(revs)
const data = await this.internalCommands.executeCommand<PlotsOutput>(
AvailableCommands.PLOTS_DIFF,
this.dvcRoot,
@@ -58,7 +57,7 @@ export class PlotsData extends BaseData<{ data: PlotsOutput; revs: string[] }> {

this.compareFiles(files)

return this.notifyChanged({ data, revs })
return this.notifyChanged({ data, revs: args })
}

public managedUpdate() {
@@ -72,4 +71,15 @@ export class PlotsData extends BaseData<{ data: PlotsOutput; revs: string[] }> {
public setModel(model: PlotsModel) {
this.model = model
}

private getArgs(revs: string[]) {
if (
this.model &&
(sameContents(revs, ['workspace']) || sameContents(revs, []))
) {
return this.model.getDefaultRevs()
}

return revs
}
}
23 changes: 20 additions & 3 deletions extension/src/plots/model/collect.test.ts
Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@ import {
uniqueValues
} from '../../util/array'
import { TemplatePlot } from '../webview/contract'
import { getCLIBranchId } from '../../test/fixtures/plotsDiff/util'

const logsLossPath = join('logs', 'loss.tsv')

@@ -226,7 +227,17 @@ describe('collectMetricOrder', () => {

describe('collectData', () => {
it('should return the expected output from the test fixture', () => {
const { revisionData, comparisonData } = collectData(plotsDiffFixture)
const mapping = {
'1ba7bcd': '1ba7bcd',
'42b8736': '42b8736',
'4fb124a': '4fb124a',
'53c3851': 'main',
workspace: 'workspace'
}
const { revisionData, comparisonData } = collectData(
plotsDiffFixture,
mapping
)
const revisions = ['workspace', 'main', '42b8736', '1ba7bcd', '4fb124a']

const values =
@@ -237,7 +248,7 @@ describe('collectData', () => {
expect(isEmpty(values)).toBeFalsy()

for (const revision of revisions) {
const expectedValues = values[revision].map(value => ({
const expectedValues = values[getCLIBranchId(revision)].map(value => ({
...value,
rev: revision
}))
@@ -287,7 +298,13 @@ describe('collectTemplates', () => {
})

describe('collectWorkspaceRaceConditionData', () => {
const { comparisonData, revisionData } = collectData(plotsDiffFixture)
const { comparisonData, revisionData } = collectData(plotsDiffFixture, {
'1ba7bcd': '1ba7bcd',
'42b8736': '42b8736',
'4fb124a': '4fb124a',
'53c3851': 'main',
workspace: 'workspace'
})

it('should return no overwrite data if there is no selected checkpoint experiment running in the workspace', () => {
const { overwriteComparisonData, overwriteRevisionData } =
56 changes: 37 additions & 19 deletions extension/src/plots/model/collect.ts
Original file line number Diff line number Diff line change
@@ -312,31 +312,39 @@ export const collectMetricOrder = (
type RevisionPathData = { [path: string]: Record<string, unknown>[] }

export type RevisionData = {
[revision: string]: RevisionPathData
[label: string]: RevisionPathData
}

export type ComparisonData = {
[revision: string]: {
[label: string]: {
[path: string]: ImagePlot
}
}

export type CLIRevisionIdToLabel = { [shortSha: string]: string }

const collectImageData = (
acc: ComparisonData,
path: string,
plot: ImagePlot
plot: ImagePlot,
cliIdToLabel: CLIRevisionIdToLabel
) => {
const rev = plot.revisions?.[0]

if (!rev) {
return
}

if (!acc[rev]) {
acc[rev] = {}
const label = cliIdToLabel[rev]

if (!label) {
return
}

if (!acc[label]) {
acc[label] = {}
}

acc[rev][path] = plot
acc[label][path] = plot
}

const collectDatapoints = (
@@ -353,15 +361,17 @@ const collectDatapoints = (
const collectPlotData = (
acc: RevisionData,
path: string,
plot: TemplatePlot
plot: TemplatePlot,
cliIdToLabel: CLIRevisionIdToLabel
) => {
for (const rev of plot.revisions || []) {
if (!acc[rev]) {
acc[rev] = {}
for (const id of plot.revisions || []) {
const label = cliIdToLabel[id]
if (!acc[label]) {
acc[label] = {}
}
acc[rev][path] = []
acc[label][path] = []

collectDatapoints(acc, path, rev, plot.datapoints?.[rev])
collectDatapoints(acc, path, label, plot.datapoints?.[id])
}
}

@@ -370,25 +380,33 @@ type DataAccumulator = {
comparisonData: ComparisonData
}

const collectPathData = (acc: DataAccumulator, path: string, plots: Plot[]) => {
const collectPathData = (
acc: DataAccumulator,
path: string,
plots: Plot[],
cliIdToLabel: CLIRevisionIdToLabel
) => {
for (const plot of plots) {
if (isImagePlot(plot)) {
collectImageData(acc.comparisonData, path, plot)
collectImageData(acc.comparisonData, path, plot, cliIdToLabel)
continue
}

collectPlotData(acc.revisionData, path, plot)
collectPlotData(acc.revisionData, path, plot, cliIdToLabel)
}
}

export const collectData = (data: PlotsOutput): DataAccumulator => {
export const collectData = (
data: PlotsOutput,
cliIdToLabel: CLIRevisionIdToLabel
): DataAccumulator => {
const acc = {
comparisonData: {},
revisionData: {}
} as DataAccumulator

for (const [path, plots] of Object.entries(data)) {
collectPathData(acc, path, plots)
collectPathData(acc, path, plots, cliIdToLabel)
}

return acc
@@ -543,7 +561,7 @@ export const collectBranchRevisionDetails = (
const branchRevisions: Record<string, string> = {}
for (const { id, sha } of branchShas) {
if (sha) {
branchRevisions[id] = sha
branchRevisions[id] = shortenForLabel(sha)
}
}
return branchRevisions
37 changes: 30 additions & 7 deletions extension/src/plots/model/index.ts
Original file line number Diff line number Diff line change
@@ -96,10 +96,15 @@ export class PlotsModel extends ModelWithPersistence {
}

public async transformAndSetPlots(data: PlotsOutput, revs: string[]) {
this.fetchedRevs = new Set([...this.fetchedRevs, ...revs])
const cliIdToLabel = this.getCLIIdToLabel()

this.fetchedRevs = new Set([
...this.fetchedRevs,
...revs.map(rev => cliIdToLabel[rev])
])

const [{ comparisonData, revisionData }, templates] = await Promise.all([
collectData(data),
collectData(data, cliIdToLabel),
collectTemplates(data)
])

@@ -168,9 +173,9 @@ export class PlotsModel extends ModelWithPersistence {
...Object.keys(this.revisionData)
])

return this.getSelectedRevisions().filter(
revision => !cachedRevisions.has(revision)
)
return this.getSelectedRevisions()
.filter(label => !cachedRevisions.has(label))
.map(label => this.getCLIId(label))
}

public getMutableRevisions() {
@@ -186,16 +191,20 @@ export class PlotsModel extends ModelWithPersistence {
this.comparisonOrder,
this.experiments
.getSelectedRevisions()
.map(({ label: revision, displayColor, logicalGroupName, id }) => ({
.map(({ label, displayColor, logicalGroupName, id }) => ({
displayColor,
group: logicalGroupName,
id,
revision
revision: label
})),
'revision'
)
}

public getDefaultRevs() {
return ['workspace', ...Object.values(this.branchRevisions)]
}

public getTemplatePlots(order: TemplateOrder | undefined) {
if (!definedAndNonEmpty(order)) {
return
@@ -330,6 +339,20 @@ export class PlotsModel extends ModelWithPersistence {
this.fetchedRevs.delete(id)
}

private getCLIIdToLabel() {
const mapping: { [shortSha: string]: string } = {}

for (const rev of this.getSelectedRevisions()) {
mapping[this.getCLIId(rev)] = rev
}

return mapping
}

private getCLIId(label: string) {
return this.branchRevisions[label] || label
}

private getSelectedRevisions() {
return this.experiments.getSelectedRevisions().map(({ label }) => label)
}
29 changes: 16 additions & 13 deletions extension/src/test/fixtures/plotsDiff/index.ts
Original file line number Diff line number Diff line change
@@ -14,12 +14,13 @@ import {
} from '../../../plots/webview/contract'
import { join } from '../../util/path'
import { copyOriginalColors } from '../../../experiments/model/status/colors'
import { getCLIBranchId, replaceBranchCLIId } from './util'

const basicVega = {
[join('logs', 'loss.tsv')]: [
{
type: PlotsType.VEGA,
revisions: ['workspace', 'main', '42b8736', '1ba7bcd', '4fb124a'],
revisions: ['workspace', '53c3851', '42b8736', '1ba7bcd', '4fb124a'],
datapoints: {
workspace: [
{
@@ -68,7 +69,7 @@ const basicVega = {
timestamp: '1641966351758'
}
],
main: [
'53c3851': [
{
loss: '2.298783302307129',
step: '0',
@@ -361,8 +362,8 @@ const getImageData = (baseUrl: string, joinFunc = join) => ({
},
{
type: PlotsType.IMAGE,
revisions: ['main'],
url: joinFunc(baseUrl, 'main_plots_acc.png')
revisions: ['53c3851'],
url: joinFunc(baseUrl, '53c3851_plots_acc.png')
},
{
type: PlotsType.IMAGE,
@@ -388,8 +389,8 @@ const getImageData = (baseUrl: string, joinFunc = join) => ({
},
{
type: PlotsType.IMAGE,
revisions: ['main'],
url: joinFunc(baseUrl, 'main_plots_heatmap.png')
revisions: ['53c3851'],
url: joinFunc(baseUrl, '53c3851_plots_heatmap.png')
},
{
type: PlotsType.IMAGE,
@@ -415,8 +416,8 @@ const getImageData = (baseUrl: string, joinFunc = join) => ({
},
{
type: PlotsType.IMAGE,
revisions: ['main'],
url: joinFunc(baseUrl, 'main_plots_loss.png')
revisions: ['53c3851'],
url: joinFunc(baseUrl, '53c3851_plots_loss.png')
},
{
type: PlotsType.IMAGE,
@@ -468,10 +469,12 @@ const extendedSpecs = (plotsOutput: TemplatePlots): TemplatePlotSection[] => {
data: {
values:
expectedRevisions.flatMap(revision =>
originalPlot.datapoints?.[revision].map(values => ({
...values,
rev: revision
}))
originalPlot.datapoints?.[getCLIBranchId(revision)].map(
values => ({
...values,
rev: revision
})
)
) || []
}
} as TopLevelSpec,
@@ -557,7 +560,7 @@ export const getComparisonWebviewMessage = (
for (const [path, plots] of Object.entries(getImageData(baseUrl, joinFunc))) {
const revisionsAcc: ComparisonRevisionData = {}
for (const { url, revisions } of plots) {
const revision = revisions?.[0]
const revision = replaceBranchCLIId(revisions?.[0])
if (!revision) {
continue
}
13 changes: 13 additions & 0 deletions extension/src/test/fixtures/plotsDiff/util.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
export const replaceBranchCLIId = (revision: string): string => {
if (revision === '53c3851') {
return 'main'
}
return revision
}

export const getCLIBranchId = (revision: string): string => {
if (revision === 'main') {
return '53c3851'
}
return revision
}
Loading

0 comments on commit c83e763

Please sign in to comment.