Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consolidate collectCustomPlots #3466

Merged
merged 13 commits into from
Mar 20, 2023
4 changes: 4 additions & 0 deletions extension/src/experiments/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,10 @@ export class Experiments extends BaseRepository<TableData> {
return this.experiments.getExperimentCount()
}

public getExperimentsWithCheckpoints() {
return this.experiments.getExperimentsWithCheckpoints()
}

public async selectExperiments() {
const experiments = this.experiments.getExperimentsWithCheckpoints()

Expand Down
4 changes: 4 additions & 0 deletions extension/src/experiments/model/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ export type ExperimentWithCheckpoints = Experiment & {
checkpoints?: Experiment[]
}

export type ExperimentWithDefinedCheckpoints = Experiment & {
checkpoints: Experiment[]
}

export enum ExperimentType {
WORKSPACE = 'workspace',
COMMIT = 'commit',
Expand Down
9 changes: 4 additions & 5 deletions extension/src/plots/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import { Experiments } from '../experiments'
import { Resource } from '../resourceLocator'
import { InternalCommands } from '../commands/internal'
import { definedAndNonEmpty } from '../util/array'
import { ExperimentsOutput } from '../cli/dvc/contract'
import { TEMP_PLOTS_DIR } from '../cli/dvc/constants'
import { removeDir } from '../fileSystem'
import { Toast } from '../vscode/toast'
Expand Down Expand Up @@ -173,7 +172,7 @@ export class Plots extends BaseRepository<TPlotsData> {
waitForInitialExpData.dispose()
this.data.setMetricFiles(data)
this.setupExperimentsListener(experiments)
void this.initializeData(data)
void this.initializeData()
}
})
)
Expand All @@ -184,7 +183,7 @@ export class Plots extends BaseRepository<TPlotsData> {
experiments.onDidChangeExperiments(async data => {
if (data) {
await Promise.all([
this.plots.transformAndSetExperiments(data),
this.plots.transformAndSetExperiments(),
this.data.setMetricFiles(data)
])
}
Expand All @@ -200,8 +199,8 @@ export class Plots extends BaseRepository<TPlotsData> {
)
}

private async initializeData(data: ExperimentsOutput) {
await this.plots.transformAndSetExperiments(data)
private async initializeData() {
await this.plots.transformAndSetExperiments()
void this.data.managedUpdate()
await Promise.all([
this.data.isReady(),
Expand Down
207 changes: 51 additions & 156 deletions extension/src/plots/model/collect.test.ts
Original file line number Diff line number Diff line change
@@ -1,154 +1,93 @@
import { join } from 'path'
import omit from 'lodash.omit'
import isEmpty from 'lodash.isempty'
import {
collectData,
collectTemplates,
collectOverrideRevisionDetails,
collectCustomPlots,
collectCustomCheckpointPlots,
collectCustomPlotData
collectCustomPlots
} from './collect'
import { isCheckpointPlot } from './custom'
import plotsDiffFixture from '../../test/fixtures/plotsDiff/output'
import customPlotsFixture, {
customPlotsOrderFixture,
checkpointPlotsFixture
experimentsWithCheckpoints
} from '../../test/fixtures/expShow/base/customPlots'
import {
ExperimentStatus,
EXPERIMENT_WORKSPACE_ID
} from '../../cli/dvc/contract'
import { sameContents } from '../../util/array'
import {
CheckpointPlot,
CustomPlot,
CustomPlotData,
CustomPlotType,
DEFAULT_NB_ITEMS_PER_ROW,
DEFAULT_PLOT_HEIGHT,
TemplatePlot
} from '../webview/contract'
import { getCLICommitId } from '../../test/fixtures/plotsDiff/util'
import expShowFixture from '../../test/fixtures/expShow/base/output'
import modifiedFixture from '../../test/fixtures/expShow/modified/output'
import { SelectedExperimentWithColor } from '../../experiments/model'
import { Experiment } from '../../experiments/webview/contract'

const logsLossPath = join('logs', 'loss.tsv')

const logsLossPlot = (plotsDiffFixture[logsLossPath][0] || {}) as TemplatePlot

const getCustomPlotFromCustomPlotData = ({
id,
metric,
param,
type,
values
}: CustomPlotData) =>
({
id,
metric,
param,
type,
values
} as CustomPlot)

describe('collectCustomPlots', () => {
const defaultFuncArgs = {
experiments: experimentsWithCheckpoints,
hasCheckpoints: true,
height: DEFAULT_PLOT_HEIGHT,
nbItemsPerRow: DEFAULT_NB_ITEMS_PER_ROW,
plotsOrderValues: customPlotsOrderFixture,
selectedRevisions: customPlotsFixture.colors?.domain
}

it('should return the expected data from the test fixture', () => {
const expectedOutput: CustomPlot[] = customPlotsFixture.plots.map(
getCustomPlotFromCustomPlotData
)
const data = collectCustomPlots(
customPlotsOrderFixture,
checkpointPlotsFixture,
[
{
id: '12345',
label: '123',
metrics: {
'summary.json': {
accuracy: 0.3724166750907898,
loss: 2.0205044746398926
}
},
name: 'exp-e7a67',
params: { 'params.yaml': { dropout: 0.15, epochs: 2 } }
},
{
id: '12345',
label: '123',
metrics: {
'summary.json': {
accuracy: 0.4668000042438507,
loss: 1.9293040037155151
}
},
name: 'test-branch',
params: { 'params.yaml': { dropout: 0.122, epochs: 2 } }
},
{
id: '12345',
label: '123',
metrics: {
'summary.json': {
accuracy: 0.5926499962806702,
loss: 1.775016188621521
}
},
name: 'exp-83425',
params: { 'params.yaml': { dropout: 0.124, epochs: 5 } }
}
]
const expectedOutput: CustomPlotData[] = customPlotsFixture.plots
const data = collectCustomPlots(defaultFuncArgs)
expect(data).toStrictEqual(expectedOutput)
})

it('should return only custom plots if there no selected revisions', () => {
const expectedOutput: CustomPlotData[] = customPlotsFixture.plots.filter(
plot => plot.type !== CustomPlotType.CHECKPOINT
)
const data = collectCustomPlots({
...defaultFuncArgs,
selectedRevisions: undefined
})

expect(data).toStrictEqual(expectedOutput)
})
})

describe('collectCustomPlotData', () => {
it('should return the expected data from test fixture', () => {
const expectedMetricVsParamPlotData = customPlotsFixture.plots[0]
const expectedCheckpointsPlotData = customPlotsFixture.plots[2]
const metricVsParamPlot = getCustomPlotFromCustomPlotData(
expectedMetricVsParamPlotData
)
const checkpointsPlot = getCustomPlotFromCustomPlotData(
expectedCheckpointsPlotData
it('should return only custom plots if checkpoints are not enabled', () => {
const expectedOutput: CustomPlotData[] = customPlotsFixture.plots.filter(
plot => plot.type !== CustomPlotType.CHECKPOINT
)
const data = collectCustomPlots({
...defaultFuncArgs,
hasCheckpoints: false
})

const metricVsParamData = collectCustomPlotData(
metricVsParamPlot,
customPlotsFixture.colors,
customPlotsFixture.nbItemsPerRow,
customPlotsFixture.height
)
expect(data).toStrictEqual(expectedOutput)
})

const checkpointsData = collectCustomPlotData(
{
...checkpointsPlot,
values: [
...checkpointsPlot.values,
{
group: 'exp-123',
iteration: 1,
y: 1.4534177053451538
},
{
group: 'exp-123',
iteration: 2,
y: 1.757687
},
{
group: 'exp-123',
iteration: 3,
y: 1.989894
}
]
} as CheckpointPlot,
customPlotsFixture.colors,
customPlotsFixture.nbItemsPerRow,
customPlotsFixture.height
)
it('should return checkpoint plots with values only containing selected experiments data', () => {
const domain = customPlotsFixture.colors?.domain.slice(1) as string[]

const expectedOutput = customPlotsFixture.plots.map(plot => ({
...plot,
values: isCheckpointPlot(plot)
? plot.values.filter(value => domain.includes(value.group))
: plot.values
}))

expect(metricVsParamData).toStrictEqual(expectedMetricVsParamPlotData)
expect(checkpointsData).toStrictEqual(expectedCheckpointsPlotData)
const data = collectCustomPlots({
...defaultFuncArgs,
selectedRevisions: domain
})

expect(data).toStrictEqual(expectedOutput)
})
})

Expand Down Expand Up @@ -215,50 +154,6 @@ describe('collectData', () => {
})
})

describe('collectCustomCheckpointPlotsData', () => {
it('should return the expected data from the test fixture', () => {
const data = collectCustomCheckpointPlots(expShowFixture)

expect(data).toStrictEqual(checkpointPlotsFixture)
})

it('should provide a continuous series for a modified experiment', () => {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[F] This logic has been removed on purpose. Change is being discussed in [WIP] exp: refactor show behavior PR. Link is further down.

const data = collectCustomCheckpointPlots(modifiedFixture)

for (const { values } of Object.values(data)) {
const initialExperiment = values.filter(
point => point.group === 'exp-908bd'
)
const modifiedExperiment = values.find(
point => point.group === 'exp-01b3a'
)

const lastIterationInitial = initialExperiment?.slice(-1)[0]
const firstIterationModified = modifiedExperiment

expect(lastIterationInitial).not.toStrictEqual(firstIterationModified)
expect(omit(lastIterationInitial, 'group')).toStrictEqual(
omit(firstIterationModified, 'group')
)

const baseExperiment = values.filter(point => point.group === 'exp-920fc')
const restartedExperiment = values.find(
point => point.group === 'exp-9bc1b'
)

const iterationRestartedFrom = baseExperiment?.slice(5)[0]
const firstIterationAfterRestart = restartedExperiment

expect(iterationRestartedFrom).not.toStrictEqual(
firstIterationAfterRestart
)
expect(omit(iterationRestartedFrom, 'group')).toStrictEqual(
omit(firstIterationAfterRestart, 'group')
)
}
})
})

describe('collectTemplates', () => {
it('should return the expected output from the test fixture', () => {
const { content } = logsLossPlot
Expand Down
Loading