diff --git a/test_unstructured/metrics/test_element_type.py b/test_unstructured/metrics/test_element_type.py new file mode 100644 index 0000000000..0440241405 --- /dev/null +++ b/test_unstructured/metrics/test_element_type.py @@ -0,0 +1,33 @@ +import pytest + +from unstructured.metrics.element_type import get_element_type_frequency +from unstructured.partition.auto import partition + + +@pytest.mark.parametrize( + ("filename", "frequency"), + [ + ( + "fake-email.txt", + { + "UncategorizedText": {"None": 6}, + "ListItem": {"None": 12}, + "Title": {"None": 5}, + "NarrativeText": {"None": 2}, + }, + ), + ( + "sample-presentation.pptx", + { + "Title": {"0": 4, "1": 1}, + "NarrativeText": {"0": 3}, + "ListItem": {"0": 6, "1": 6, "2": 3}, + "Table": {"None": 1}, + }, + ), + ], +) +def test_get_element_type_frequency(filename, frequency): + elements = partition(filename=f"example-docs/{filename}") + elements_freq = get_element_type_frequency(elements) + assert elements_freq == frequency diff --git a/unstructured/metrics/element_type.py b/unstructured/metrics/element_type.py new file mode 100644 index 0000000000..7ba87daf70 --- /dev/null +++ b/unstructured/metrics/element_type.py @@ -0,0 +1,19 @@ +from typing import Dict, List, Union + + +def get_element_type_frequency( + elements: List, +) -> Union[Dict[str, Dict[str, int]], Dict]: + frequency: Dict = {} + if len(elements) == 0: + return frequency + for element in elements: + category = element.category + category_depth = element.metadata.category_depth + if category not in frequency: + frequency[category] = {} + if str(category_depth) not in frequency[category]: + frequency[category][str(category_depth)] = 1 + else: + frequency[category][str(category_depth)] += 1 + return frequency \ No newline at end of file