diff --git a/.changeset/yellow-insects-dress.md b/.changeset/yellow-insects-dress.md new file mode 100644 index 0000000000000..518d93bf2495f --- /dev/null +++ b/.changeset/yellow-insects-dress.md @@ -0,0 +1,5 @@ +--- +"gradio": patch +--- + +fix:highlightedtext throws an error basing on model diff --git a/gradio/components/highlighted_text.py b/gradio/components/highlighted_text.py index 8d8c9b0d0b34e..690a1f0e8e7d0 100644 --- a/gradio/components/highlighted_text.py +++ b/gradio/components/highlighted_text.py @@ -25,7 +25,7 @@ class HighlightedText(Changeable, Selectable, IOComponent, JSONSerializable): """ Displays text that contains spans that are highlighted by category or numerical value. Preprocessing: this component does *not* accept input. - Postprocessing: expects a {List[Tuple[str, float | str]]]} consisting of spans of text and their associated labels, or a {Dict} with two keys: (1) "text" whose value is the complete text, and "entities", which is a list of dictionaries, each of which have the keys: "entity" (consisting of the entity label), "start" (the character index where the label starts), and "end" (the character index where the label ends). Entities should not overlap. + Postprocessing: expects a {List[Tuple[str, float | str]]]} consisting of spans of text and their associated labels, or a {Dict} with two keys: (1) "text" whose value is the complete text, and (2) "entities", which is a list of dictionaries, each of which have the keys: "entity" (consisting of the entity label, can alternatively be called "entity_group"), "start" (the character index where the label starts), and "end" (the character index where the label ends). Entities should not overlap. Demos: diff_texts, text_analysis Guides: named-entity-recognition @@ -135,7 +135,7 @@ def postprocess( ) -> list[tuple[str, str | float | None]] | None: """ Parameters: - y: List of (word, category) tuples + y: List of (word, category) tuples, or a dictionary of two keys: "text", and "entities", which itself is a list of dictionaries, each of which have the keys: "entity" (or "entity_group"), "start", and "end" Returns: List of (word, category) tuples """ @@ -158,8 +158,9 @@ def postprocess( entities = sorted(entities, key=lambda x: x["start"]) for entity in entities: list_format.append((text[index : entity["start"]], None)) + entity_category = entity.get("entity") or entity.get("entity_group") list_format.append( - (text[entity["start"] : entity["end"]], entity["entity"]) + (text[entity["start"] : entity["end"]], entity_category) ) index = entity["end"] list_format.append((text[index:], None)) diff --git a/test/test_components.py b/test/test_components.py index 61becfe779060..d6bb500ea4a97 100644 --- a/test/test_components.py +++ b/test/test_components.py @@ -1733,6 +1733,14 @@ def test_postprocess(self): result_ = component.postprocess({"text": text, "entities": entities}) assert result == result_ + text = "Wolfgang lives in Berlin" + entities = [ + {"entity_group": "PER", "start": 0, "end": 8}, + {"entity": "LOC", "start": 18, "end": 24}, + ] + result_ = component.postprocess({"text": text, "entities": entities}) + assert result == result_ + # Test split entity is merged when combine adjacent is set text = "Wolfgang lives in Berlin" entities = [