Skip to content

Commit

Permalink
Add ICE chart (#1283)
Browse files Browse the repository at this point in the history
* Fix

* ice chart

* ic

* test

* test
  • Loading branch information
zhb000 authored Mar 19, 2022
1 parent 2f550f5 commit db3cd4c
Show file tree
Hide file tree
Showing 10 changed files with 212 additions and 130 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ export function describeSubLineChart(dataShape: IInterpretData): void {
props.chart.clickNthPoint(1);
});
it("should have more than one point", () => {
cy.get("#subPlotContainer svg g[class^='plot'] .points .point")
cy.get("#subPlotContainer svg g[class^='highcharts-series-group'] path")
.its("length")
.should("be.gte", 1);
});
Expand Down
4 changes: 2 additions & 2 deletions apps/widget-e2e/src/describer/modelAssessment/Constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ export enum Locators {
IFIYAxisValue = '#FeatureImportanceBar div[class^="rotatedVerticalBox-"]',
IFIXAxisValue = "#FeatureImportanceBar g.highcharts-xaxis-labels text",
ICEPlot = '#subPlotChoice label:contains("ICE")', // ICE - Individual Conditional Expectation
ICENoOfPoints = "#subPlotContainer svg g[class^='plot'] .points .point",
ICENoOfPoints = "#subPlotContainer svg g[class^='highcharts-series-group'] path",
IFITopFeaturesText = "div[class^='featureImportanceControls'] span[class^='sliderLabel']",
IFITopFeaturesValue = "div[class^='featureImportanceControls'] div.ms-Slider-container div.ms-Slider-slideBox",
IFIAbsoluteValuesToggleButton = "div[class^='featureImportanceLegend'] div.ms-Toggle",
IFIDataPointDropdown = "div[class^='featureImportanceLegend'] div[role='listbox']",
ICEFeatureDropdown = "div[class^='featureImportanceLegend'] div[class^='ms-ComboBox-container'] button[class^='ms-Button ms-Button--icon ms-ComboBox-CaretDown-button']",
ICEFeatureDropdownOption = "div[class^='featureImportanceLegend'] div[class^='ms-ComboBox-container'] button:contains('workclass')",
ICEXAxisNewValue = "#subPlotContainer text[class^='xtitle']",
ICEXAxisNewValue = "#subPlotContainer text[class^='highcharts-axis-title']",
ICEToolTipButton = "#subPlotContainer button:contains('How to read this chart')",
ICECalloutTitle = "#subPlotContainer div.ms-Callout-container span[class^='calloutTitle']",
ICECalloutBody = "#subPlotContainer div.ms-Callout-container div[class^='calloutInner']",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ export function getIndividualChartOptions(
const series = data.map((d) => {
return {
data: d,
showInLegend: false
name: ""
};
});
return {
Expand Down
1 change: 1 addition & 0 deletions libs/core-ui/src/lib/Highchart/FeatureImportanceBar.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ export class FeatureImportanceBar extends React.Component<
if (
this.props.unsortedSeries !== prevProps.unsortedSeries ||
this.props.sortArray !== prevProps.sortArray ||
this.props.topK !== prevProps.topK ||
this.props.chartType !== prevProps.chartType
) {
this.setState({
Expand Down
3 changes: 3 additions & 0 deletions libs/core-ui/src/lib/Highchart/getDefaultHighchartOptions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ export function getDefaultHighchartOptions(theme: ITheme): Highcharts.Options {
zoomType: "xy"
},
credits: undefined,
legend: {
enabled: false
},
plotOptions: {
area: {
marker: {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ export class GlobalExplanationTab extends React.PureComponent<
/>
</div>
</Stack.Item>
<Stack.Item className={classNames.chartRightPart}>
<Stack.Item className={classNames.legendAndSort}>
{featureOptions && (
<ComboBox
id="DependencePlotFeatureSelection"
Expand All @@ -304,7 +304,6 @@ export class GlobalExplanationTab extends React.PureComponent<
selectedKey={this.state.dependenceProps?.xAxis.property}
onChange={this.onXSet}
calloutProps={FabricStyles.calloutProps}
styles={FabricStyles.defaultDropdownStyle}
/>
)}
{cohortOptions && (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,25 @@
// Licensed under the MIT License.

import {
isTwoDimArray,
ModelTypes,
JointDataset,
IExplanationModelMetadata,
ModelExplanationUtils,
FabricStyles
FabricStyles,
BasicHighChart
} from "@responsible-ai/core-ui";
import { localization } from "@responsible-ai/localization";
import {
IPlotlyProperty,
RangeTypes,
AccessibleChart,
PlotlyMode
} from "@responsible-ai/mlchartlib";
import _, { toNumber, map } from "lodash";
import { RangeTypes } from "@responsible-ai/mlchartlib";
import _, { toNumber } from "lodash";
import {
IComboBox,
IComboBoxOption,
ComboBox,
SpinButton,
Text,
getTheme
Text
} from "office-ui-fabric-react";
import { Data } from "plotly.js";
import React from "react";

import { NoDataMessage } from "../../SharedComponents/NoDataMessage";
import { getIceChartOption } from "../../utils/getIceChartOption";
import { IRangeView } from "../ICEPlot";

import { multiIcePlotStyles } from "./MultiICEPlot.styles";
Expand Down Expand Up @@ -70,101 +62,6 @@ export class MultiICEPlot extends React.PureComponent<

this.debounceFetchData = _.debounce(this.fetchData.bind(this), 500);
}
private static buildYAxis(
metadata: IExplanationModelMetadata,
selectedClass: number
): string {
if (metadata.modelType === ModelTypes.Regression) {
return localization.Interpret.IcePlot.prediction;
}
return `${
localization.Interpret.IcePlot.predictedProbability
}<br>${localization.formatString(
localization.Interpret.WhatIfTab.classLabel,
metadata.classNames[selectedClass]
)}`;
}
private static buildPlotlyProps(
metadata: IExplanationModelMetadata,
featureName: string,
selectedClass: number,
colors: string[],
rowNames: string[],
rangeType?: RangeTypes,
xData?: Array<number | string>,
yData?: number[][] | number[][][]
): IPlotlyProperty | undefined {
if (
yData === undefined ||
xData === undefined ||
yData.length === 0 ||
yData.some((row: number[] | number[][]) => row === undefined)
) {
return undefined;
}
const data: Data[] = map<number[] | number[][]>(
yData,
(singleRow: number[] | number[][], rowIndex: number) => {
const transposedY: number[][] = isTwoDimArray(singleRow)
? ModelExplanationUtils.transpose2DArray(singleRow)
: [singleRow];
const predictionLabel =
metadata.modelType === ModelTypes.Regression
? localization.Interpret.IcePlot.prediction
: `${localization.Interpret.IcePlot.predictedProbability}: ${metadata.classNames[selectedClass]}`;
const hovertemplate = `%{customdata.Name}<br>${featureName}: %{x}<br>${predictionLabel}: %{customdata.Yformatted}<br><extra></extra>`;
return {
customdata: transposedY[selectedClass].map((predY) => {
return {
Name: rowNames[rowIndex],
Yformatted: predY.toLocaleString(undefined, {
maximumFractionDigits: 3
})
};
}),
hoverinfo: "all",
hovertemplate,
marker: {
color: colors[rowIndex]
},
mode:
rangeType === RangeTypes.Categorical
? PlotlyMode.Markers
: PlotlyMode.LinesMarkers,
name: rowNames[rowIndex],
type: "scatter",
x: xData,
y: transposedY[selectedClass]
};
}
) as any;
return {
config: { displaylogo: false, displayModeBar: false, responsive: true },
data,
layout: {
autosize: true,
dragmode: false,
font: {
size: 10
},
hovermode: "closest",
margin: {
b: 30,
r: 10,
t: 10
},
showlegend: false,
xaxis: {
automargin: true,
title: featureName
},
yaxis: {
automargin: true,
title: MultiICEPlot.buildYAxis(metadata, selectedClass)
}
}
};
}

public componentDidMount(): void {
this.fetchData();
Expand Down Expand Up @@ -195,18 +92,18 @@ export class MultiICEPlot extends React.PureComponent<
const hasOutgoingRequest = this.state.abortControllers.some(
(x) => x !== undefined
);
const plotlyProps = this.state.rangeView
? MultiICEPlot.buildPlotlyProps(
this.props.metadata,
this.props.jointDataset.metaDict[this.props.feature].label,
this.props.selectedClass,
this.props.colors,
this.props.rowNames,
this.state.rangeView.type,
this.state.xAxisArray,
this.state.yAxes
)
: undefined;
const iceChartOption =
this.state.rangeView &&
getIceChartOption(
this.props.metadata,
this.props.jointDataset.metaDict[this.props.feature].label,
this.props.selectedClass,
this.props.colors,
this.props.rowNames,
this.state.rangeView.type,
this.state.xAxisArray,
this.state.yAxes
);
const hasError =
this.state.rangeView !== undefined &&
(this.state.rangeView.maxErrorMessage !== undefined ||
Expand Down Expand Up @@ -306,7 +203,7 @@ export class MultiICEPlot extends React.PureComponent<
<Text>{this.state.errorMessage}</Text>
</div>
)}
{plotlyProps === undefined && !hasOutgoingRequest && (
{!iceChartOption && !hasOutgoingRequest && (
<div className={classNames.placeholder}>
<Text>{localization.Interpret.IcePlot.submitPrompt}</Text>
</div>
Expand All @@ -316,9 +213,9 @@ export class MultiICEPlot extends React.PureComponent<
<Text>{localization.Interpret.IcePlot.topLevelErrorMessage}</Text>
</div>
)}
{plotlyProps !== undefined && !hasOutgoingRequest && !hasError && (
{iceChartOption && !hasOutgoingRequest && !hasError && (
<div className={classNames.chartWrapper}>
<AccessibleChart plotlyProps={plotlyProps} theme={getTheme()} />
<BasicHighChart configOverride={iceChartOption} />
</div>
)}
</div>
Expand Down
20 changes: 20 additions & 0 deletions libs/interpret/src/lib/MLIDashboard/utils/buildYAxis.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

import { IExplanationModelMetadata, ModelTypes } from "@responsible-ai/core-ui";
import { localization } from "@responsible-ai/localization";

export function buildYAxis(
metadata: IExplanationModelMetadata,
selectedClass: number
): string {
if (metadata.modelType === ModelTypes.Regression) {
return localization.Interpret.IcePlot.prediction;
}
return `${
localization.Interpret.IcePlot.predictedProbability
}<br>${localization.formatString(
localization.Interpret.WhatIfTab.classLabel,
metadata.classNames[selectedClass]
)}`;
}
Loading

0 comments on commit db3cd4c

Please sign in to comment.