diff --git a/libs/experimental/langchain_experimental/graph_transformers/relik.py b/libs/experimental/langchain_experimental/graph_transformers/relik.py index 94eeab10fa41a..a1e7dd0f6b072 100644 --- a/libs/experimental/langchain_experimental/graph_transformers/relik.py +++ b/libs/experimental/langchain_experimental/graph_transformers/relik.py @@ -1,4 +1,5 @@ -from typing import List, Sequence +import logging +from typing import Any, Dict, List, Sequence from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship from langchain_core.documents import Document @@ -22,23 +23,33 @@ class RelikGraphTransformer: model (str): The name of the pretrained Relik model to use. Default is "relik-ie/relik-relation-extraction-small-wikipedia". relationship_confidence_threshold (float): The confidence threshold for - filtering relationships. Default is 0.0. + filtering relationships. Default is 0.1. + model_config (Dict[str, any]): Additional configuration options for the + Relik model. Default is an empty dictionary. + ignore_self_loops (bool): Whether to ignore relationships where the + source and target nodes are the same. Default is True. """ def __init__( self, - model: str = "relik-ie/relik-relation-extraction-small-wikipedia", - relationship_confidence_threshold: float = 0.0, + model: str = "relik-ie/relik-relation-extraction-small", + relationship_confidence_threshold: float = 0.1, + model_config: Dict[str, Any] = {}, + ignore_self_loops: bool = True, ) -> None: try: import relik # type: ignore + + # Remove default INFO logging + logging.getLogger("relik").setLevel(logging.WARNING) except ImportError: raise ImportError( "Could not import relik python package. " "Please install it with `pip install relik`." ) - self.relik_model = relik.Relik.from_pretrained(model) + self.relik_model = relik.Relik.from_pretrained(model, **model_config) self.relationship_confidence_threshold = relationship_confidence_threshold + self.ignore_self_loops = ignore_self_loops def process_document(self, document: Document) -> GraphDocument: relik_out = self.relik_model(document.page_content) @@ -60,6 +71,9 @@ def process_document(self, document: Document) -> GraphDocument: # Ignore relationship if below confidence threshold if triple.confidence < self.relationship_confidence_threshold: continue + # Ignore self loops + if self.ignore_self_loops and triple.subject.text == triple.object.text: + continue source_node = Node( id=triple.subject.text, type=DEFAULT_NODE_TYPE