From db3cd4c5705b009c24d4b27d0925f860880258da Mon Sep 17 00:00:00 2001 From: Bo Zhang <71688188+zhb000@users.noreply.github.com> Date: Sat, 19 Mar 2022 11:24:55 -0700 Subject: [PATCH] Add ICE chart (#1283) * Fix * ice chart * ic * test * test --- .../describeSubLineChart.ts | 2 +- .../describer/modelAssessment/Constants.ts | 4 +- .../getIndividualChartOptions.ts | 2 +- .../lib/Highchart/FeatureImportanceBar.tsx | 1 + .../Highchart/getDefaultHighchartOptions.ts | 3 + .../GlobalExplanationTab.tsx | 3 +- .../Controls/MultiICEPlot/MultiICEPlot.tsx | 145 +++--------------- .../src/lib/MLIDashboard/utils/buildYAxis.ts | 20 +++ .../MLIDashboard/utils/getIceChartOption.ts | 138 +++++++++++++++++ .../src/lib/MLIDashboard/utils/mergeXYData.ts | 24 +++ 10 files changed, 212 insertions(+), 130 deletions(-) create mode 100644 libs/interpret/src/lib/MLIDashboard/utils/buildYAxis.ts create mode 100644 libs/interpret/src/lib/MLIDashboard/utils/getIceChartOption.ts create mode 100644 libs/interpret/src/lib/MLIDashboard/utils/mergeXYData.ts diff --git a/apps/dashboard-e2e/src/describer/interpret/individualFeatureImportance/describeSubLineChart.ts b/apps/dashboard-e2e/src/describer/interpret/individualFeatureImportance/describeSubLineChart.ts index 452ea68d9a..00a02d0418 100644 --- a/apps/dashboard-e2e/src/describer/interpret/individualFeatureImportance/describeSubLineChart.ts +++ b/apps/dashboard-e2e/src/describer/interpret/individualFeatureImportance/describeSubLineChart.ts @@ -20,7 +20,7 @@ export function describeSubLineChart(dataShape: IInterpretData): void { props.chart.clickNthPoint(1); }); it("should have more than one point", () => { - cy.get("#subPlotContainer svg g[class^='plot'] .points .point") + cy.get("#subPlotContainer svg g[class^='highcharts-series-group'] path") .its("length") .should("be.gte", 1); }); diff --git a/apps/widget-e2e/src/describer/modelAssessment/Constants.ts b/apps/widget-e2e/src/describer/modelAssessment/Constants.ts index 7097bdfb40..853b17d2aa 100644 --- a/apps/widget-e2e/src/describer/modelAssessment/Constants.ts +++ b/apps/widget-e2e/src/describer/modelAssessment/Constants.ts @@ -13,14 +13,14 @@ export enum Locators { IFIYAxisValue = '#FeatureImportanceBar div[class^="rotatedVerticalBox-"]', IFIXAxisValue = "#FeatureImportanceBar g.highcharts-xaxis-labels text", ICEPlot = '#subPlotChoice label:contains("ICE")', // ICE - Individual Conditional Expectation - ICENoOfPoints = "#subPlotContainer svg g[class^='plot'] .points .point", + ICENoOfPoints = "#subPlotContainer svg g[class^='highcharts-series-group'] path", IFITopFeaturesText = "div[class^='featureImportanceControls'] span[class^='sliderLabel']", IFITopFeaturesValue = "div[class^='featureImportanceControls'] div.ms-Slider-container div.ms-Slider-slideBox", IFIAbsoluteValuesToggleButton = "div[class^='featureImportanceLegend'] div.ms-Toggle", IFIDataPointDropdown = "div[class^='featureImportanceLegend'] div[role='listbox']", ICEFeatureDropdown = "div[class^='featureImportanceLegend'] div[class^='ms-ComboBox-container'] button[class^='ms-Button ms-Button--icon ms-ComboBox-CaretDown-button']", ICEFeatureDropdownOption = "div[class^='featureImportanceLegend'] div[class^='ms-ComboBox-container'] button:contains('workclass')", - ICEXAxisNewValue = "#subPlotContainer text[class^='xtitle']", + ICEXAxisNewValue = "#subPlotContainer text[class^='highcharts-axis-title']", ICEToolTipButton = "#subPlotContainer button:contains('How to read this chart')", ICECalloutTitle = "#subPlotContainer div.ms-Callout-container span[class^='calloutTitle']", ICECalloutBody = "#subPlotContainer div.ms-Callout-container div[class^='calloutInner']", diff --git a/libs/causality/src/lib/CausalAnalysisDashboard/Controls/CausalAnalysisView/CausalIndividualView/getIndividualChartOptions.ts b/libs/causality/src/lib/CausalAnalysisDashboard/Controls/CausalAnalysisView/CausalIndividualView/getIndividualChartOptions.ts index a372e0391d..1a5ce5a185 100644 --- a/libs/causality/src/lib/CausalAnalysisDashboard/Controls/CausalAnalysisView/CausalIndividualView/getIndividualChartOptions.ts +++ b/libs/causality/src/lib/CausalAnalysisDashboard/Controls/CausalAnalysisView/CausalIndividualView/getIndividualChartOptions.ts @@ -35,7 +35,7 @@ export function getIndividualChartOptions( const series = data.map((d) => { return { data: d, - showInLegend: false + name: "" }; }); return { diff --git a/libs/core-ui/src/lib/Highchart/FeatureImportanceBar.tsx b/libs/core-ui/src/lib/Highchart/FeatureImportanceBar.tsx index 898de7cde2..21138785df 100644 --- a/libs/core-ui/src/lib/Highchart/FeatureImportanceBar.tsx +++ b/libs/core-ui/src/lib/Highchart/FeatureImportanceBar.tsx @@ -57,6 +57,7 @@ export class FeatureImportanceBar extends React.Component< if ( this.props.unsortedSeries !== prevProps.unsortedSeries || this.props.sortArray !== prevProps.sortArray || + this.props.topK !== prevProps.topK || this.props.chartType !== prevProps.chartType ) { this.setState({ diff --git a/libs/core-ui/src/lib/Highchart/getDefaultHighchartOptions.ts b/libs/core-ui/src/lib/Highchart/getDefaultHighchartOptions.ts index c9c3ed4176..975287142c 100644 --- a/libs/core-ui/src/lib/Highchart/getDefaultHighchartOptions.ts +++ b/libs/core-ui/src/lib/Highchart/getDefaultHighchartOptions.ts @@ -34,6 +34,9 @@ export function getDefaultHighchartOptions(theme: ITheme): Highcharts.Options { zoomType: "xy" }, credits: undefined, + legend: { + enabled: false + }, plotOptions: { area: { marker: { diff --git a/libs/interpret/src/lib/MLIDashboard/Controls/GlobalExplanationTab/GlobalExplanationTab.tsx b/libs/interpret/src/lib/MLIDashboard/Controls/GlobalExplanationTab/GlobalExplanationTab.tsx index a0f55f24c5..17f891cbc7 100644 --- a/libs/interpret/src/lib/MLIDashboard/Controls/GlobalExplanationTab/GlobalExplanationTab.tsx +++ b/libs/interpret/src/lib/MLIDashboard/Controls/GlobalExplanationTab/GlobalExplanationTab.tsx @@ -287,7 +287,7 @@ export class GlobalExplanationTab extends React.PureComponent< /> - + {featureOptions && ( )} {cohortOptions && ( diff --git a/libs/interpret/src/lib/MLIDashboard/Controls/MultiICEPlot/MultiICEPlot.tsx b/libs/interpret/src/lib/MLIDashboard/Controls/MultiICEPlot/MultiICEPlot.tsx index f70dc57f51..74d179ac0e 100644 --- a/libs/interpret/src/lib/MLIDashboard/Controls/MultiICEPlot/MultiICEPlot.tsx +++ b/libs/interpret/src/lib/MLIDashboard/Controls/MultiICEPlot/MultiICEPlot.tsx @@ -2,33 +2,25 @@ // Licensed under the MIT License. import { - isTwoDimArray, - ModelTypes, JointDataset, IExplanationModelMetadata, - ModelExplanationUtils, - FabricStyles + FabricStyles, + BasicHighChart } from "@responsible-ai/core-ui"; import { localization } from "@responsible-ai/localization"; -import { - IPlotlyProperty, - RangeTypes, - AccessibleChart, - PlotlyMode -} from "@responsible-ai/mlchartlib"; -import _, { toNumber, map } from "lodash"; +import { RangeTypes } from "@responsible-ai/mlchartlib"; +import _, { toNumber } from "lodash"; import { IComboBox, IComboBoxOption, ComboBox, SpinButton, - Text, - getTheme + Text } from "office-ui-fabric-react"; -import { Data } from "plotly.js"; import React from "react"; import { NoDataMessage } from "../../SharedComponents/NoDataMessage"; +import { getIceChartOption } from "../../utils/getIceChartOption"; import { IRangeView } from "../ICEPlot"; import { multiIcePlotStyles } from "./MultiICEPlot.styles"; @@ -70,101 +62,6 @@ export class MultiICEPlot extends React.PureComponent< this.debounceFetchData = _.debounce(this.fetchData.bind(this), 500); } - private static buildYAxis( - metadata: IExplanationModelMetadata, - selectedClass: number - ): string { - if (metadata.modelType === ModelTypes.Regression) { - return localization.Interpret.IcePlot.prediction; - } - return `${ - localization.Interpret.IcePlot.predictedProbability - }
${localization.formatString( - localization.Interpret.WhatIfTab.classLabel, - metadata.classNames[selectedClass] - )}`; - } - private static buildPlotlyProps( - metadata: IExplanationModelMetadata, - featureName: string, - selectedClass: number, - colors: string[], - rowNames: string[], - rangeType?: RangeTypes, - xData?: Array, - yData?: number[][] | number[][][] - ): IPlotlyProperty | undefined { - if ( - yData === undefined || - xData === undefined || - yData.length === 0 || - yData.some((row: number[] | number[][]) => row === undefined) - ) { - return undefined; - } - const data: Data[] = map( - yData, - (singleRow: number[] | number[][], rowIndex: number) => { - const transposedY: number[][] = isTwoDimArray(singleRow) - ? ModelExplanationUtils.transpose2DArray(singleRow) - : [singleRow]; - const predictionLabel = - metadata.modelType === ModelTypes.Regression - ? localization.Interpret.IcePlot.prediction - : `${localization.Interpret.IcePlot.predictedProbability}: ${metadata.classNames[selectedClass]}`; - const hovertemplate = `%{customdata.Name}
${featureName}: %{x}
${predictionLabel}: %{customdata.Yformatted}
`; - return { - customdata: transposedY[selectedClass].map((predY) => { - return { - Name: rowNames[rowIndex], - Yformatted: predY.toLocaleString(undefined, { - maximumFractionDigits: 3 - }) - }; - }), - hoverinfo: "all", - hovertemplate, - marker: { - color: colors[rowIndex] - }, - mode: - rangeType === RangeTypes.Categorical - ? PlotlyMode.Markers - : PlotlyMode.LinesMarkers, - name: rowNames[rowIndex], - type: "scatter", - x: xData, - y: transposedY[selectedClass] - }; - } - ) as any; - return { - config: { displaylogo: false, displayModeBar: false, responsive: true }, - data, - layout: { - autosize: true, - dragmode: false, - font: { - size: 10 - }, - hovermode: "closest", - margin: { - b: 30, - r: 10, - t: 10 - }, - showlegend: false, - xaxis: { - automargin: true, - title: featureName - }, - yaxis: { - automargin: true, - title: MultiICEPlot.buildYAxis(metadata, selectedClass) - } - } - }; - } public componentDidMount(): void { this.fetchData(); @@ -195,18 +92,18 @@ export class MultiICEPlot extends React.PureComponent< const hasOutgoingRequest = this.state.abortControllers.some( (x) => x !== undefined ); - const plotlyProps = this.state.rangeView - ? MultiICEPlot.buildPlotlyProps( - this.props.metadata, - this.props.jointDataset.metaDict[this.props.feature].label, - this.props.selectedClass, - this.props.colors, - this.props.rowNames, - this.state.rangeView.type, - this.state.xAxisArray, - this.state.yAxes - ) - : undefined; + const iceChartOption = + this.state.rangeView && + getIceChartOption( + this.props.metadata, + this.props.jointDataset.metaDict[this.props.feature].label, + this.props.selectedClass, + this.props.colors, + this.props.rowNames, + this.state.rangeView.type, + this.state.xAxisArray, + this.state.yAxes + ); const hasError = this.state.rangeView !== undefined && (this.state.rangeView.maxErrorMessage !== undefined || @@ -306,7 +203,7 @@ export class MultiICEPlot extends React.PureComponent< {this.state.errorMessage} )} - {plotlyProps === undefined && !hasOutgoingRequest && ( + {!iceChartOption && !hasOutgoingRequest && (
{localization.Interpret.IcePlot.submitPrompt}
@@ -316,9 +213,9 @@ export class MultiICEPlot extends React.PureComponent< {localization.Interpret.IcePlot.topLevelErrorMessage} )} - {plotlyProps !== undefined && !hasOutgoingRequest && !hasError && ( + {iceChartOption && !hasOutgoingRequest && !hasError && (
- +
)} diff --git a/libs/interpret/src/lib/MLIDashboard/utils/buildYAxis.ts b/libs/interpret/src/lib/MLIDashboard/utils/buildYAxis.ts new file mode 100644 index 0000000000..a1c9d1b14b --- /dev/null +++ b/libs/interpret/src/lib/MLIDashboard/utils/buildYAxis.ts @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +import { IExplanationModelMetadata, ModelTypes } from "@responsible-ai/core-ui"; +import { localization } from "@responsible-ai/localization"; + +export function buildYAxis( + metadata: IExplanationModelMetadata, + selectedClass: number +): string { + if (metadata.modelType === ModelTypes.Regression) { + return localization.Interpret.IcePlot.prediction; + } + return `${ + localization.Interpret.IcePlot.predictedProbability + }
${localization.formatString( + localization.Interpret.WhatIfTab.classLabel, + metadata.classNames[selectedClass] + )}`; +} diff --git a/libs/interpret/src/lib/MLIDashboard/utils/getIceChartOption.ts b/libs/interpret/src/lib/MLIDashboard/utils/getIceChartOption.ts new file mode 100644 index 0000000000..3d9a3d10d1 --- /dev/null +++ b/libs/interpret/src/lib/MLIDashboard/utils/getIceChartOption.ts @@ -0,0 +1,138 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +import { + IExplanationModelMetadata, + isTwoDimArray, + ModelExplanationUtils, + ModelTypes +} from "@responsible-ai/core-ui"; +import { localization } from "@responsible-ai/localization"; +import { RangeTypes } from "@responsible-ai/mlchartlib"; +import { map } from "lodash"; +import { Data } from "plotly.js"; + +import { buildYAxis } from "./buildYAxis"; +import { mergeXYData } from "./mergeXYData"; + +export interface IIceChartData { + x?: number; + y: number; + customData: any[]; +} +export function getIceChartOption( + metadata: IExplanationModelMetadata, + featureName: string, + selectedClass: number, + colors: string[], + rowNames: string[], + rangeType?: RangeTypes, + xData?: Array, + yData?: number[][] | number[][][] +): any { + if ( + yData === undefined || + xData === undefined || + yData.length === 0 || + yData.some((row: number[] | number[][]) => row === undefined) + ) { + return undefined; + } + const data: Data[] = map( + yData, + (singleRow: number[] | number[][], rowIndex: number) => { + const transposedY: number[][] = isTwoDimArray(singleRow) + ? ModelExplanationUtils.transpose2DArray(singleRow) + : [singleRow]; + const predictionLabel = + metadata.modelType === ModelTypes.Regression + ? localization.Interpret.IcePlot.prediction + : `${localization.Interpret.IcePlot.predictedProbability}: ${metadata.classNames[selectedClass]}`; + + const customData = transposedY[selectedClass].map((predY) => { + return { + Name: rowNames[rowIndex], + template: "", + Yformatted: predY.toLocaleString(undefined, { + maximumFractionDigits: 3 + }) + }; + }); + customData.forEach((c, index) => { + c.template = `${c.Name}
${featureName}: ${xData[index]}
${predictionLabel}: ${c.Yformatted}
`; + }); + return { + customdata: customData, + marker: { + color: colors[rowIndex] + }, + name: rowNames[rowIndex], + x: xData, + y: transposedY[selectedClass] + }; + } + ) as any; + const xAxisSetting = + rangeType === RangeTypes.Categorical + ? { categories: data[0]?.x, title: { text: featureName } } + : { + title: { + text: featureName + } + }; + const dataSeries: any = data.map((d) => { + return { + color: d.marker?.color, + data: mergeXYData( + d.x, + d.y, + d.customdata, + rangeType === RangeTypes.Categorical + ), + name: d.name + }; + }); + return { + chart: { + type: rangeType === RangeTypes.Categorical ? "scatter" : "" + }, + plotOptions: { + line: { + marker: { + states: { + hover: { + enabled: true + } + } + }, + tooltip: { + headerFormat: "", + pointFormat: `{point.customData.template}` + } + }, + scatter: { + marker: { + states: { + hover: { + enabled: true + } + } + }, + tooltip: { + headerFormat: "", + pointFormat: `{point.customData.template}` + } + } + }, + series: dataSeries, + title: { + text: "" + }, + xAxis: xAxisSetting, + yAxis: { + title: { + text: buildYAxis(metadata, selectedClass) + } + } + }; +} diff --git a/libs/interpret/src/lib/MLIDashboard/utils/mergeXYData.ts b/libs/interpret/src/lib/MLIDashboard/utils/mergeXYData.ts new file mode 100644 index 0000000000..c78023bbe9 --- /dev/null +++ b/libs/interpret/src/lib/MLIDashboard/utils/mergeXYData.ts @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +import { IIceChartData } from "./getIceChartOption"; + +export function mergeXYData( + xData: any, + yData: any, + customData: any, + isCategorical: boolean +): IIceChartData[] { + if (xData.length !== yData.length) { + return []; + } + const data: IIceChartData[] = []; + xData.forEach((x: any, index: number) => { + data.push( + isCategorical + ? { customData: customData[index], y: yData[index] } + : { customData: customData[index], x, y: yData[index] } + ); + }); + return data; +}