Skip to content

Commit

Permalink
Add relik transformer config (langchain-ai#25019)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasonjo authored Aug 3, 2024
1 parent 1dcee68 commit f9a11a9
Showing 1 changed file with 19 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Sequence
import logging
from typing import Any, Dict, List, Sequence

from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
from langchain_core.documents import Document
Expand All @@ -22,23 +23,33 @@ class RelikGraphTransformer:
model (str): The name of the pretrained Relik model to use.
Default is "relik-ie/relik-relation-extraction-small-wikipedia".
relationship_confidence_threshold (float): The confidence threshold for
filtering relationships. Default is 0.0.
filtering relationships. Default is 0.1.
model_config (Dict[str, any]): Additional configuration options for the
Relik model. Default is an empty dictionary.
ignore_self_loops (bool): Whether to ignore relationships where the
source and target nodes are the same. Default is True.
"""

def __init__(
self,
model: str = "relik-ie/relik-relation-extraction-small-wikipedia",
relationship_confidence_threshold: float = 0.0,
model: str = "relik-ie/relik-relation-extraction-small",
relationship_confidence_threshold: float = 0.1,
model_config: Dict[str, Any] = {},
ignore_self_loops: bool = True,
) -> None:
try:
import relik # type: ignore

# Remove default INFO logging
logging.getLogger("relik").setLevel(logging.WARNING)
except ImportError:
raise ImportError(
"Could not import relik python package. "
"Please install it with `pip install relik`."
)
self.relik_model = relik.Relik.from_pretrained(model)
self.relik_model = relik.Relik.from_pretrained(model, **model_config)
self.relationship_confidence_threshold = relationship_confidence_threshold
self.ignore_self_loops = ignore_self_loops

def process_document(self, document: Document) -> GraphDocument:
relik_out = self.relik_model(document.page_content)
Expand All @@ -60,6 +71,9 @@ def process_document(self, document: Document) -> GraphDocument:
# Ignore relationship if below confidence threshold
if triple.confidence < self.relationship_confidence_threshold:
continue
# Ignore self loops
if self.ignore_self_loops and triple.subject.text == triple.object.text:
continue
source_node = Node(
id=triple.subject.text,
type=DEFAULT_NODE_TYPE
Expand Down

0 comments on commit f9a11a9

Please sign in to comment.