Skip to content

Commit

Permalink
Consolidate collectCustomPlots (#3466)
Browse files Browse the repository at this point in the history
  • Loading branch information
julieg18 authored Mar 20, 2023
1 parent 9ae29fd commit 1ab409c
Show file tree
Hide file tree
Showing 9 changed files with 393 additions and 594 deletions.
4 changes: 4 additions & 0 deletions extension/src/experiments/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,10 @@ export class Experiments extends BaseRepository<TableData> {
return this.experiments.getExperimentCount()
}

public getExperimentsWithCheckpoints() {
return this.experiments.getExperimentsWithCheckpoints()
}

public async selectExperiments() {
const experiments = this.experiments.getExperimentsWithCheckpoints()

Expand Down
4 changes: 4 additions & 0 deletions extension/src/experiments/model/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ export type ExperimentWithCheckpoints = Experiment & {
checkpoints?: Experiment[]
}

export type ExperimentWithDefinedCheckpoints = Experiment & {
checkpoints: Experiment[]
}

export enum ExperimentType {
WORKSPACE = 'workspace',
COMMIT = 'commit',
Expand Down
9 changes: 4 additions & 5 deletions extension/src/plots/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import { Experiments } from '../experiments'
import { Resource } from '../resourceLocator'
import { InternalCommands } from '../commands/internal'
import { definedAndNonEmpty } from '../util/array'
import { ExperimentsOutput } from '../cli/dvc/contract'
import { TEMP_PLOTS_DIR } from '../cli/dvc/constants'
import { removeDir } from '../fileSystem'
import { Toast } from '../vscode/toast'
Expand Down Expand Up @@ -173,7 +172,7 @@ export class Plots extends BaseRepository<TPlotsData> {
waitForInitialExpData.dispose()
this.data.setMetricFiles(data)
this.setupExperimentsListener(experiments)
void this.initializeData(data)
void this.initializeData()
}
})
)
Expand All @@ -184,7 +183,7 @@ export class Plots extends BaseRepository<TPlotsData> {
experiments.onDidChangeExperiments(async data => {
if (data) {
await Promise.all([
this.plots.transformAndSetExperiments(data),
this.plots.transformAndSetExperiments(),
this.data.setMetricFiles(data)
])
}
Expand All @@ -200,8 +199,8 @@ export class Plots extends BaseRepository<TPlotsData> {
)
}

private async initializeData(data: ExperimentsOutput) {
await this.plots.transformAndSetExperiments(data)
private async initializeData() {
await this.plots.transformAndSetExperiments()
void this.data.managedUpdate()
await Promise.all([
this.data.isReady(),
Expand Down
207 changes: 51 additions & 156 deletions extension/src/plots/model/collect.test.ts
Original file line number Diff line number Diff line change
@@ -1,154 +1,93 @@
import { join } from 'path'
import omit from 'lodash.omit'
import isEmpty from 'lodash.isempty'
import {
collectData,
collectTemplates,
collectOverrideRevisionDetails,
collectCustomPlots,
collectCustomCheckpointPlots,
collectCustomPlotData
collectCustomPlots
} from './collect'
import { isCheckpointPlot } from './custom'
import plotsDiffFixture from '../../test/fixtures/plotsDiff/output'
import customPlotsFixture, {
customPlotsOrderFixture,
checkpointPlotsFixture
experimentsWithCheckpoints
} from '../../test/fixtures/expShow/base/customPlots'
import {
ExperimentStatus,
EXPERIMENT_WORKSPACE_ID
} from '../../cli/dvc/contract'
import { sameContents } from '../../util/array'
import {
CheckpointPlot,
CustomPlot,
CustomPlotData,
CustomPlotType,
DEFAULT_NB_ITEMS_PER_ROW,
DEFAULT_PLOT_HEIGHT,
TemplatePlot
} from '../webview/contract'
import { getCLICommitId } from '../../test/fixtures/plotsDiff/util'
import expShowFixture from '../../test/fixtures/expShow/base/output'
import modifiedFixture from '../../test/fixtures/expShow/modified/output'
import { SelectedExperimentWithColor } from '../../experiments/model'
import { Experiment } from '../../experiments/webview/contract'

const logsLossPath = join('logs', 'loss.tsv')

const logsLossPlot = (plotsDiffFixture[logsLossPath][0] || {}) as TemplatePlot

const getCustomPlotFromCustomPlotData = ({
id,
metric,
param,
type,
values
}: CustomPlotData) =>
({
id,
metric,
param,
type,
values
} as CustomPlot)

describe('collectCustomPlots', () => {
const defaultFuncArgs = {
experiments: experimentsWithCheckpoints,
hasCheckpoints: true,
height: DEFAULT_PLOT_HEIGHT,
nbItemsPerRow: DEFAULT_NB_ITEMS_PER_ROW,
plotsOrderValues: customPlotsOrderFixture,
selectedRevisions: customPlotsFixture.colors?.domain
}

it('should return the expected data from the test fixture', () => {
const expectedOutput: CustomPlot[] = customPlotsFixture.plots.map(
getCustomPlotFromCustomPlotData
)
const data = collectCustomPlots(
customPlotsOrderFixture,
checkpointPlotsFixture,
[
{
id: '12345',
label: '123',
metrics: {
'summary.json': {
accuracy: 0.3724166750907898,
loss: 2.0205044746398926
}
},
name: 'exp-e7a67',
params: { 'params.yaml': { dropout: 0.15, epochs: 2 } }
},
{
id: '12345',
label: '123',
metrics: {
'summary.json': {
accuracy: 0.4668000042438507,
loss: 1.9293040037155151
}
},
name: 'test-branch',
params: { 'params.yaml': { dropout: 0.122, epochs: 2 } }
},
{
id: '12345',
label: '123',
metrics: {
'summary.json': {
accuracy: 0.5926499962806702,
loss: 1.775016188621521
}
},
name: 'exp-83425',
params: { 'params.yaml': { dropout: 0.124, epochs: 5 } }
}
]
const expectedOutput: CustomPlotData[] = customPlotsFixture.plots
const data = collectCustomPlots(defaultFuncArgs)
expect(data).toStrictEqual(expectedOutput)
})

it('should return only custom plots if there no selected revisions', () => {
const expectedOutput: CustomPlotData[] = customPlotsFixture.plots.filter(
plot => plot.type !== CustomPlotType.CHECKPOINT
)
const data = collectCustomPlots({
...defaultFuncArgs,
selectedRevisions: undefined
})

expect(data).toStrictEqual(expectedOutput)
})
})

describe('collectCustomPlotData', () => {
it('should return the expected data from test fixture', () => {
const expectedMetricVsParamPlotData = customPlotsFixture.plots[0]
const expectedCheckpointsPlotData = customPlotsFixture.plots[2]
const metricVsParamPlot = getCustomPlotFromCustomPlotData(
expectedMetricVsParamPlotData
)
const checkpointsPlot = getCustomPlotFromCustomPlotData(
expectedCheckpointsPlotData
it('should return only custom plots if checkpoints are not enabled', () => {
const expectedOutput: CustomPlotData[] = customPlotsFixture.plots.filter(
plot => plot.type !== CustomPlotType.CHECKPOINT
)
const data = collectCustomPlots({
...defaultFuncArgs,
hasCheckpoints: false
})

const metricVsParamData = collectCustomPlotData(
metricVsParamPlot,
customPlotsFixture.colors,
customPlotsFixture.nbItemsPerRow,
customPlotsFixture.height
)
expect(data).toStrictEqual(expectedOutput)
})

const checkpointsData = collectCustomPlotData(
{
...checkpointsPlot,
values: [
...checkpointsPlot.values,
{
group: 'exp-123',
iteration: 1,
y: 1.4534177053451538
},
{
group: 'exp-123',
iteration: 2,
y: 1.757687
},
{
group: 'exp-123',
iteration: 3,
y: 1.989894
}
]
} as CheckpointPlot,
customPlotsFixture.colors,
customPlotsFixture.nbItemsPerRow,
customPlotsFixture.height
)
it('should return checkpoint plots with values only containing selected experiments data', () => {
const domain = customPlotsFixture.colors?.domain.slice(1) as string[]

const expectedOutput = customPlotsFixture.plots.map(plot => ({
...plot,
values: isCheckpointPlot(plot)
? plot.values.filter(value => domain.includes(value.group))
: plot.values
}))

expect(metricVsParamData).toStrictEqual(expectedMetricVsParamPlotData)
expect(checkpointsData).toStrictEqual(expectedCheckpointsPlotData)
const data = collectCustomPlots({
...defaultFuncArgs,
selectedRevisions: domain
})

expect(data).toStrictEqual(expectedOutput)
})
})

Expand Down Expand Up @@ -215,50 +154,6 @@ describe('collectData', () => {
})
})

describe('collectCustomCheckpointPlotsData', () => {
it('should return the expected data from the test fixture', () => {
const data = collectCustomCheckpointPlots(expShowFixture)

expect(data).toStrictEqual(checkpointPlotsFixture)
})

it('should provide a continuous series for a modified experiment', () => {
const data = collectCustomCheckpointPlots(modifiedFixture)

for (const { values } of Object.values(data)) {
const initialExperiment = values.filter(
point => point.group === 'exp-908bd'
)
const modifiedExperiment = values.find(
point => point.group === 'exp-01b3a'
)

const lastIterationInitial = initialExperiment?.slice(-1)[0]
const firstIterationModified = modifiedExperiment

expect(lastIterationInitial).not.toStrictEqual(firstIterationModified)
expect(omit(lastIterationInitial, 'group')).toStrictEqual(
omit(firstIterationModified, 'group')
)

const baseExperiment = values.filter(point => point.group === 'exp-920fc')
const restartedExperiment = values.find(
point => point.group === 'exp-9bc1b'
)

const iterationRestartedFrom = baseExperiment?.slice(5)[0]
const firstIterationAfterRestart = restartedExperiment

expect(iterationRestartedFrom).not.toStrictEqual(
firstIterationAfterRestart
)
expect(omit(iterationRestartedFrom, 'group')).toStrictEqual(
omit(firstIterationAfterRestart, 'group')
)
}
})
})

describe('collectTemplates', () => {
it('should return the expected output from the test fixture', () => {
const { content } = logsLossPlot
Expand Down
Loading

0 comments on commit 1ab409c

Please sign in to comment.