Skip to content

Commit

Permalink
Add options to show expanded node data summary on group node and labe…
Browse files Browse the repository at this point in the history
…l count columns in children stats table.

PiperOrigin-RevId: 704079793
  • Loading branch information
Google AI Edge authored and copybara-github committed Dec 8, 2024
1 parent d7232e5 commit cd61642
Show file tree
Hide file tree
Showing 13 changed files with 473 additions and 82 deletions.
12 changes: 12 additions & 0 deletions src/ui/src/components/visualizer/common/consts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,18 @@ export const NODE_ATTRS_TABLE_VALUE_MAX_WIDTH = 200;
/** The height of attrs table row. */
export const NODE_ATTRS_TABLE_ROW_HEIGHT = 12;

/** The height of the summary row in node data provider. */
export const EXPANDED_NODE_DATA_PROVIDER_SUMMARY_ROW_HEIGHT = 14;

/** The top padding of the summary row in node data provider. */
export const EXPANDED_NODE_DATA_PROVIDER_SUMMARY_TOP_PADDING = 6;

/** The bottom padding of the summary row in node data provider. */
export const EXPANDED_NODE_DATA_PROVIDER_SUMMARY_BOTTOM_PADDING = 6;

/** The font size of the summary row in node data provider. */
export const EXPANDED_NODE_DATA_PROVIDER_SYUMMARY_FONT_SIZE = 9;

/** The maximum number of children nodes under a group node. */
export const DEFAULT_GROUP_NODE_CHILDREN_COUNT_THRESHOLD = IS_EXTERNAL
? 1000
Expand Down
28 changes: 28 additions & 0 deletions src/ui/src/components/visualizer/common/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,27 @@ export declare interface NodeDataProviderGraphData {
* The value for the hidden stat will be displayed as '-'.
*/
hideAggregatedStats?: AggregatedStat[];

/**
* Controls whether to display a detailed value distribution summary on the
* group node.
*
* By default, a color bar representing the value distribution of
* all descendant nodes is shown at the bottom of the group node. If this
* field is set to true, we will show a more detailed summary, with each
* value's label, percentage, and count shown on a separate line.
*
* For now this only works with non-numerical (e.g. string) node data values.
*/
showExpandedSummaryOnGroupNode?: boolean;

/**
* Whether to display the label count columns in the children stats table in
* the side panel.
*
* For now this only works with non-numerical (e.g. string) node data values.
*/
showLabelCountColumnsInChildrenStatsTable?: boolean;
}

/** The top level node data provider data, indexed by graph id. */
Expand Down Expand Up @@ -348,6 +369,13 @@ export declare interface NodeDataProviderRunData {
error?: string;
}

/** Info for a value in a node data provider run. */
export declare interface NodeDataProviderValueInfo {
label: string;
bgColor: string;
count: number;
}

