Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "image by step" plots to comparison section #4319

Merged
merged 18 commits into from
Aug 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions extension/src/cli/dvc/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ export const DEFAULT_CURRENT_BRANCH_COMMITS_TO_SHOW = 3
export const DEFAULT_OTHER_BRANCH_COMMITS_TO_SHOW = 1
export const NUM_OF_COMMITS_TO_INCREASE = 2

export const MULTI_IMAGE_PATH_REG = /[^/]+[/\\]\d+\.[a-z]+$/i

export enum Command {
ADD = 'add',
CHECKOUT = 'checkout',
Expand Down
50 changes: 49 additions & 1 deletion extension/src/cli/dvc/index.test.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import { join } from 'path'
import { EventEmitter } from 'vscode'
import { Disposable, Disposer } from '@hediet/std/disposable'
import { DvcCli } from '.'
import { Command } from './constants'
import { Command, MULTI_IMAGE_PATH_REG } from './constants'
import { CliResult, CliStarted, typeCheckCommands } from '..'
import { getProcessEnv } from '../../env'
import { createProcess } from '../../process/execution'
Expand Down Expand Up @@ -52,6 +53,53 @@ describe('typeCheckCommands', () => {
})
})

describe('Comparison Multi Image Regex', () => {
it('should match a nested image group directory', () => {
expect(
MULTI_IMAGE_PATH_REG.test(
join(
'extremely',
'super',
'super',
'super',
'nested',
'image',
'768.svg'
)
)
).toBe(true)
})

it('should match directories with spaces or special characters', () => {
expect(MULTI_IMAGE_PATH_REG.test(join('mis classified', '5.png'))).toBe(
true
)

expect(MULTI_IMAGE_PATH_REG.test(join('misclassified#^', '5.png'))).toBe(
true
)
})

it('should match different types of images', () => {
const imageFormats = ['svg', 'png', 'jpg', 'jpeg']
for (const format of imageFormats) {
expect(
MULTI_IMAGE_PATH_REG.test(join('misclassified', `5.${format}`))
).toBe(true)
}
})

it('should not match files that include none digits or do not have a file extension', () => {
expect(MULTI_IMAGE_PATH_REG.test(join('misclassified', 'five.png'))).toBe(
false
)
expect(MULTI_IMAGE_PATH_REG.test(join('misclassified', '5 4.png'))).toBe(
false
)
expect(MULTI_IMAGE_PATH_REG.test(join('misclassified', '5'))).toBe(false)
})
})

