From 5d89b41b1a5bc8e549b4c5fa504476026d368c14 Mon Sep 17 00:00:00 2001 From: Pluto Date: Tue, 2 Jul 2024 12:07:24 +0200 Subject: [PATCH] Fix not counting false negatives and false positives in table metrics (#3300) This pull request fixes counting tables metric for three cases: - False Negatives: when table exist in ground truth but any of the predicted tables doesn't match the table, the table should count as 0 and the file should not be completely skipped (before it was np.NaN). - False Positives: When there is a predicted table that didn't match any ground truth table it should be counted as 0, right now it is skipped in processing (matched_indices==-1) - The file should be completely skipped only if there is no tables in ground truth and in prediction In short we can say that previous metric calculation didn't consider OD mistakes --- CHANGELOG.md | 4 +- .../metrics/test_table_alignment.py | 14 ++ .../metrics/test_table_structure.py | 155 ++++++++++++++++++ unstructured/__version__.py | 2 +- unstructured/metrics/table/table_alignment.py | 50 ++++-- unstructured/metrics/table/table_eval.py | 72 +++++--- 6 files changed, 250 insertions(+), 47 deletions(-) create mode 100644 test_unstructured/metrics/test_table_alignment.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 52caf2f090..1abbc38e3f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,13 +1,13 @@ -## 0.14.10-dev3 +## 0.14.10-dev4 ### Enhancements * **`.doc` files are now supported in the `arm64` image.**. `libreoffice24` is added to the `arm64` image, meaning `.doc` files are now supported. We have follow on work planned to investigate adding `.ppt` support for `arm64` as well. - ### Features ### Fixes +- Fix counting false negatives and false positives in table structure evaluation * **Fix Slack CI test** Change channel that Slack test is pointing to because previous test bot expired diff --git a/test_unstructured/metrics/test_table_alignment.py b/test_unstructured/metrics/test_table_alignment.py new file mode 100644 index 0000000000..1f012cff72 --- /dev/null +++ b/test_unstructured/metrics/test_table_alignment.py @@ -0,0 +1,14 @@ +from unstructured.metrics.table.table_alignment import TableAlignment + + +def test_get_element_level_alignment_when_no_match(): + example_table = [{"row_index": 0, "col_index": 0, "content": "a"}] + metrics = TableAlignment.get_element_level_alignment( + predicted_table_data=[example_table], + ground_truth_table_data=[example_table], + matched_indices=[-1], + ) + assert metrics["col_index_acc"] == 0 + assert metrics["row_index_acc"] == 0 + assert metrics["row_content_acc"] == 0 + assert metrics["col_content_acc"] == 0 diff --git a/test_unstructured/metrics/test_table_structure.py b/test_unstructured/metrics/test_table_structure.py index def97b3792..3c684be5a4 100644 --- a/test_unstructured/metrics/test_table_structure.py +++ b/test_unstructured/metrics/test_table_structure.py @@ -1,5 +1,9 @@ +from unittest import mock + +import numpy as np import pytest +from unstructured.metrics.table.table_alignment import TableAlignment from unstructured.metrics.table.table_eval import TableEvalProcessor from unstructured.metrics.table_structure import ( eval_table_transformer_for_file, @@ -542,3 +546,154 @@ def test_table_eval_processor_merged_cells(): assert result.element_col_level_index_acc == 1.0 assert result.element_row_level_content_acc == 1.0 assert result.element_col_level_content_acc == 1.0 + + +def test_table_eval_processor_when_no_match_with_pred(): + prediction = [ + { + "type": "Table", + "metadata": {"text_as_html": """
Some cell
"""}, + } + ] + + ground_truth = [ + { + "type": "Table", + "text": [ + { + "id": "ee862c7a-d27e-4484-92de-4faa42a63f3b", + "x": 0, + "y": 0, + "w": 1, + "h": 1, + "content": "11", + }, + { + "id": "6237ac7b-bfc8-40d2-92f2-d138277205e2", + "x": 0, + "y": 1, + "w": 1, + "h": 1, + "content": "21", + }, + { + "id": "9d0933a9-5984-4cad-80d9-6752bf9bc4df", + "x": 1, + "y": 0, + "w": 1, + "h": 1, + "content": "12", + }, + { + "id": "1152d043-5ead-4ab8-8b88-888d48831ac2", + "x": 1, + "y": 1, + "w": 1, + "h": 1, + "content": "22", + }, + ], + } + ] + + with mock.patch.object(TableAlignment, "get_table_level_alignment") as align_fn: + align_fn.return_value = [-1] + te_processor = TableEvalProcessor(prediction, ground_truth) + result = te_processor.process_file() + + assert result.total_tables == 1 + assert result.table_level_acc == 0 + assert result.element_row_level_index_acc == 0 + assert result.element_col_level_index_acc == 0 + assert result.element_row_level_content_acc == 0 + assert result.element_col_level_content_acc == 0 + + +def test_table_eval_processor_when_no_tables(): + prediction = [{}] + + ground_truth = [{}] + + te_processor = TableEvalProcessor(prediction, ground_truth) + result = te_processor.process_file() + assert result.total_tables == 0 + assert result.table_level_acc == 1 + assert np.isnan(result.element_row_level_index_acc) + assert np.isnan(result.element_col_level_index_acc) + assert np.isnan(result.element_row_level_content_acc) + assert np.isnan(result.element_col_level_content_acc) + + +def test_table_eval_processor_when_only_gt(): + prediction = [] + + ground_truth = [ + { + "type": "Table", + "text": [ + { + "id": "ee862c7a-d27e-4484-92de-4faa42a63f3b", + "x": 0, + "y": 0, + "w": 1, + "h": 1, + "content": "11", + }, + { + "id": "6237ac7b-bfc8-40d2-92f2-d138277205e2", + "x": 0, + "y": 1, + "w": 1, + "h": 1, + "content": "21", + }, + { + "id": "9d0933a9-5984-4cad-80d9-6752bf9bc4df", + "x": 1, + "y": 0, + "w": 1, + "h": 1, + "content": "12", + }, + { + "id": "1152d043-5ead-4ab8-8b88-888d48831ac2", + "x": 1, + "y": 1, + "w": 1, + "h": 1, + "content": "22", + }, + ], + } + ] + + te_processor = TableEvalProcessor(prediction, ground_truth) + result = te_processor.process_file() + + assert result.total_tables == 1 + assert result.table_level_acc == 0 + assert result.element_row_level_index_acc == 0 + assert result.element_col_level_index_acc == 0 + assert result.element_row_level_content_acc == 0 + assert result.element_col_level_content_acc == 0 + + +def test_table_eval_processor_when_only_pred(): + prediction = [ + { + "type": "Table", + "metadata": {"text_as_html": """
Some cell
"""}, + } + ] + + ground_truth = [{}] + + te_processor = TableEvalProcessor(prediction, ground_truth) + result = te_processor.process_file() + + assert result.total_tables == 0 + assert result.table_level_acc == 0 + assert result.element_row_level_index_acc == 0 + assert result.element_col_level_index_acc == 0 + assert result.element_row_level_content_acc == 0 + assert result.element_col_level_content_acc == 0 diff --git a/unstructured/__version__.py b/unstructured/__version__.py index 8486f2ef35..70050cabeb 100644 --- a/unstructured/__version__.py +++ b/unstructured/__version__.py @@ -1 +1 @@ -__version__ = "0.14.10-dev3" # pragma: no cover +__version__ = "0.14.10-dev4" # pragma: no cover diff --git a/unstructured/metrics/table/table_alignment.py b/unstructured/metrics/table/table_alignment.py index 35acc2625a..66ded1b698 100644 --- a/unstructured/metrics/table/table_alignment.py +++ b/unstructured/metrics/table/table_alignment.py @@ -74,14 +74,17 @@ def get_element_level_alignment( A dictionary with column and row alignment accuracies. """ - aligned_element_col_count = 0 - aligned_element_row_count = 0 - total_element_count = 0 content_diff_cols = [] content_diff_rows = [] + col_index_acc = [] + row_index_acc = [] for idx, td in zip(matched_indices, predicted_table_data): if idx == -1: + content_diff_cols.append(0) + content_diff_rows.append(0) + col_index_acc.append(0) + row_index_acc.append(0) continue ground_truth_td = ground_truth_table_data[idx] @@ -96,6 +99,9 @@ def get_element_level_alignment( content_diff_cols.append(table_content_diff["by_col_token_ratio"]) content_diff_rows.append(table_content_diff["by_row_token_ratio"]) + aligned_element_col_count = 0 + aligned_element_row_count = 0 + total_element_count = 0 # Get row and col index accuracy ground_truth_td_contents_list = [gtd["content"].lower() for gtd in ground_truth_td] used_indices = set() @@ -148,17 +154,27 @@ def get_element_level_alignment( aligned_element_col_count += 1 total_element_count += 1 - if total_element_count > 0: - col_index_acc = round(aligned_element_col_count / total_element_count, 2) - row_index_acc = round(aligned_element_row_count / total_element_count, 2) - col_content_acc = round(np.mean(content_diff_cols) / 100.0, 2) - row_content_acc = round(np.mean(content_diff_rows) / 100.0, 2) - - return { - "col_index_acc": col_index_acc, - "row_index_acc": row_index_acc, - "col_content_acc": col_content_acc, - "row_content_acc": row_content_acc, - } - - return {} + table_col_index_acc = 0 + table_row_index_acc = 0 + if total_element_count > 0: + table_col_index_acc = round(aligned_element_col_count / total_element_count, 2) + table_row_index_acc = round(aligned_element_row_count / total_element_count, 2) + + col_index_acc.append(table_col_index_acc) + row_index_acc.append(table_row_index_acc) + + not_found_gt_table_indexes = [ + id for id in range(len(ground_truth_table_data)) if id not in matched_indices + ] + for _ in not_found_gt_table_indexes: + content_diff_cols.append(0) + content_diff_rows.append(0) + col_index_acc.append(0) + row_index_acc.append(0) + + return { + "col_index_acc": round(np.mean(col_index_acc), 2), + "row_index_acc": round(np.mean(row_index_acc), 2), + "col_content_acc": round(np.mean(content_diff_cols) / 100.0, 2), + "row_content_acc": round(np.mean(content_diff_cols) / 100.0, 2), + } diff --git a/unstructured/metrics/table/table_eval.py b/unstructured/metrics/table/table_eval.py index 89d6bfa526..a25cf30c3d 100644 --- a/unstructured/metrics/table/table_eval.py +++ b/unstructured/metrics/table/table_eval.py @@ -200,37 +200,55 @@ def process_file(self) -> TableEvaluation: predicted_table_data = extract_and_convert_tables_from_prediction( file_elements=self.prediction, source_type=self.source_type ) - - matched_indices = TableAlignment.get_table_level_alignment( - predicted_table_data, - ground_truth_table_data, - ) - if matched_indices: + is_table_in_gt = bool(ground_truth_table_data) + is_table_predicted = bool(predicted_table_data) + if not is_table_in_gt: + # There is no table data in ground truth, you either got perfect score or 0 + score = 0 if is_table_predicted else np.nan + table_acc = 1 if not is_table_predicted else 0 + return TableEvaluation( + total_tables=0, + table_level_acc=table_acc, + element_col_level_index_acc=score, + element_row_level_index_acc=score, + element_col_level_content_acc=score, + element_row_level_content_acc=score, + ) + if is_table_in_gt and not is_table_predicted: + return TableEvaluation( + total_tables=len(ground_truth_table_data), + table_level_acc=0, + element_col_level_index_acc=0, + element_row_level_index_acc=0, + element_col_level_content_acc=0, + element_row_level_content_acc=0, + ) + else: + # We have both ground truth tables and predicted tables + matched_indices = TableAlignment.get_table_level_alignment( + predicted_table_data, + ground_truth_table_data, + ) predicted_table_acc = np.mean( table_level_acc(predicted_table_data, ground_truth_table_data, matched_indices) ) - elif ground_truth_table_data: - # no matching prediction but has actual table -> total failure - predicted_table_acc = 0 - else: - # no predicted and no actual table -> good job - predicted_table_acc = 1 - - metrics = TableAlignment.get_element_level_alignment( - predicted_table_data, - ground_truth_table_data, - matched_indices, - cutoff=self.cutoff, - ) - return TableEvaluation( - total_tables=len(ground_truth_table_data), - table_level_acc=predicted_table_acc, - element_col_level_index_acc=metrics.get("col_index_acc", np.nan), - element_row_level_index_acc=metrics.get("row_index_acc", np.nan), - element_col_level_content_acc=metrics.get("col_content_acc", np.nan), - element_row_level_content_acc=metrics.get("row_content_acc", np.nan), - ) + metrics = TableAlignment.get_element_level_alignment( + predicted_table_data, + ground_truth_table_data, + matched_indices, + cutoff=self.cutoff, + ) + + evaluation = TableEvaluation( + total_tables=len(ground_truth_table_data), + table_level_acc=predicted_table_acc, + element_col_level_index_acc=metrics.get("col_index_acc", 0), + element_row_level_index_acc=metrics.get("row_index_acc", 0), + element_col_level_content_acc=metrics.get("col_content_acc", 0), + element_row_level_content_acc=metrics.get("row_content_acc", 0), + ) + return evaluation @click.command()