Skip to content

Commit

Permalink
remove duplicate code in explanation dashboard (#1266)
Browse files Browse the repository at this point in the history
  • Loading branch information
imatiach-msft authored Mar 15, 2022
1 parent 67209e5 commit f3279d3
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 54 deletions.
2 changes: 1 addition & 1 deletion libs/core-ui/src/lib/util/JointDataset.ts
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ export class JointDataset {
}
}

private static buildLocalFeatureMatrix(
public static buildLocalFeatureMatrix(
localExplanationRaw: number[][] | number[][][],
modelType: ModelTypes
): number[][][] {
Expand Down
58 changes: 5 additions & 53 deletions libs/interpret/src/lib/MLIDashboard/ExplanationDashboard.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -108,27 +108,6 @@ export class ExplanationDashboard extends React.Component<
"ICE"
];

private static transposeLocalImportanceMatrix: (
input: number[][][]
) => number[][][] = memoize((input: number[][][]): number[][][] => {
const numClasses = input.length;
const numRows = input[0].length;
const numFeatures = input[0][0].length;
const result: number[][][] = new Array(numRows)
.fill(0)
.map(() =>
new Array(numFeatures).fill(0).map(() => new Array(numClasses).fill(0))
);
input.forEach((rowByFeature, classIndex) => {
rowByFeature.forEach((featureArray, rowIndex) => {
featureArray.forEach((value, featureIndex) => {
result[rowIndex][featureIndex][classIndex] = value;
});
});
});
return result;
});

private static buildWeightDropdownOptions: (
explanationContext: IExplanationContext
) => IDropdownOption[] = memoize(
Expand Down Expand Up @@ -345,7 +324,7 @@ export class ExplanationDashboard extends React.Component<
const weighting = props.predictedY
? WeightVectors.Predicted
: WeightVectors.AbsAvg;
const localFeatureMatrix = ExplanationDashboard.buildLocalFeatureMatrix(
const localFeatureMatrix = JointDataset.buildLocalFeatureMatrix(
props.precomputedExplanations.localFeatureImportance.scores,
modelMetadata.modelType
);
Expand Down Expand Up @@ -492,32 +471,6 @@ export class ExplanationDashboard extends React.Component<
}
}

private static buildLocalFeatureMatrix(
localExplanationRaw: number[][] | number[][][],
modelType: ModelTypes
): number[][][] {
switch (modelType) {
case ModelTypes.Regression: {
return (localExplanationRaw as number[][]).map((featureArray) =>
featureArray.map((val) => [val])
);
}
case ModelTypes.Binary: {
return ExplanationDashboard.transposeLocalImportanceMatrix(
localExplanationRaw as number[][][]
).map((featuresByClasses) =>
featuresByClasses.map((classArray) => classArray.slice(0, 1))
);
}
case ModelTypes.Multiclass:
default: {
return ExplanationDashboard.transposeLocalImportanceMatrix(
localExplanationRaw as number[][][]
);
}
}
}

private static buildLocalFlattenMatrix(
localExplanations: number[][][] | undefined,
modelType: ModelTypes,
Expand Down Expand Up @@ -999,11 +952,10 @@ export class ExplanationDashboard extends React.Component<
this.setState((prevState) => {
const weighting =
prevState.dashboardContext.weightContext.selectedKey;
const localFeatureMatrix =
ExplanationDashboard.buildLocalFeatureMatrix(
result,
modelMetadata.modelType
);
const localFeatureMatrix = JointDataset.buildLocalFeatureMatrix(
result,
modelMetadata.modelType
);
const flattenedFeatureMatrix =
ExplanationDashboard.buildLocalFlattenMatrix(
localFeatureMatrix,
Expand Down

0 comments on commit f3279d3

Please sign in to comment.