describe('executeDvcProcess', () => {
it('should pass the correct details to the underlying process given no path to the cli or python binary path', async () => {
const existingPath = joinEnvPath(
Expand Down
4 changes: 3 additions & 1 deletion extension/src/fileSystem/util.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { sep } from 'path'
import { sep, parse } from 'path'

export const getPathArray = (path: string): string[] => path.split(sep)

Expand All @@ -18,3 +18,5 @@ export const getParent = (pathArray: string[], idx: number) => {

export const removeTrailingSlash = (path: string): string =>
path.endsWith(sep) ? path.slice(0, -1) : path

export const getFileNameWithoutExt = (path: string) => parse(path).name
7 changes: 4 additions & 3 deletions extension/src/plots/model/collect.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,17 +84,18 @@ describe('collectData', () => {
expect(Object.keys(comparisonData.main)).toStrictEqual([
join('plots', 'acc.png'),
heatmapPlot,
join('plots', 'loss.png')
join('plots', 'loss.png'),
join('plots', 'image')
])

const testBranchHeatmap = comparisonData['test-branch'][heatmapPlot]

expect(testBranchHeatmap).toBeDefined()
expect(testBranchHeatmap).toStrictEqual(
expect(testBranchHeatmap).toStrictEqual([
plotsDiffFixture.data[heatmapPlot].find(({ revisions }) =>
sameContents(revisions as string[], ['test-branch'])
)
)
])
})
})

Expand Down
113 changes: 110 additions & 3 deletions extension/src/plots/model/collect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ import {
TemplatePlotSection,
PlotsType,
CustomPlotData,
CustomPlotValues
CustomPlotValues,
ComparisonRevisionData,
ComparisonPlotImg
} from '../webview/contract'
import { PlotsOutput } from '../../cli/dvc/contract'
import { splitColumnPath } from '../../experiments/columns/paths'
Expand All @@ -34,6 +36,12 @@ import {
import { StrokeDashEncoding } from '../multiSource/constants'
import { exists } from '../../fileSystem'
import { hasKey } from '../../util/object'
import { MULTI_IMAGE_PATH_REG } from '../../cli/dvc/constants'
import {
getFileNameWithoutExt,
getParent,
getPathArray
} from '../../fileSystem/util'

export const getCustomPlotId = (metric: string, param: string) =>
`custom-${metric}-${param}`
Expand Down Expand Up @@ -126,18 +134,31 @@ export type RevisionData = {
[label: string]: RevisionPathData
}

type ComparisonDataImgPlot = ImagePlot & { ind?: number }

export type ComparisonData = {
[label: string]: {
[path: string]: ImagePlot
[path: string]: ComparisonDataImgPlot[]
}
}

const getMultiImagePath = (path: string) =>
getParent(getPathArray(path), 0) as string

const getMultiImageInd = (path: string) => {
const fileName = getFileNameWithoutExt(path)
return Number(fileName)
}

const collectImageData = (
acc: ComparisonData,
path: string,
plot: ImagePlot
) => {
const isMultiImgPlot = MULTI_IMAGE_PATH_REG.test(path)
const pathLabel = isMultiImgPlot ? getMultiImagePath(path) : path
const id = plot.revisions?.[0]

if (!id) {
return
}
Expand All @@ -146,7 +167,17 @@ const collectImageData = (
acc[id] = {}
}

acc[id][path] = plot
if (!acc[id][pathLabel]) {
acc[id][pathLabel] = []
}

const imgPlot: ComparisonDataImgPlot = { ...plot }

if (isMultiImgPlot) {
imgPlot.ind = getMultiImageInd(path)
}

acc[id][pathLabel].push(imgPlot)
}

const collectDatapoints = (
Expand Down Expand Up @@ -202,6 +233,16 @@ const collectPathData = (acc: DataAccumulator, path: string, plots: Plot[]) => {
}
}

const sortComparisonImgPaths = (acc: DataAccumulator) => {
for (const [label, paths] of Object.entries(acc.comparisonData)) {
for (const path of Object.keys(paths)) {
acc.comparisonData[label][path].sort(
(img1, img2) => (img1.ind || 0) - (img2.ind || 0)
)
}
}
}

export const collectData = (output: PlotsOutput): DataAccumulator => {
const { data } = output
const acc = {
Expand All @@ -213,6 +254,72 @@ export const collectData = (output: PlotsOutput): DataAccumulator => {
collectPathData(acc, path, plots)
}

sortComparisonImgPaths(acc)

return acc
}

type ComparisonPlotsAcc = { path: string; revisions: ComparisonRevisionData }[]

type GetComparisonPlotImg = (
img: ImagePlot,
id: string,
path: string
) => ComparisonPlotImg

const collectSelectedPathComparisonPlots = ({
acc,
comparisonData,
path,
selectedRevisionIds,
getComparisonPlotImg
}: {
acc: ComparisonPlotsAcc
comparisonData: ComparisonData
path: string
selectedRevisionIds: string[]
getComparisonPlotImg: GetComparisonPlotImg
}) => {
const pathRevisions = {
path,
revisions: {} as ComparisonRevisionData
}

for (const id of selectedRevisionIds) {
const imgs = comparisonData[id]?.[path]
pathRevisions.revisions[id] = {
id,
imgs: imgs
? imgs.map(img => getComparisonPlotImg(img, id, path))
: [{ errors: undefined, loading: false, url: undefined }]
}
}
acc.push(pathRevisions)
}

export const collectSelectedComparisonPlots = ({
comparisonData,
paths,
selectedRevisionIds,
getComparisonPlotImg
}: {
comparisonData: ComparisonData
paths: string[]
selectedRevisionIds: string[]
getComparisonPlotImg: GetComparisonPlotImg
}) => {
const acc: ComparisonPlotsAcc = []

for (const path of paths) {
collectSelectedPathComparisonPlots({
acc,
comparisonData,
getComparisonPlotImg,
path,
selectedRevisionIds
})
}

return acc
}

Expand Down
56 changes: 21 additions & 35 deletions extension/src/plots/model/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ import {
collectImageUrl,
collectIdShas,
collectSelectedTemplatePlotRawData,
collectCustomPlotRawData
collectCustomPlotRawData,
collectSelectedComparisonPlots
} from './collect'
import { getRevisionSummaryColumns } from './util'
import {
Expand All @@ -21,9 +22,7 @@ import {
CustomPlotsOrderValue
} from './custom'
import {
ComparisonPlots,
Revision,
ComparisonRevisionData,
DEFAULT_SECTION_COLLAPSED,
DEFAULT_SECTION_NB_ITEMS_PER_ROW_OR_WIDTH,
PlotsSection,
Expand All @@ -33,7 +32,8 @@ import {
DEFAULT_HEIGHT,
DEFAULT_NB_ITEMS_PER_ROW,
PlotHeight,
SmoothPlotValues
SmoothPlotValues,
ImagePlot
} from '../webview/contract'
import {
EXPERIMENT_WORKSPACE_ID,
Expand Down Expand Up @@ -427,37 +427,23 @@ export class PlotsModel extends ModelWithPersistence {
paths: string[],
selectedRevisionIds: string[]
) {
const acc: ComparisonPlots = []
for (const path of paths) {
this.collectSelectedPathComparisonPlots(acc, path, selectedRevisionIds)
}
return acc
}

private collectSelectedPathComparisonPlots(
acc: ComparisonPlots,
path: string,
selectedRevisionIds: string[]
) {
const pathRevisions = {
path,
revisions: {} as ComparisonRevisionData
}

for (const id of selectedRevisionIds) {
const image = this.comparisonData?.[id]?.[path]
const errors = this.errors.getImageErrors(path, id)
const fetched = this.fetchedRevs.has(id)
const url = collectImageUrl(image, fetched)
const loading = !fetched && !url
pathRevisions.revisions[id] = {
errors,
id,
loading,
url
}
}
acc.push(pathRevisions)
return collectSelectedComparisonPlots({
comparisonData: this.comparisonData,
getComparisonPlotImg: (image: ImagePlot, id: string, path: string) => {
const errors = this.errors.getImageErrors(path, id)
const fetched = this.fetchedRevs.has(id)
const url = collectImageUrl(image, fetched)
const loading = !fetched && !url

return {
errors,
loading,
url
}
},
paths,
selectedRevisionIds
})
}

private getSelectedTemplatePlots(
Expand Down
7 changes: 7 additions & 0 deletions extension/src/plots/paths/collect.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ describe('collectPaths', () => {
revisions: new Set(REVISIONS),
type: new Set(['comparison'])
},
{
hasChildren: false,
parentPath: 'plots',
path: join('plots', 'image'),
revisions: new Set(REVISIONS),
type: new Set(['comparison'])
},
{
hasChildren: false,
parentPath: 'logs',
Expand Down
Loading