From dc9770634d2f047ba6ce8ad7b435b905e641cfb1 Mon Sep 17 00:00:00 2001
From: mattseddon <37993418+mattseddon@users.noreply.github.com>
Date: Wed, 19 Oct 2022 04:20:54 +1100
Subject: [PATCH] Add util for identifying ValueTree type (#2619)

---
 extension/src/cli/dvc/contract.ts                     |  5 +++++
 .../experiments/columns/collect/metricsAndParams.ts   | 11 ++++-------
 extension/src/experiments/model/modify/collect.ts     |  6 +++---
 extension/src/plots/model/collect.ts                  |  5 +++--
 4 files changed, 15 insertions(+), 12 deletions(-)

diff --git a/extension/src/cli/dvc/contract.ts b/extension/src/cli/dvc/contract.ts
index 1484ded5fb..8ea7601ed6 100644
--- a/extension/src/cli/dvc/contract.ts
+++ b/extension/src/cli/dvc/contract.ts
@@ -39,6 +39,11 @@ export interface ValueTreeNode {
 
 export type ValueTree = ValueTreeRoot | ValueTreeNode
 
+export const isValueTree = (
+  value: Value | ValueTree
+): value is NonNullable<ValueTree> =>
+  !!(value && !Array.isArray(value) && typeof value === 'object')
+
 export enum ExperimentStatus {
   FAILED = 'Failed',
   QUEUED = 'Queued',
diff --git a/extension/src/experiments/columns/collect/metricsAndParams.ts b/extension/src/experiments/columns/collect/metricsAndParams.ts
index 2bc86ee64b..a95a132737 100644
--- a/extension/src/experiments/columns/collect/metricsAndParams.ts
+++ b/extension/src/experiments/columns/collect/metricsAndParams.ts
@@ -9,9 +9,9 @@ import {
 import { ColumnType } from '../../webview/contract'
 import {
   ExperimentFields,
+  isValueTree,
   Value,
   ValueTree,
-  ValueTreeNode,
   ValueTreeOrError,
   ValueTreeRoot
 } from '../../../cli/dvc/contract'
@@ -57,9 +57,6 @@ const collectMetricOrParam = (
   )
 }
 
-const notLeaf = (value: ValueTreeNode): boolean =>
-  value && !Array.isArray(value) && typeof value === 'object'
-
 const walkValueTree = (
   acc: ColumnAccumulator,
   type: ColumnType,
@@ -67,7 +64,7 @@ const walkValueTree = (
   ancestors: string[] = []
 ) => {
   for (const [label, value] of Object.entries(tree)) {
-    if (notLeaf(value)) {
+    if (isValueTree(value)) {
       walkValueTree(acc, type, value, [...ancestors, label])
     } else {
       collectMetricOrParam(acc, type, ancestors, label, value)
@@ -110,8 +107,8 @@ const collectChange = (
   commitData: ExperimentFields,
   ancestors: string[] = []
 ) => {
-  if (value && !Array.isArray(value) && typeof value === 'object') {
-    for (const [childKey, childValue] of Object.entries(value as ValueTree)) {
+  if (isValueTree(value)) {
+    for (const [childKey, childValue] of Object.entries(value)) {
       collectChange(changes, type, file, childKey, childValue, commitData, [
         ...ancestors,
         key
diff --git a/extension/src/experiments/model/modify/collect.ts b/extension/src/experiments/model/modify/collect.ts
index db9a094650..a6cad1796c 100644
--- a/extension/src/experiments/model/modify/collect.ts
+++ b/extension/src/experiments/model/modify/collect.ts
@@ -1,4 +1,4 @@
-import { Value, ValueTree } from '../../../cli/dvc/contract'
+import { isValueTree, Value, ValueTree } from '../../../cli/dvc/contract'
 import { appendColumnToPath } from '../../columns/paths'
 import { MetricOrParamColumns } from '../../webview/contract'
 
@@ -15,8 +15,8 @@ const collectFromParamsFile = (
 ) => {
   const pathArray = [...ancestors, key].filter(Boolean) as string[]
 
-  if (!Array.isArray(value) && typeof value === 'object') {
-    for (const [childKey, childValue] of Object.entries(value as ValueTree)) {
+  if (isValueTree(value)) {
+    for (const [childKey, childValue] of Object.entries(value)) {
       collectFromParamsFile(acc, childKey, childValue, pathArray)
     }
     return
diff --git a/extension/src/plots/model/collect.ts b/extension/src/plots/model/collect.ts
index 4a155c9967..f8df5e1265 100644
--- a/extension/src/plots/model/collect.ts
+++ b/extension/src/plots/model/collect.ts
@@ -19,6 +19,7 @@ import {
   ExperimentsBranchOutput,
   ExperimentsOutput,
   ExperimentStatus,
+  isValueTree,
   PlotsOutput,
   Value,
   ValueTree
@@ -58,8 +59,8 @@ const collectFromMetricsFile = (
 ) => {
   const pathArray = [...ancestors, key].filter(Boolean) as string[]
 
-  if (typeof value === 'object') {
-    for (const [childKey, childValue] of Object.entries(value as ValueTree)) {
+  if (isValueTree(value)) {
+    for (const [childKey, childValue] of Object.entries(value)) {
       collectFromMetricsFile(
         acc,
         name,