Skip to content

Commit

Permalink
heavily populate ConfusionMatrix
Browse files Browse the repository at this point in the history
  • Loading branch information
patel-zeel committed Aug 28, 2024
1 parent 54ef6f6 commit edcfb19
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 7 deletions.
132 changes: 127 additions & 5 deletions garuda/od.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from dataclasses import dataclass

from supervision.metrics.detection import ConfusionMatrix as SVConfusionMatrix
from supervision.metrics.detection import ConfusionMatrix as SVConfusionMatrix, MeanAveragePrecision as SVMeanAveragePrecision

from garuda.ops import webm_pixel_to_geo, geo_to_webm_pixel, local_to_geo, label_studio_csv_to_obb, obb_iou
from beartype import beartype
Expand Down Expand Up @@ -167,7 +167,7 @@ def process_row(row):
return obb
except Exception as e:
warnings.warn(f"Error processing row: {row}\n{e}")
return None
return np.zeros((0, 9))

df["obb"] = df.apply(process_row, axis=1)
return df
Expand All @@ -183,8 +183,8 @@ class ConfusionMatrix(SVConfusionMatrix):
@jaxtyped(typechecker=beartype)
def from_obb_tensors(
cls,
predictions: List[Float[ndarray, "... 10"]],
targets: List[Float[ndarray, "... 9"]],
predictions: List[Float[ndarray, "_ 10"]],
targets: List[Float[ndarray, "_ 9"]],
classes: List[str],
conf_threshold: float,
iou_threshold: float,
Expand Down Expand Up @@ -295,4 +295,126 @@ def evaluate_detection_obb_batch(
if not any(matched_detection_idx == i):
result_matrix[num_classes, detection_class_value] += 1 # FP

return result_matrix
return result_matrix

@property
@jaxtyped(typechecker=beartype)
def true_positives(self) -> Int[ndarray, "{len(self.classes)}"]:
"""
Calculate True Positives (TP) for each class.
Returns
-------
np.ndarray: True Positives for each class.
"""
return self.matrix.diagonal()[:-1].astype(int)

@property
@jaxtyped(typechecker=beartype)
def predicted_positives(self) -> Int[ndarray, "{len(self.classes)}"]:
"""
Calculate Predicted Positives (PP) for each class.
Returns
-------
np.ndarray: Predicted Positives for each class.
"""
return self.matrix.sum(axis=0)[:-1].astype(int)

@property
@jaxtyped(typechecker=beartype)
def false_positives(self) -> Int[ndarray, "{len(self.classes)}"]:
"""
Calculate False Positives (FP) for each class.
Returns
-------
np.ndarray: False Positives for each class.
"""
return self.predicted_positives - self.true_positives

@property
@jaxtyped(typechecker=beartype)
def actual_positives(self) -> Int[ndarray, "{len(self.classes)}"]:
"""
Calculate Actual Positives (AP) for each class.
Returns
-------
np.ndarray: Actual Positives for each class.
"""
return self.matrix.sum(axis=1)[:-1].astype(int)

@property
@jaxtyped(typechecker=beartype)
def false_negatives(self) -> Int[ndarray, "{len(self.classes)}"]:
"""
Calculate False Negatives (FN) for each class.
Returns
-------
np.ndarray: False Negatives for each class.
"""
return self.actual_positives - self.true_positives

@property
@jaxtyped(typechecker=beartype)
def precision(self) -> Float[ndarray, "{len(self.classes)}"]:
"""
Calculate precision for each class.
Returns
-------
np.ndarray: Precision for each class.
"""
precision = self.true_positives / self.predicted_positives
return precision

@property
@jaxtyped(typechecker=beartype)
def recall(self) -> Float[ndarray, "{len(self.classes)}"]:
"""
Calculate recall for each class.
Returns
-------
np.ndarray: Recall for each class.
"""
recall = self.true_positives / self.actual_positives
return recall

@property
@jaxtyped(typechecker=beartype)
def f1_score(self) -> Float[ndarray, "{len(self.classes)}"]:
"""
Calculate F1 score for each class.
Returns
-------
np.ndarray: F1 score for each class.
"""
# f1_score = 2 * (self.precision * self.recall) / (self.precision + self.recall)
# OR more efficiently
f1_score = 2 * self.true_positives / (self.predicted_positives + self.actual_positives)
return f1_score

@property
def summary(self) -> pd.DataFrame:
"""
Generate a summary DataFrame.
Returns
-------
pd.DataFrame: Summary DataFrame.
"""
summary_df = pd.DataFrame(columns=self.classes)

summary_df.loc["Actual Positives", self.classes] = self.actual_positives
summary_df.loc["Predicted Positives", self.classes] = self.predicted_positives
summary_df.loc["True Positives", self.classes] = self.true_positives
summary_df.loc["False Positives", self.classes] = self.false_positives
summary_df.loc["False Negatives", self.classes] = self.false_negatives
summary_df.loc["Precision", self.classes] = self.precision
summary_df.loc["Recall", self.classes] = self.recall
summary_df.loc["F1 Score", self.classes] = self.f1_score
return summary_df
3 changes: 3 additions & 0 deletions garuda/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,9 @@ def xyxyxyxy2xywhr(xyxyxyxy: Float[ndarray, "n 8"]) -> Float[ndarray, "n 5"]:
xywhr: Oriented Bounding Boxes in [x_c, y_c, w, h, r] format.
"""

if xyxyxyxy.shape[0] == 0:
return np.zeros((0, 5))

points = xyxyxyxy.reshape(len(xyxyxyxy), -1, 2)
rboxes = []
for pts in points:
Expand Down
59 changes: 57 additions & 2 deletions lab/testing.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -25,7 +25,62 @@
"from pystac.extensions.eo import EOExtension as eo\n",
"from shapely.geometry import box\n",
"from rioxarray.merge import merge_arrays\n",
"import rioxarray"
"import rioxarray\n",
"\n",
"from supervision.metrics.detection import ConfusionMatrix as SVConfusionMatrix"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ConfusionMatrix(matrix=array([[0., 0., 6.],\n",
" [0., 0., 0.],\n",
" [2., 0., 0.]]), classes=['A', 'B'], conf_threshold=0.3, iou_threshold=0.5)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"SVConfusionMatrix.from_tensors(predictions=[np.random.rand(3, 6), np.random.rand(0, 6)], targets=[np.random.rand(3, 5)]*2, classes = [\"A\", \"B\"])"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3\n",
"[0.15822135 0.01644572 0.63779084]\n"
]
}
],
"source": [
"from dataclasses import dataclass\n",
"\n",
"@dataclass\n",
"class ABC:\n",
" array = [1, 2, 3]\n",
" length: int\n",
" \n",
" @jaxtyped(typechecker=beartype)\n",
" def check(self) -> Float[ndarray, \"{len(self.array)}\"]:\n",
" return np.random.rand(self.length)\n",
"\n",
"abc = ABC(3)\n",
"print(abc.length)\n",
"print(abc.check())"
]
},
{
Expand Down

0 comments on commit edcfb19

Please sign in to comment.