Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create custom plots when the data is requested #3491

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 51 additions & 69 deletions extension/src/plots/model/collect.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ import {
collectData,
collectTemplates,
collectOverrideRevisionDetails,
collectCustomPlots,
collectCustomPlotData
collectCustomPlots
} from './collect'
import { isCheckpointPlot } from './custom'
import plotsDiffFixture from '../../test/fixtures/plotsDiff/output'
import customPlotsFixture, {
customPlotsOrderFixture,
Expand All @@ -18,9 +18,10 @@ import {
} from '../../cli/dvc/contract'
import { sameContents } from '../../util/array'
import {
CheckpointPlot,
CustomPlot,
CustomPlotData,
CustomPlotType,
DEFAULT_NB_ITEMS_PER_ROW,
DEFAULT_PLOT_HEIGHT,
TemplatePlot
} from '../webview/contract'
import { getCLICommitId } from '../../test/fixtures/plotsDiff/util'
Expand All @@ -31,81 +32,62 @@ const logsLossPath = join('logs', 'loss.tsv')

const logsLossPlot = (plotsDiffFixture[logsLossPath][0] || {}) as TemplatePlot

const getCustomPlotFromCustomPlotData = ({
id,
metric,
param,
type,
values
}: CustomPlotData) =>
({
id,
metric,
param,
type,
values
} as CustomPlot)

describe('collectCustomPlots', () => {
const defaultFuncArgs = {
experiments: experimentsWithCheckpoints,
hasCheckpoints: true,
height: DEFAULT_PLOT_HEIGHT,
nbItemsPerRow: DEFAULT_NB_ITEMS_PER_ROW,
plotsOrderValues: customPlotsOrderFixture,
selectedRevisions: customPlotsFixture.colors?.domain
}

it('should return the expected data from the test fixture', () => {
const expectedOutput: CustomPlot[] = customPlotsFixture.plots.map(
getCustomPlotFromCustomPlotData
)
const data = collectCustomPlots(
customPlotsOrderFixture,
experimentsWithCheckpoints
)
const expectedOutput: CustomPlotData[] = customPlotsFixture.plots
const data = collectCustomPlots(defaultFuncArgs)
expect(data).toStrictEqual(expectedOutput)
})
})

describe('collectCustomPlotData', () => {
it('should return the expected data from test fixture', () => {
const expectedMetricVsParamPlotData = customPlotsFixture.plots[0]
const expectedCheckpointsPlotData = customPlotsFixture.plots[2]
const metricVsParamPlot = getCustomPlotFromCustomPlotData(
expectedMetricVsParamPlotData
)
const checkpointsPlot = getCustomPlotFromCustomPlotData(
expectedCheckpointsPlotData
it('should return only custom plots if there no selected revisions', () => {
const expectedOutput: CustomPlotData[] = customPlotsFixture.plots.filter(
plot => plot.type !== CustomPlotType.CHECKPOINT
)
const data = collectCustomPlots({
...defaultFuncArgs,
selectedRevisions: undefined
})

const metricVsParamData = collectCustomPlotData(
metricVsParamPlot,
customPlotsFixture.colors,
customPlotsFixture.nbItemsPerRow,
customPlotsFixture.height
)
expect(data).toStrictEqual(expectedOutput)
})

const checkpointsData = collectCustomPlotData(
{
...checkpointsPlot,
values: [
...checkpointsPlot.values,
{
group: 'exp-123',
iteration: 1,
y: 1.4534177053451538
},
{
group: 'exp-123',
iteration: 2,
y: 1.757687
},
{
group: 'exp-123',
iteration: 3,
y: 1.989894
}
]
} as CheckpointPlot,
customPlotsFixture.colors,
customPlotsFixture.nbItemsPerRow,
customPlotsFixture.height
it('should return only custom plots if checkpoints are not enabled', () => {
const expectedOutput: CustomPlotData[] = customPlotsFixture.plots.filter(
plot => plot.type !== CustomPlotType.CHECKPOINT
)
const data = collectCustomPlots({
...defaultFuncArgs,
hasCheckpoints: false
})

expect(metricVsParamData).toStrictEqual(expectedMetricVsParamPlotData)
expect(checkpointsData).toStrictEqual(expectedCheckpointsPlotData)
expect(data).toStrictEqual(expectedOutput)
})

it('should return checkpoint plots with values only containing selected experiments data', () => {
const domain = customPlotsFixture.colors?.domain.slice(1) as string[]

const expectedOutput = customPlotsFixture.plots.map(plot => ({
...plot,
values: isCheckpointPlot(plot)
? plot.values.filter(value => domain.includes(value.group))
: plot.values
}))

const data = collectCustomPlots({
...defaultFuncArgs,
selectedRevisions: domain
})

expect(data).toStrictEqual(expectedOutput)
})
})

Expand Down
76 changes: 45 additions & 31 deletions extension/src/plots/model/collect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import {
getFullValuePath,
CHECKPOINTS_PARAM,
CustomPlotsOrderValue,
isCheckpointPlot,
isCheckpointValue
} from './custom'
import { getRevisionFirstThreeColumns } from './util'
Expand All @@ -20,7 +19,6 @@ import {
TemplatePlotSection,
PlotsType,
Revision,
CustomPlot,
CustomPlotData,
MetricVsParamPlotValues
} from '../webview/contract'
Expand Down Expand Up @@ -126,10 +124,13 @@ const getMetricVsParamValues = (
return values
}

const getCustomPlot = (
const getCustomPlotData = (
orderValue: CustomPlotsOrderValue,
experiments: ExperimentWithCheckpoints[]
): CustomPlot => {
experiments: ExperimentWithCheckpoints[],
selectedRevisions: string[] | undefined = [],
height: number,
nbItemsPerRow: number
): CustomPlotData => {
const { metric, param, type } = orderValue
const metricPath = getFullValuePath(
ColumnType.METRICS,
Expand All @@ -139,46 +140,59 @@ const getCustomPlot = (

const paramPath = getFullValuePath(ColumnType.PARAMS, param, FILE_SEPARATOR)

const selectedExperiments = experiments.filter(({ name, label }) =>
selectedRevisions.includes(name || label)
)

const values = isCheckpointValue(type)
? getCheckpointValues(experiments, metricPath)
? getCheckpointValues(selectedExperiments, metricPath)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[F] This will be the best place to push the data into the spec.

: getMetricVsParamValues(experiments, metricPath, paramPath)

return {
id: getCustomPlotId(metric, param),
metric,
param,
type,
values
} as CustomPlot
values,
yTitle: truncateVerticalTitle(metric, nbItemsPerRow, height) as string
} as CustomPlotData
}

export const collectCustomPlots = (
plotsOrderValues: CustomPlotsOrderValue[],
export const collectCustomPlots = ({
plotsOrderValues,
experiments,
hasCheckpoints,
selectedRevisions,
height,
nbItemsPerRow
}: {
plotsOrderValues: CustomPlotsOrderValue[]
experiments: ExperimentWithCheckpoints[]
): CustomPlot[] => {
return plotsOrderValues.map(plotOrderValue =>
getCustomPlot(plotOrderValue, experiments)
)
}

export const collectCustomPlotData = (
plot: CustomPlot,
colors: ColorScale | undefined,
nbItemsPerRow: number,
hasCheckpoints: boolean
selectedRevisions: string[] | undefined
height: number
): CustomPlotData => {
const selectedExperiments = colors?.domain
const filteredValues = isCheckpointPlot(plot)
? plot.values.filter(value =>
(selectedExperiments as string[]).includes(value.group)
nbItemsPerRow: number
}): CustomPlotData[] => {
const plots = []
const shouldSkipCheckpointPlots = !hasCheckpoints || !selectedRevisions

for (const value of plotsOrderValues) {
if (shouldSkipCheckpointPlots && isCheckpointValue(value.type)) {
continue
}

plots.push(
getCustomPlotData(
value,
experiments,
selectedRevisions,
height,
nbItemsPerRow
)
: plot.values
)
}

return {
...plot,
values: filteredValues,
yTitle: truncateVerticalTitle(plot.metric, nbItemsPerRow, height) as string
} as CustomPlotData
return plots
}

type RevisionPathData = { [path: string]: Record<string, unknown>[] }
Expand Down
Loading