From 656f9e588c71624a60a1bf81736fbbcb89aeecd5 Mon Sep 17 00:00:00 2001 From: Alex Arvanitidis Date: Thu, 7 Nov 2024 12:35:15 +0200 Subject: [PATCH] fix: confusion matrix for binary classification --- src/app/api.schema.d.ts | 25 +++++------------ .../scores/BinaryClassificationScoreCard.tsx | 12 ++++++++- .../components/scores/ConfusionMatrix.tsx | 27 ++++++++++--------- .../MulticlassClassificationScoreCard.tsx | 2 +- 4 files changed, 32 insertions(+), 34 deletions(-) diff --git a/src/app/api.schema.d.ts b/src/app/api.schema.d.ts index 5a2260f..21f71b7 100644 --- a/src/app/api.schema.d.ts +++ b/src/app/api.schema.d.ts @@ -241,7 +241,6 @@ export interface components { test?: components["schemas"]["Scores"][]; crossValidation?: components["schemas"]["Scores"][]; }; - extraConfig?: components["schemas"]["ModelExtraConfig"]; /** * Format: date-time * @description The date and time when the feature was created. @@ -291,6 +290,7 @@ export interface components { }; RegressionScores: { yName: string; + folds?: number; /** Format: float */ r2?: number; /** Format: float */ @@ -311,6 +311,7 @@ export interface components { BinaryClassificationScores: { labels?: string[]; yName: string; + folds?: number; /** Format: float */ accuracy?: number; /** Format: float */ @@ -326,6 +327,7 @@ export interface components { MulticlassClassificationScores: { labels?: string[]; yName: string; + folds?: number; /** Format: float */ accuracy?: number; /** Format: float */ @@ -349,14 +351,6 @@ export interface components { }; /** @enum {string} */ ModelType: "SKLEARN" | "TORCH_ONNX" | "TORCHSCRIPT" | "R_BNLEARN_DISCRETE" | "R_CARET" | "R_GBM" | "R_NAIVE_BAYES" | "R_PBPK" | "R_RF" | "R_RPART" | "R_SVM" | "R_TREE_CLASS" | "R_TREE_REGR" | "QSAR_TOOLBOX_CALCULATOR" | "QSAR_TOOLBOX_QSAR_MODEL" | "QSAR_TOOLBOX_PROFILER"; - /** @description A JSON object containing extra configuration for the model */ - ModelExtraConfig: { - torchConfig?: { - [key: string]: components["schemas"]["AnyValue"]; - }; - preprocessors?: components["schemas"]["Transformer"][]; - featurizers?: components["schemas"]["Transformer"][]; - }; /** @description A preprocessor for the model */ Transformer: { /** Format: int64 */ @@ -416,8 +410,9 @@ export interface components { /** Format: int64 */ id?: number; method: components["schemas"]["DoaMethod"]; - /** @description The doa calculated data */ - data: components["schemas"]["LeverageDoa"] | components["schemas"]["BoundingBoxDoa"] | components["schemas"]["KernelBasedDoa"] | components["schemas"]["MeanVarDoa"] | components["schemas"]["MahalanobisDoa"] | components["schemas"]["CityBlockDoa"]; + data: { + [key: string]: components["schemas"]["AnyValue"]; + }; /** * Format: date-time * @description The date and time when the feature was created. @@ -679,10 +674,6 @@ export interface components { torchConfig?: { [key: string]: components["schemas"]["AnyValue"]; } | null; - /** @description Additional configuration settings, optional */ - extraConfig?: { - [key: string]: components["schemas"]["AnyValue"]; - } | null; /** @description Legacy additional information settings, optional */ legacyAdditionalInfo?: { [key: string]: components["schemas"]["AnyValue"]; @@ -693,10 +684,6 @@ export interface components { PredictionRequest: { model: components["schemas"]["PredictionModel"]; dataset: components["schemas"]["Dataset"]; - /** @description Optional configuration for additional settings. */ - extraConfig?: { - [key: string]: components["schemas"]["AnyValue"]; - }; }; PredictionResponse: { predictions: components["schemas"]["AnyValue"][]; diff --git a/src/app/dashboard/models/[modelId]/components/scores/BinaryClassificationScoreCard.tsx b/src/app/dashboard/models/[modelId]/components/scores/BinaryClassificationScoreCard.tsx index f9cc8bc..883211f 100644 --- a/src/app/dashboard/models/[modelId]/components/scores/BinaryClassificationScoreCard.tsx +++ b/src/app/dashboard/models/[modelId]/components/scores/BinaryClassificationScoreCard.tsx @@ -6,6 +6,10 @@ interface BinaryClassificationScoreCardProps { score: BinaryClassificationDto; } +function transposeMatrix(matrix: number[][]) { + return matrix[0].map((col, i) => matrix.map((row) => row[i])); +} + export default function BinaryClassificationScoreCard({ score, }: BinaryClassificationScoreCardProps) { @@ -37,7 +41,13 @@ export default function BinaryClassificationScoreCard({ matthewsCorrCoef: {score!.matthewsCorrCoef}
- +
diff --git a/src/app/dashboard/models/[modelId]/components/scores/ConfusionMatrix.tsx b/src/app/dashboard/models/[modelId]/components/scores/ConfusionMatrix.tsx index 2237774..b329286 100644 --- a/src/app/dashboard/models/[modelId]/components/scores/ConfusionMatrix.tsx +++ b/src/app/dashboard/models/[modelId]/components/scores/ConfusionMatrix.tsx @@ -2,11 +2,10 @@ import CustomizedTreemapContent from '@/app/dashboard/models/[modelId]/component interface ConfusionMatrixProps { matrix: number[][][] | undefined; - classNames: string[]; + classNames?: string[]; } import React from 'react'; -import { Treemap, ResponsiveContainer } from 'recharts'; import Heatmap from '@/app/dashboard/models/[modelId]/components/scores/Heatmap'; const data = [ @@ -61,17 +60,19 @@ export default function ConfusionMatrix({

Confusion matrix

- {classNames.map((className, index) => ( -
- -
- ))} + {classNames?.map((className, index) => { + return ( +
+ +
+ ); + })} ); } diff --git a/src/app/dashboard/models/[modelId]/components/scores/MulticlassClassificationScoreCard.tsx b/src/app/dashboard/models/[modelId]/components/scores/MulticlassClassificationScoreCard.tsx index ae8f677..67fb778 100644 --- a/src/app/dashboard/models/[modelId]/components/scores/MulticlassClassificationScoreCard.tsx +++ b/src/app/dashboard/models/[modelId]/components/scores/MulticlassClassificationScoreCard.tsx @@ -40,7 +40,7 @@ export default function MulticlassClassificationScoreCard({