diff --git a/x-pack/plugins/ml/common/types/feature_importance.ts b/x-pack/plugins/ml/common/types/feature_importance.ts index 4f5619cf3ab7b..1ae4c7832390c 100644 --- a/x-pack/plugins/ml/common/types/feature_importance.ts +++ b/x-pack/plugins/ml/common/types/feature_importance.ts @@ -8,10 +8,13 @@ export interface ClassFeatureImportance { class_name: string | boolean; importance: number; } + +// TODO We should separate the interface because classes/importance +// isn't both optional but either/or. export interface FeatureImportance { feature_name: string; - importance?: number; classes?: ClassFeatureImportance[]; + importance?: number; } export interface TopClass { diff --git a/x-pack/plugins/ml/public/application/components/data_grid/common.test.ts b/x-pack/plugins/ml/public/application/components/data_grid/common.test.ts index 4bb670ad02dfc..aaf6f90b00f4d 100644 --- a/x-pack/plugins/ml/public/application/components/data_grid/common.test.ts +++ b/x-pack/plugins/ml/public/application/components/data_grid/common.test.ts @@ -8,7 +8,7 @@ import { EuiDataGridSorting } from '@elastic/eui'; import { multiColumnSortFactory } from './common'; -describe('Transform: Define Pivot Common', () => { +describe('Data Frame Analytics: Data Grid Common', () => { test('multiColumnSortFactory()', () => { const data = [ { s: 'a', n: 1 }, diff --git a/x-pack/plugins/ml/public/application/components/data_grid/common.ts b/x-pack/plugins/ml/public/application/components/data_grid/common.ts index 642d0ae564b85..48a0a0c9ab126 100644 --- a/x-pack/plugins/ml/public/application/components/data_grid/common.ts +++ b/x-pack/plugins/ml/public/application/components/data_grid/common.ts @@ -24,7 +24,9 @@ import { KBN_FIELD_TYPES, } from '../../../../../../../src/plugins/data/public'; +import { DEFAULT_RESULTS_FIELD } from '../../../../common/constants/data_frame_analytics'; import { extractErrorMessage } from '../../../../common/util/errors'; +import { FeatureImportance, TopClasses } from '../../../../common/types/feature_importance'; import { BASIC_NUMERICAL_TYPES, @@ -158,6 +160,90 @@ export const getDataGridSchemaFromKibanaFieldType = ( return schema; }; +const getClassName = (className: string, isClassTypeBoolean: boolean) => { + if (isClassTypeBoolean) { + return className === 'true'; + } + + return className; +}; +/** + * Helper to transform feature importance flattened fields with arrays back to object structure + * + * @param row - EUI data grid data row + * @param mlResultsField - Data frame analytics results field + * @returns nested object structure of feature importance values + */ +export const getFeatureImportance = ( + row: Record, + mlResultsField: string, + isClassTypeBoolean = false +): FeatureImportance[] => { + const featureNames: string[] | undefined = + row[`${mlResultsField}.feature_importance.feature_name`]; + const classNames: string[] | undefined = + row[`${mlResultsField}.feature_importance.classes.class_name`]; + const classImportance: number[] | undefined = + row[`${mlResultsField}.feature_importance.classes.importance`]; + + if (featureNames === undefined) { + return []; + } + + // return object structure for classification job + if (classNames !== undefined && classImportance !== undefined) { + const overallClassNames = classNames?.slice(0, classNames.length / featureNames.length); + + return featureNames.map((fName, index) => { + const offset = overallClassNames.length * index; + const featureClassImportance = classImportance.slice( + offset, + offset + overallClassNames.length + ); + return { + feature_name: fName, + classes: overallClassNames.map((fClassName, fIndex) => { + return { + class_name: getClassName(fClassName, isClassTypeBoolean), + importance: featureClassImportance[fIndex], + }; + }), + }; + }); + } + + // return object structure for regression job + const importance: number[] = row[`${mlResultsField}.feature_importance.importance`]; + return featureNames.map((fName, index) => ({ + feature_name: fName, + importance: importance[index], + })); +}; + +/** + * Helper to transforms top classes flattened fields with arrays back to object structure + * + * @param row - EUI data grid data row + * @param mlResultsField - Data frame analytics results field + * @returns nested object structure of feature importance values + */ +export const getTopClasses = (row: Record, mlResultsField: string): TopClasses => { + const classNames: string[] | undefined = row[`${mlResultsField}.top_classes.class_name`]; + const classProbabilities: number[] | undefined = + row[`${mlResultsField}.top_classes.class_probability`]; + const classScores: number[] | undefined = row[`${mlResultsField}.top_classes.class_score`]; + + if (classNames === undefined || classProbabilities === undefined || classScores === undefined) { + return []; + } + + return classNames.map((className, index) => ({ + class_name: className, + class_probability: classProbabilities[index], + class_score: classScores[index], + })); +}; + export const useRenderCellValue = ( indexPattern: IndexPattern | undefined, pagination: IndexPagination, @@ -207,6 +293,15 @@ export const useRenderCellValue = ( return item[cId]; } + // For classification and regression results, we need to treat some fields with a custom transform. + if (cId === `${resultsField}.feature_importance`) { + return getFeatureImportance(fullItem, resultsField ?? DEFAULT_RESULTS_FIELD); + } + + if (cId === `${resultsField}.top_classes`) { + return getTopClasses(fullItem, resultsField ?? DEFAULT_RESULTS_FIELD); + } + // Try if the field name is available as a nested field. return getNestedProperty(tableItems[adjustedRowIndex], cId, null); } diff --git a/x-pack/plugins/ml/public/application/components/data_grid/data_grid.tsx b/x-pack/plugins/ml/public/application/components/data_grid/data_grid.tsx index fad2439f5d5ee..50e9cabc99c35 100644 --- a/x-pack/plugins/ml/public/application/components/data_grid/data_grid.tsx +++ b/x-pack/plugins/ml/public/application/components/data_grid/data_grid.tsx @@ -27,10 +27,15 @@ import { DEFAULT_SAMPLER_SHARD_SIZE } from '../../../../common/constants/field_h import { ANALYSIS_CONFIG_TYPE, INDEX_STATUS } from '../../data_frame_analytics/common'; -import { euiDataGridStyle, euiDataGridToolbarSettings } from './common'; +import { + euiDataGridStyle, + euiDataGridToolbarSettings, + getFeatureImportance, + getTopClasses, +} from './common'; import { UseIndexDataReturnType } from './types'; import { DecisionPathPopover } from './feature_importance/decision_path_popover'; -import { TopClasses } from '../../../../common/types/feature_importance'; +import { FeatureImportance, TopClasses } from '../../../../common/types/feature_importance'; import { DEFAULT_RESULTS_FIELD } from '../../../../common/constants/data_frame_analytics'; import { DataFrameAnalysisConfigType } from '../../../../common/types/data_frame_analytics'; @@ -118,18 +123,28 @@ export const DataGrid: FC = memo( if (!row) return
; // if resultsField for some reason is not available then use ml const mlResultsField = resultsField ?? DEFAULT_RESULTS_FIELD; - const parsedFIArray = row[mlResultsField].feature_importance; let predictedValue: string | number | undefined; let topClasses: TopClasses = []; if ( predictionFieldName !== undefined && row && - row[mlResultsField][predictionFieldName] !== undefined + row[`${mlResultsField}.${predictionFieldName}`] !== undefined ) { - predictedValue = row[mlResultsField][predictionFieldName]; - topClasses = row[mlResultsField].top_classes; + predictedValue = row[`${mlResultsField}.${predictionFieldName}`]; + topClasses = getTopClasses(row, mlResultsField); } + const isClassTypeBoolean = topClasses.reduce( + (p, c) => typeof c.class_name === 'boolean' || p, + false + ); + + const parsedFIArray: FeatureImportance[] = getFeatureImportance( + row, + mlResultsField, + isClassTypeBoolean + ); + return ( !field.name.includes(`${resultsField}.${FEATURE_IMPORTANCE}.`) + ); } if ((numTopClasses ?? 0) > 0) { @@ -221,6 +225,10 @@ export const getDefaultFieldsFromJobCaps = ( name: `${resultsField}.${TOP_CLASSES}`, type: KBN_FIELD_TYPES.UNKNOWN, }); + // remove flattened top classes fields + fields = fields.filter( + (field: any) => !field.name.includes(`${resultsField}.${TOP_CLASSES}.`) + ); } // Only need to add these fields if we didn't use dest index pattern to get the fields diff --git a/x-pack/plugins/ml/public/application/data_frame_analytics/common/get_index_data.ts b/x-pack/plugins/ml/public/application/data_frame_analytics/common/get_index_data.ts index 8e50aab0914db..85f222109d408 100644 --- a/x-pack/plugins/ml/public/application/data_frame_analytics/common/get_index_data.ts +++ b/x-pack/plugins/ml/public/application/data_frame_analytics/common/get_index_data.ts @@ -53,7 +53,7 @@ export const getIndexData = async ( index: jobConfig.dest.index, body: { fields: ['*'], - _source: [], + _source: false, query: searchQuery, from: pageIndex * pageSize, size: pageSize, diff --git a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/exploration_results_table/exploration_results_table.tsx b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/exploration_results_table/exploration_results_table.tsx index a6e95269b3633..10e2ad5b5eb53 100644 --- a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/exploration_results_table/exploration_results_table.tsx +++ b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/exploration_results_table/exploration_results_table.tsx @@ -29,7 +29,7 @@ interface Props { } export const ExplorationResultsTable: FC = React.memo( - ({ indexPattern, jobConfig, jobStatus, needsDestIndexPattern, searchQuery }) => { + ({ indexPattern, jobConfig, needsDestIndexPattern, searchQuery }) => { const { services: { mlServices: { mlApiServices },