From 91017137edfb235837066c80deec4b3c96290b3d Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Tue, 18 Oct 2022 16:45:18 +1100 Subject: [PATCH] Add util for identifying ValueTree type --- extension/src/cli/dvc/contract.ts | 5 +++++ .../experiments/columns/collect/metricsAndParams.ts | 11 ++++------- extension/src/experiments/model/modify/collect.ts | 6 +++--- extension/src/plots/model/collect.ts | 5 +++-- 4 files changed, 15 insertions(+), 12 deletions(-) diff --git a/extension/src/cli/dvc/contract.ts b/extension/src/cli/dvc/contract.ts index 1484ded5fb..8ea7601ed6 100644 --- a/extension/src/cli/dvc/contract.ts +++ b/extension/src/cli/dvc/contract.ts @@ -39,6 +39,11 @@ export interface ValueTreeNode { export type ValueTree = ValueTreeRoot | ValueTreeNode +export const isValueTree = ( + value: Value | ValueTree +): value is NonNullable => + !!(value && !Array.isArray(value) && typeof value === 'object') + export enum ExperimentStatus { FAILED = 'Failed', QUEUED = 'Queued', diff --git a/extension/src/experiments/columns/collect/metricsAndParams.ts b/extension/src/experiments/columns/collect/metricsAndParams.ts index 2bc86ee64b..a95a132737 100644 --- a/extension/src/experiments/columns/collect/metricsAndParams.ts +++ b/extension/src/experiments/columns/collect/metricsAndParams.ts @@ -9,9 +9,9 @@ import { import { ColumnType } from '../../webview/contract' import { ExperimentFields, + isValueTree, Value, ValueTree, - ValueTreeNode, ValueTreeOrError, ValueTreeRoot } from '../../../cli/dvc/contract' @@ -57,9 +57,6 @@ const collectMetricOrParam = ( ) } -const notLeaf = (value: ValueTreeNode): boolean => - value && !Array.isArray(value) && typeof value === 'object' - const walkValueTree = ( acc: ColumnAccumulator, type: ColumnType, @@ -67,7 +64,7 @@ const walkValueTree = ( ancestors: string[] = [] ) => { for (const [label, value] of Object.entries(tree)) { - if (notLeaf(value)) { + if (isValueTree(value)) { walkValueTree(acc, type, value, [...ancestors, label]) } else { collectMetricOrParam(acc, type, ancestors, label, value) @@ -110,8 +107,8 @@ const collectChange = ( commitData: ExperimentFields, ancestors: string[] = [] ) => { - if (value && !Array.isArray(value) && typeof value === 'object') { - for (const [childKey, childValue] of Object.entries(value as ValueTree)) { + if (isValueTree(value)) { + for (const [childKey, childValue] of Object.entries(value)) { collectChange(changes, type, file, childKey, childValue, commitData, [ ...ancestors, key diff --git a/extension/src/experiments/model/modify/collect.ts b/extension/src/experiments/model/modify/collect.ts index db9a094650..a6cad1796c 100644 --- a/extension/src/experiments/model/modify/collect.ts +++ b/extension/src/experiments/model/modify/collect.ts @@ -1,4 +1,4 @@ -import { Value, ValueTree } from '../../../cli/dvc/contract' +import { isValueTree, Value, ValueTree } from '../../../cli/dvc/contract' import { appendColumnToPath } from '../../columns/paths' import { MetricOrParamColumns } from '../../webview/contract' @@ -15,8 +15,8 @@ const collectFromParamsFile = ( ) => { const pathArray = [...ancestors, key].filter(Boolean) as string[] - if (!Array.isArray(value) && typeof value === 'object') { - for (const [childKey, childValue] of Object.entries(value as ValueTree)) { + if (isValueTree(value)) { + for (const [childKey, childValue] of Object.entries(value)) { collectFromParamsFile(acc, childKey, childValue, pathArray) } return diff --git a/extension/src/plots/model/collect.ts b/extension/src/plots/model/collect.ts index 4a155c9967..f8df5e1265 100644 --- a/extension/src/plots/model/collect.ts +++ b/extension/src/plots/model/collect.ts @@ -19,6 +19,7 @@ import { ExperimentsBranchOutput, ExperimentsOutput, ExperimentStatus, + isValueTree, PlotsOutput, Value, ValueTree @@ -58,8 +59,8 @@ const collectFromMetricsFile = ( ) => { const pathArray = [...ancestors, key].filter(Boolean) as string[] - if (typeof value === 'object') { - for (const [childKey, childValue] of Object.entries(value as ValueTree)) { + if (isValueTree(value)) { + for (const [childKey, childValue] of Object.entries(value)) { collectFromMetricsFile( acc, name,