/** The result data for a node in a node data provider run. */
export declare interface NodeDataProviderResultData {
/** The original value of the result. */
Expand Down
33 changes: 32 additions & 1 deletion src/ui/src/components/visualizer/common/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ import {
FieldLabel,
KeyValueList,
KeyValuePairs,
NodeDataProviderResultProcessedData,
NodeDataProviderRunData,
NodeDataProviderValueInfo,
NodeQuery,
NodeQueryType,
NodeStyleId,
Expand Down Expand Up @@ -546,7 +548,7 @@ export function getOpNodeDataProviderKeyValuePairsForAttrsTable(
runNames.includes(getRunName(run, {id: modelGraphId})),
);
for (const run of runs) {
const result = (run.results || {})?.[modelGraphId]?.[node.id];
const result = ((run.results || {})?.[modelGraphId] || {})[node.id];
if (config?.hideEmptyNodeDataEntries && !result) {
continue;
}
Expand Down Expand Up @@ -1079,3 +1081,32 @@ export function getRunName(
run.nodeDataProviderData?.[modelGraphIdLike?.id || '']?.name ?? run.runName
);
}

/** Generates the sorted value infos for the given group node. */
export function genSortedValueInfos(
groupNode: GroupNode | undefined,
modelGraph: ModelGraph,
results: Record<string, NodeDataProviderResultProcessedData>,
): NodeDataProviderValueInfo[] {
const bgColorToValueInfo: Record<string, NodeDataProviderValueInfo> = {};
const descendantsOpNodeIds =
groupNode?.descendantsOpNodeIds || modelGraph.nodes.map((node) => node.id);
for (const nodeId of descendantsOpNodeIds) {
const node = modelGraph.nodesById[nodeId];
const bgColor = results[node.id]?.bgColor || '';
if (bgColor) {
if (!bgColorToValueInfo[bgColor]) {
bgColorToValueInfo[bgColor] = {
label: `${results[nodeId]?.value || ''}`,
bgColor,
count: 1,
};
} else {
bgColorToValueInfo[bgColor].count++;
}
}
}
return Object.values(bgColorToValueInfo).sort((a, b) =>
a.bgColor.localeCompare(b.bgColor),
);
}
3 changes: 3 additions & 0 deletions src/ui/src/components/visualizer/common/worker_events.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ export declare interface ExpandOrCollapseGroupNodeRequest
expand: boolean;
showOnNodeItemTypes: Record<string, ShowOnNodeItemData>;
nodeDataProviderRuns: Record<string, NodeDataProviderRunData>;
selectedNodeDataProviderRunId?: string;
rendererId: string;
paneId: string;
// Expand or collapse all groups under the selected group.
Expand Down Expand Up @@ -111,6 +112,7 @@ export declare interface RelayoutGraphRequest extends WorkerEventBase {
modelGraphId: string;
showOnNodeItemTypes: Record<string, ShowOnNodeItemData>;
nodeDataProviderRuns: Record<string, NodeDataProviderRunData>;
selectedNodeDataProviderRunId?: string;
targetDeepestGroupNodeIdsToExpand?: string[];
selectedNodeId: string;
rendererId: string;
Expand Down Expand Up @@ -142,6 +144,7 @@ export declare interface LocateNodeRequest extends WorkerEventBase {
modelGraphId: string;
showOnNodeItemTypes: Record<string, ShowOnNodeItemData>;
nodeDataProviderRuns: Record<string, NodeDataProviderRunData>;
selectedNodeDataProviderRunId?: string;
nodeId: string;
rendererId: string;
noNodeShake?: boolean;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
<div class="container">
<div class="index-container">
<div class="index-row" *ngFor="let runItem of runItems; let i = index; trackBy: trackByRunId"
[class.selected]="isRunItemSelected(runItem)">
[class.selected]="isRunItemSelected(runItem)"
(click)="handleClickToggleVisibility(runItem, $event)">
<div class="index-number-container">
<div class="index-number" *ngIf="runItem.done">{{i + 1}}</div>
<mat-spinner color="primary" diameter="16" *ngIf="!runItem.done">
Expand Down Expand Up @@ -137,7 +138,7 @@
(click)="handleClickChildrenStatsHeader(col.colIndex)">
<div class="header-content">
<div class="index-number">{{col.runIndex + 1}}</div>
<div class="stat-label">{{col.label}}</div>
<div class="stat-label" [class.multi-line]="col.multiLineHeader">{{col.label}}</div>
<mat-icon *ngIf="col.colIndex === curChildrenStatSortingColIndex" class="sort">
{{curChildrenStatSortingDirection === 'asc' ? 'arrow_upward' : 'arrow_downward'}}
</mat-icon>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@
align-items: center;
overflow: hidden;
padding: 2px 8px;
cursor: pointer;

&.selected {
background-color: #fff2d5;
Expand Down Expand Up @@ -197,7 +198,7 @@
display: flex;
flex-direction: column;
transition: max-height 150ms ease-out;
overflow: hidden;
overflow-y: clip;

&.collapsed {
/* stylelint-disable-next-line declaration-no-important -- override element style */
Expand Down Expand Up @@ -289,6 +290,10 @@
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;

&.multi-line {
white-space: pre;
}
}
}

Expand Down
128 changes: 102 additions & 26 deletions src/ui/src/components/visualizer/node_data_provider_summary_panel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,17 @@ import {debounceTime} from 'rxjs/operators';
import {AppService} from './app_service';
import {NODE_DATA_PROVIDER_SHOW_ON_NODE_TYPE_PREFIX} from './common/consts';
import {GroupNode, ModelGraph, OpNode} from './common/model_graph';
import {AggregatedStat, NodeDataProviderRunData} from './common/types';
import {getRunName, isGroupNode, isOpNode} from './common/utils';
import {
AggregatedStat,
NodeDataProviderRunData,
NodeDataProviderValueInfo,
} from './common/types';
import {
genSortedValueInfos,
getRunName,
isGroupNode,
isOpNode,
} from './common/utils';
import {InfoPanelService, SortingDirection} from './info_panel_service';
import {NodeDataProviderExtensionService} from './node_data_provider_extension_service';
import {Paginator} from './paginator';
Expand Down Expand Up @@ -98,6 +107,7 @@ interface ChildrenStatsCol {
runIndex: number;
label: string;
hideInChildrenStatsTable?: boolean;
multiLineHeader?: boolean;
}

const CHILDREN_STATS = ['Sum %'];
Expand Down Expand Up @@ -716,15 +726,45 @@ export class NodeDataProviderSummaryPanel implements OnChanges {
// Generate children stats columns.
this.childrenStatsCols = [];
let childrenStatColIndex = 0;
const groupNode = this.curModelGraph.nodesById[
this.rootGroupNodeId ?? ''
] as GroupNode;
const runIdToValueInfos: Record<string, NodeDataProviderValueInfo[]> = {};
for (let i = 0; i < runs.length; i++) {
for (const childrenStat of CHILDREN_STATS) {
const run = runs[i];
let childrenStats = CHILDREN_STATS;
let valueInfos: NodeDataProviderValueInfo[] = [];
let multiLineHeader = false;
if (
(run.nodeDataProviderData ?? {})[this.curModelGraph.id]
?.showLabelCountColumnsInChildrenStatsTable
) {
valueInfos = genSortedValueInfos(
groupNode,
this.curModelGraph,
(run.results ?? {})[this.curModelGraph.id],
).sort((a, b) => a.label.localeCompare(b.label));
runIdToValueInfos[run.runId] = valueInfos;
childrenStats = valueInfos.map((valueInfo) => `#${valueInfo.label}`);
multiLineHeader = true;
}
for (const childrenStat of childrenStats) {
let label = childrenStat;
if (runs.length > 1) {
if (multiLineHeader) {
label = `${this.getRunName(runs[i])}\n${childrenStat}`;
} else {
label = `${this.getRunName(runs[i])}${childrenStat}`;
}
}
this.childrenStatsCols.push({
colIndex: childrenStatColIndex,
runIndex: i,
label: `${this.getRunName(runs[i])}${childrenStat}`,
label,
hideInChildrenStatsTable:
runs[i].nodeDataProviderData?.[this.curModelGraph.id]
?.hideInChildrenStatsTable,
multiLineHeader,
});
childrenStatColIndex++;
}
Expand All @@ -746,36 +786,72 @@ export class NodeDataProviderSummaryPanel implements OnChanges {
const run = runs[runIndex];
const curResults = run.results || {};
// Sum pct.
let sumPct = 0;
let hasValue = false;
if (isOpNode(node)) {
const nodeResult = (curResults[this.curModelGraph.id] || {})[nodeId];
const value = nodeResult?.value;
if (value != null && typeof value === 'number') {
sumPct = (value / stats[runIndex].sum) * 100;
hasValue = true;
}
} else if (isGroupNode(node)) {
let layerSum = 0;
const childrenIds = node.descendantsOpNodeIds || [];
for (const childNodeId of childrenIds) {
if (!runIdToValueInfos[run.runId]) {
let sumPct = 0;
let hasValue = false;
if (isOpNode(node)) {
const nodeResult = (curResults[this.curModelGraph.id] || {})[
childNodeId
nodeId
];
const value = nodeResult?.value;
if (value != null && typeof value === 'number') {
layerSum += value;
sumPct = (value / stats[runIndex].sum) * 100;
hasValue = true;
}
} else if (isGroupNode(node)) {
let layerSum = 0;
const childrenIds = node.descendantsOpNodeIds || [];
for (const childNodeId of childrenIds) {
const nodeResult = (curResults[this.curModelGraph.id] || {})[
childNodeId
];
const value = nodeResult?.value;
if (value != null && typeof value === 'number') {
layerSum += value;
hasValue = true;
}
}
sumPct = (layerSum / stats[runIndex].sum) * 100;
}
colValues.push(sumPct);
colStrs.push(hasValue ? sumPct.toFixed(1) : '-');
colHidden.push(
run.nodeDataProviderData?.[this.curModelGraph.id]
?.hideInChildrenStatsTable === true,
);
}
// Label counts.
else {
const valueInfos = runIdToValueInfos[run.runId];
const curResults = run.results || {};
const nodeResult = (curResults[this.curModelGraph.id] || {})[nodeId];
const value = nodeResult?.value || '';
for (const valueInfo of valueInfos) {
let count = 0;
if (isOpNode(node)) {
if (valueInfo.label === value) {
count = 1;
}
} else if (isGroupNode(node)) {
const childrenIds = node.descendantsOpNodeIds || [];
for (const childNodeId of childrenIds) {
const nodeResult = (curResults[this.curModelGraph.id] || {})[
childNodeId
];
const childValue = nodeResult?.value || '';
if (childValue === valueInfo.label) {
count++;
}
}
}
colValues.push(count);
colStrs.push(`${count}`);
colHidden.push(
run.nodeDataProviderData?.[this.curModelGraph.id]
?.hideInChildrenStatsTable === true,
);
}
sumPct = (layerSum / stats[runIndex].sum) * 100;
}
colValues.push(sumPct);
colStrs.push(hasValue ? sumPct.toFixed(1) : '-');
colHidden.push(
run.nodeDataProviderData?.[this.curModelGraph.id]
?.hideInChildrenStatsTable === true,
);
}
this.curChildrenStatRows.push({
id: nodeId,
Expand Down
Loading

0 comments on commit cd61642

Please sign in to comment.