Skip to content

Commit

Permalink
override revisions when there is a selected running checkpoint experi…
Browse files Browse the repository at this point in the history
…ment
  • Loading branch information
mattseddon committed Nov 29, 2022
1 parent 57d4c94 commit f1e790b
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 91 deletions.
2 changes: 0 additions & 2 deletions extension/src/plots/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,8 @@ export class Plots extends BaseRepository<TPlotsData> {
this.paths.hasPaths() &&
definedAndNonEmpty(this.plots.getUnfetchedRevisions())
) {
this.webviewMessages.sendCheckpointPlotsMessage()
this.data.managedUpdate()
}
// needs to wait for update if checkpoint experiment is running <= fix this in the way described in #2831

return this.webviewMessages.sendWebviewMessage()
}
Expand Down
88 changes: 64 additions & 24 deletions extension/src/plots/model/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ import {
MultiSourceEncoding,
MultiSourceVariations
} from '../multiSource/collect'
import { isRunning } from '../../experiments/webview/contract'
import { SelectedExperimentWithColor } from '../../experiments/model'

export class PlotsModel extends ModelWithPersistence {
private readonly experiments: Experiments
Expand Down Expand Up @@ -107,11 +109,6 @@ export class PlotsModel extends ModelWithPersistence {
public async transformAndSetPlots(data: PlotsOutput, revs: string[]) {
const cliIdToLabel = this.getCLIIdToLabel()

this.fetchedRevs = new Set([
...this.fetchedRevs,
...revs.map(rev => cliIdToLabel[rev])
])

const [{ comparisonData, revisionData }, templates, multiSourceVariations] =
await Promise.all([
collectData(data, cliIdToLabel),
Expand Down Expand Up @@ -144,6 +141,11 @@ export class PlotsModel extends ModelWithPersistence {

this.setComparisonOrder()

this.fetchedRevs = new Set([
...this.fetchedRevs,
...revs.map(rev => cliIdToLabel[rev])
])

this.deferred.resolve()
}

Expand Down Expand Up @@ -176,6 +178,34 @@ export class PlotsModel extends ModelWithPersistence {
this.deleteRevisionData(id)
}

public getOverrideRevisionDetails() {
const mapping: { [label: string]: string } = {}

const selectedWithFetchedRunningCheckpointRevs = this.experiments
.getSelectedRevisions()
.map(exp => {
const { label, status, displayColor, id } = exp
if (isRunning(status) && !this.fetchedRevs.has(label)) {
const mostRecent =
this.experiments
.getCheckpoints(id)
?.find(({ label }) => this.fetchedRevs.has(label)) || exp
mapping[label] = mostRecent.label
return {
...mostRecent,
displayColor
} as SelectedExperimentWithColor
}
mapping[label] = label
return exp
})

return this.getSelectedRevisionDetails(
this.comparisonOrder.map(label => mapping[label]).filter(Boolean),
selectedWithFetchedRunningCheckpointRevs
)
}

public getUnfetchedRevisions() {
return this.getSelectedRevisions().filter(
revision => !this.fetchedRevs.has(revision)
Expand All @@ -188,41 +218,48 @@ export class PlotsModel extends ModelWithPersistence {
...Object.keys(this.revisionData)
])

return this.getSelectedRevisions()
.filter(label => !cachedRevisions.has(label))
.map(label => this.getCLIId(label))
return this.experiments
.getSelectedRevisions()
.filter(({ label }) => !cachedRevisions.has(label))
.map(({ label }) => this.getCLIId(label))
}

public getMutableRevisions() {
return this.experiments.getMutableRevisions()
}

public getRevisionColors() {
return getColorScale(this.getSelectedRevisionDetails())
public getRevisionColors(overrideRevs?: Revision[]) {
return getColorScale(overrideRevs || this.getSelectedRevisionDetails())
}

public getSelectedRevisionDetails() {
public getSelectedRevisionDetails(
overrideOrder?: string[],
overrideRevs?: SelectedExperimentWithColor[]
) {
return reorderObjectList<Revision>(
this.comparisonOrder,
this.experiments
.getSelectedRevisions()
.map(({ label, displayColor, logicalGroupName, id }) => ({
overrideOrder || this.comparisonOrder,
(overrideRevs || this.experiments.getSelectedRevisions()).map(
({ label, displayColor, logicalGroupName, id }) => ({
displayColor,
fetched: this.fetchedRevs.has(label),
group: logicalGroupName,
id,
revision: label
})),
})
),
'revision'
)
}

public getTemplatePlots(order: TemplateOrder | undefined) {
public getTemplatePlots(
order: TemplateOrder | undefined,
overrideRevs?: Revision[]
) {
if (!definedAndNonEmpty(order)) {
return
}

const selectedRevisions = this.getSelectedRevisions()
const selectedRevisions = overrideRevs || this.getSelectedRevisionDetails()

if (!definedAndNonEmpty(selectedRevisions)) {
return
Expand All @@ -231,12 +268,15 @@ export class PlotsModel extends ModelWithPersistence {
return this.getSelectedTemplatePlots(order, selectedRevisions)
}

public getComparisonPlots(paths: string[] | undefined) {
public getComparisonPlots(
paths: string[] | undefined,
overrideRevs?: string[]
) {
if (!paths) {
return
}

const selectedRevisions = this.getSelectedRevisions()
const selectedRevisions = overrideRevs || this.getSelectedRevisions()
if (!definedAndNonEmpty(selectedRevisions)) {
return
}
Expand Down Expand Up @@ -369,7 +409,7 @@ export class PlotsModel extends ModelWithPersistence {
private getCLIIdToLabel() {
const mapping: { [shortSha: string]: string } = {}

for (const rev of this.getSelectedRevisions()) {
for (const rev of this.experiments.getRevisions()) {
mapping[this.getCLIId(rev)] = rev
}

Expand Down Expand Up @@ -440,15 +480,15 @@ export class PlotsModel extends ModelWithPersistence {

private getSelectedTemplatePlots(
order: TemplateOrder,
selectedRevisions: string[]
selectedRevisions: Revision[]
) {
return collectSelectedTemplatePlots(
order,
selectedRevisions,
selectedRevisions.map(({ revision }) => revision),
this.templates,
this.revisionData,
this.getPlotSize(Section.TEMPLATE_PLOTS),
this.getRevisionColors(),
this.getRevisionColors(selectedRevisions),
this.multiSourceEncoding
)
}
Expand Down
30 changes: 17 additions & 13 deletions extension/src/plots/webview/messages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import {
ComparisonPlot,
ComparisonRevisionData,
PlotsData as TPlotsData,
Revision,
Section,
SectionCollapsed
} from './contract'
Expand Down Expand Up @@ -48,14 +49,17 @@ export class WebviewMessages {
}

public sendWebviewMessage() {
const selectedRevisions = this.plots.getOverrideRevisionDetails()
const selectedLabels = selectedRevisions.map(({ revision }) => revision)

this.getWebview()?.show({
checkpoint: this.getCheckpointPlots(),
comparison: this.getComparisonPlots(),
comparison: this.getComparisonPlots(selectedLabels),
hasPlots: !!this.paths?.hasPaths(),
hasSelectedPlots: definedAndNonEmpty(this.paths.getSelected()),
sectionCollapsed: this.plots.getSectionCollapsed(),
selectedRevisions: this.plots.getSelectedRevisionDetails(),
template: this.getTemplatePlots()
selectedRevisions,
template: this.getTemplatePlots(selectedRevisions)
})
}

Expand Down Expand Up @@ -213,14 +217,14 @@ export class WebviewMessages {

private sendSectionCollapsed() {
this.getWebview()?.show({
sectionCollapsed: this.plots?.getSectionCollapsed()
sectionCollapsed: this.plots.getSectionCollapsed()
})
}

private sendComparisonPlots() {
this.getWebview()?.show({
comparison: this.getComparisonPlots(),
selectedRevisions: this.plots?.getSelectedRevisionDetails()
selectedRevisions: this.plots.getSelectedRevisionDetails()
})
}

Expand All @@ -230,11 +234,11 @@ export class WebviewMessages {
})
}

private getTemplatePlots() {
const paths = this.paths?.getTemplateOrder()
const plots = this.plots?.getTemplatePlots(paths)
private getTemplatePlots(overrideRevs?: Revision[]) {
const paths = this.paths.getTemplateOrder()
const plots = this.plots.getTemplatePlots(paths, overrideRevs)

if (!this.plots || !plots || isEmpty(plots)) {
if (!plots || isEmpty(plots)) {
return null
}

Expand All @@ -244,10 +248,10 @@ export class WebviewMessages {
}
}

private getComparisonPlots() {
private getComparisonPlots(overrideRevs?: string[]) {
const paths = this.paths.getComparisonPaths()
const comparison = this.plots.getComparisonPlots(paths)
if (!this.plots || !comparison || isEmpty(comparison)) {
const comparison = this.plots.getComparisonPlots(paths, overrideRevs)
if (!comparison || isEmpty(comparison)) {
return null
}

Expand Down Expand Up @@ -285,6 +289,6 @@ export class WebviewMessages {
}

private getCheckpointPlots() {
return this.plots?.getCheckpointPlots() || null
return this.plots.getCheckpointPlots() || null
}
}
11 changes: 0 additions & 11 deletions extension/src/util/array.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,6 @@ export const splitMatchedOrdered = <T>(values: T[], existingOrder: T[]) => {
export const reorderListSubset = <T>(subset: T[], supersetOrder: T[]): T[] =>
splitMatchedOrdered(subset, supersetOrder)[0]

export const performOrderedUpdate = (
order: string[],
items: { [key: string]: unknown }[],
key: string
): string[] => {
const current = reorderObjectList(order, items, key)
const added = items.filter(item => !order.includes(item[key] as string))

return [...current, ...added].map(item => item?.[key]) as string[]
}

export const performSimpleOrderedUpdate = (
order: string[],
items: string[]
Expand Down
32 changes: 0 additions & 32 deletions webview/src/plots/components/App.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -1818,38 +1818,6 @@ describe('App', () => {
type: MessageFromWebviewType.REFRESH_REVISIONS
})
})

it('should not reorder the ribbon when comparison plots are reordered', () => {
renderAppWithOptionalData({
comparison: comparisonTableFixture,
selectedRevisions: plotsRevisionsFixture
})

const expectedRevisions = plotsRevisionsFixture.map(rev =>
rev.group ? rev.group.slice(1, -1) + rev.revision : rev.revision
)

expect(getDisplayedRevisionOrder()).toStrictEqual(expectedRevisions)

sendSetDataMessage({
comparison: comparisonTableFixture,
selectedRevisions: [
{
displayColor: '#f56565',
fetched: true,
group: undefined,
id: 'new-revision',
revision: 'new-revision'
},
...[...plotsRevisionsFixture].reverse()
]
})

expect(getDisplayedRevisionOrder()).toStrictEqual([
...expectedRevisions,
'new-revision'
])
})
})

describe('Vega panels', () => {
Expand Down
11 changes: 2 additions & 9 deletions webview/src/plots/components/ribbon/Ribbon.tsx
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import cx from 'classnames'
import { MessageFromWebviewType } from 'dvc/src/webview/contract'
import React, { useEffect, useState } from 'react'
import React from 'react'
import { useSelector } from 'react-redux'
import { performOrderedUpdate, reorderObjectList } from 'dvc/src/util/array'
import { useInView } from 'react-intersection-observer'
import styles from './styles.module.scss'
import { RibbonBlock } from './RibbonBlock'
Expand All @@ -23,12 +22,6 @@ export const Ribbon: React.FC = () => {
const revisions = useSelector(
(state: PlotsState) => state.webview.selectedRevisions
)
const [order, setOrder] = useState<string[]>([])
const reorderId = 'id'

useEffect(() => {
setOrder(pastOrder => performOrderedUpdate(pastOrder, revisions, reorderId))
}, [revisions])

const removeRevision = (revision: string) => {
sendMessage({
Expand Down Expand Up @@ -70,7 +63,7 @@ export const Ribbon: React.FC = () => {
appearance="secondary"
/>
</li>
{reorderObjectList(order, revisions, reorderId).map(revision => (
{revisions.map(revision => (
<RibbonBlock
revision={revision}
key={revision.revision}
Expand Down

0 comments on commit f1e790b

Please sign in to comment.