From f12bfe29300ec8d1ae132eccb29438f87e7f2eae Mon Sep 17 00:00:00 2001 From: Michael Chin Date: Mon, 16 Dec 2024 23:41:58 -0800 Subject: [PATCH] Add Neptune chains --- libs/aws/langchain_aws/__init__.py | 3 + libs/aws/langchain_aws/chains/__init__.py | 27 ++ .../langchain_aws/chains/graph_qa/__init__.py | 7 + .../chains/graph_qa/neptune_cypher.py | 247 ++++++++++++++++++ .../chains/graph_qa/neptune_sparql.py | 236 +++++++++++++++++ .../langchain_aws/chains/graph_qa/prompts.py | 83 ++++++ libs/aws/langchain_aws/graphs/__init__.py | 3 +- libs/aws/pyproject.toml | 1 + 8 files changed, 606 insertions(+), 1 deletion(-) create mode 100644 libs/aws/langchain_aws/chains/__init__.py create mode 100644 libs/aws/langchain_aws/chains/graph_qa/__init__.py create mode 100644 libs/aws/langchain_aws/chains/graph_qa/neptune_cypher.py create mode 100644 libs/aws/langchain_aws/chains/graph_qa/neptune_sparql.py create mode 100644 libs/aws/langchain_aws/chains/graph_qa/prompts.py diff --git a/libs/aws/langchain_aws/__init__.py b/libs/aws/langchain_aws/__init__.py index 6683e118..53cc6c3d 100644 --- a/libs/aws/langchain_aws/__init__.py +++ b/libs/aws/langchain_aws/__init__.py @@ -1,3 +1,4 @@ +from langchain_aws.chains.graph_qa import NeptuneOpenCypherQAChain, NeptuneSparqlQAChain from langchain_aws.chat_models import ChatBedrock, ChatBedrockConverse from langchain_aws.embeddings import BedrockEmbeddings from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph @@ -19,6 +20,8 @@ "SagemakerEndpoint", "AmazonKendraRetriever", "AmazonKnowledgeBasesRetriever", + "NeptuneOpenCypherQAChain", + "NeptuneSparqlQAChain", "NeptuneAnalyticsGraph", "NeptuneGraph", "InMemoryVectorStore", diff --git a/libs/aws/langchain_aws/chains/__init__.py b/libs/aws/langchain_aws/chains/__init__.py new file mode 100644 index 00000000..7554fda6 --- /dev/null +++ b/libs/aws/langchain_aws/chains/__init__.py @@ -0,0 +1,27 @@ +import importlib +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from langchain_aws.chains.graph_qa.neptune_cypher import ( + NeptuneOpenCypherQAChain + ) + from langchain_aws.chains.graph_qa.neptune_sparql import ( + NeptuneSparqlQAChain + ) + +__all__ = [ + "NeptuneOpenCypherQAChain", + "NeptuneSparqlQAChain" +] + +_module_lookup = { + "NeptuneOpenCypherQAChain": "langchain_aws.chains.graph_qa.neptune_cypher", + "NeptuneSparqlQAChain": "langchain_aws.chains.graph_qa.neptune_sparql", +} + + +def __getattr__(name: str) -> Any: + if name in _module_lookup: + module = importlib.import_module(_module_lookup[name]) + return getattr(module, name) + raise AttributeError(f"module {__name__} has no attribute {name}") \ No newline at end of file diff --git a/libs/aws/langchain_aws/chains/graph_qa/__init__.py b/libs/aws/langchain_aws/chains/graph_qa/__init__.py new file mode 100644 index 00000000..cf842e02 --- /dev/null +++ b/libs/aws/langchain_aws/chains/graph_qa/__init__.py @@ -0,0 +1,7 @@ +from .neptune_cypher import NeptuneOpenCypherQAChain +from .neptune_sparql import NeptuneSparqlQAChain + +__all__ = [ + "NeptuneOpenCypherQAChain", + "NeptuneSparqlQAChain" +] \ No newline at end of file diff --git a/libs/aws/langchain_aws/chains/graph_qa/neptune_cypher.py b/libs/aws/langchain_aws/chains/graph_qa/neptune_cypher.py new file mode 100644 index 00000000..0a5fd3c0 --- /dev/null +++ b/libs/aws/langchain_aws/chains/graph_qa/neptune_cypher.py @@ -0,0 +1,247 @@ +from __future__ import annotations + +import re +from typing import Any, Dict, List, Optional + +from langchain.chains.base import Chain +from langchain.chains.llm import LLMChain +from langchain.chains.prompt_selector import ConditionalPromptSelector +from langchain_core.callbacks import CallbackManagerForChainRun +from langchain_core.language_models import BaseLanguageModel +from langchain_core.prompts.base import BasePromptTemplate +from pydantic import Field +from .prompts import ( + CYPHER_QA_PROMPT, + NEPTUNE_OPENCYPHER_GENERATION_PROMPT, + NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT, +) +from langchain_aws.graphs import BaseNeptuneGraph + +INTERMEDIATE_STEPS_KEY = "intermediate_steps" + + +def trim_query(query: str) -> str: + """Trim the query to only include Cypher keywords.""" + keywords = ( + "CALL", + "CREATE", + "DELETE", + "DETACH", + "LIMIT", + "MATCH", + "MERGE", + "OPTIONAL", + "ORDER", + "REMOVE", + "RETURN", + "SET", + "SKIP", + "UNWIND", + "WITH", + "WHERE", + "//", + ) + + lines = query.split("\n") + new_query = "" + + for line in lines: + if line.strip().upper().startswith(keywords): + new_query += line + "\n" + + return new_query + + +def extract_cypher(text: str) -> str: + """Extract Cypher code from text using Regex.""" + # The pattern to find Cypher code enclosed in triple backticks + pattern = r"```(.*?)```" + + # Find all matches in the input text + matches = re.findall(pattern, text, re.DOTALL) + + return matches[0] if matches else text + + +def use_simple_prompt(llm: BaseLanguageModel) -> bool: + """Decides whether to use the simple prompt""" + if llm._llm_type and "anthropic" in llm._llm_type: # type: ignore + return True + + # Bedrock anthropic + if hasattr(llm, "model_id") and "anthropic" in llm.model_id: # type: ignore + return True + + return False + + +PROMPT_SELECTOR = ConditionalPromptSelector( + default_prompt=NEPTUNE_OPENCYPHER_GENERATION_PROMPT, + conditionals=[(use_simple_prompt, NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT)], +) + + +class NeptuneOpenCypherQAChain(Chain): + """Chain for question-answering against a Neptune graph + by generating openCypher statements. + + *Security note*: Make sure that the database connection uses credentials + that are narrowly-scoped to only include necessary permissions. + Failure to do so may result in data corruption or loss, since the calling + code may attempt commands that would result in deletion, mutation + of data if appropriately prompted or reading sensitive data if such + data is present in the database. + The best way to guard against such negative outcomes is to (as appropriate) + limit the permissions granted to the credentials used with this tool. + + See https://python.langchain.com/docs/security for more information. + + Example: + .. code-block:: python + + chain = NeptuneOpenCypherQAChain.from_llm( + llm=llm, + graph=graph + ) + response = chain.run(query) + """ + + graph: BaseNeptuneGraph = Field(exclude=True) + cypher_generation_chain: LLMChain + qa_chain: LLMChain + input_key: str = "query" #: :meta private: + output_key: str = "result" #: :meta private: + top_k: int = 10 + return_intermediate_steps: bool = False + """Whether or not to return the intermediate steps along with the final answer.""" + return_direct: bool = False + """Whether or not to return the result of querying the graph directly.""" + extra_instructions: Optional[str] = None + """Extra instructions by the appended to the query generation prompt.""" + + allow_dangerous_requests: bool = False + """Forced user opt-in to acknowledge that the chain can make dangerous requests. + + *Security note*: Make sure that the database connection uses credentials + that are narrowly-scoped to only include necessary permissions. + Failure to do so may result in data corruption or loss, since the calling + code may attempt commands that would result in deletion, mutation + of data if appropriately prompted or reading sensitive data if such + data is present in the database. + The best way to guard against such negative outcomes is to (as appropriate) + limit the permissions granted to the credentials used with this tool. + + See https://python.langchain.com/docs/security for more information. + """ + + def __init__(self, **kwargs: Any) -> None: + """Initialize the chain.""" + super().__init__(**kwargs) + if self.allow_dangerous_requests is not True: + raise ValueError( + "In order to use this chain, you must acknowledge that it can make " + "dangerous requests by setting `allow_dangerous_requests` to `True`." + "You must narrowly scope the permissions of the database connection " + "to only include necessary permissions. Failure to do so may result " + "in data corruption or loss or reading sensitive data if such data is " + "present in the database." + "Only use this chain if you understand the risks and have taken the " + "necessary precautions. " + "See https://python.langchain.com/docs/security for more information." + ) + + @property + def input_keys(self) -> List[str]: + """Return the input keys. + + :meta private: + """ + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Return the output keys. + + :meta private: + """ + _output_keys = [self.output_key] + return _output_keys + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + *, + qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT, + cypher_prompt: Optional[BasePromptTemplate] = None, + extra_instructions: Optional[str] = None, + **kwargs: Any, + ) -> NeptuneOpenCypherQAChain: + """Initialize from LLM.""" + qa_chain = LLMChain(llm=llm, prompt=qa_prompt) + + _cypher_prompt = cypher_prompt or PROMPT_SELECTOR.get_prompt(llm) + cypher_generation_chain = LLMChain(llm=llm, prompt=_cypher_prompt) + + return cls( + qa_chain=qa_chain, + cypher_generation_chain=cypher_generation_chain, + extra_instructions=extra_instructions, + **kwargs, + ) + + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: + """Generate Cypher statement, use it to look up in db and answer question.""" + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + callbacks = _run_manager.get_child() + question = inputs[self.input_key] + + intermediate_steps: List = [] + + generated_cypher = self.cypher_generation_chain.run( + { + "question": question, + "schema": self.graph.get_schema, + "extra_instructions": self.extra_instructions or "", + }, + callbacks=callbacks, + ) + + # Extract Cypher code if it is wrapped in backticks + generated_cypher = extract_cypher(generated_cypher) + generated_cypher = trim_query(generated_cypher) + + _run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose) + _run_manager.on_text( + generated_cypher, color="green", end="\n", verbose=self.verbose + ) + + intermediate_steps.append({"query": generated_cypher}) + + context = self.graph.query(generated_cypher) + + if self.return_direct: + final_result = context + else: + _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose) + _run_manager.on_text( + str(context), color="green", end="\n", verbose=self.verbose + ) + + intermediate_steps.append({"context": context}) + + result = self.qa_chain( + {"question": question, "context": context}, + callbacks=callbacks, + ) + final_result = result[self.qa_chain.output_key] + + chain_result: Dict[str, Any] = {self.output_key: final_result} + if self.return_intermediate_steps: + chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps + + return chain_result diff --git a/libs/aws/langchain_aws/chains/graph_qa/neptune_sparql.py b/libs/aws/langchain_aws/chains/graph_qa/neptune_sparql.py new file mode 100644 index 00000000..3cc1b9b5 --- /dev/null +++ b/libs/aws/langchain_aws/chains/graph_qa/neptune_sparql.py @@ -0,0 +1,236 @@ +""" +Question answering over an RDF or OWL graph using SPARQL. +""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +from langchain.chains.base import Chain +from langchain.chains.llm import LLMChain +from langchain_core.callbacks.manager import CallbackManagerForChainRun +from langchain_core.language_models import BaseLanguageModel +from langchain_core.prompts.base import BasePromptTemplate +from langchain_core.prompts.prompt import PromptTemplate +from pydantic import Field + +from langchain_aws.chains.graph_qa.prompts import SPARQL_QA_PROMPT +from langchain_aws.graphs import NeptuneRdfGraph + +INTERMEDIATE_STEPS_KEY = "intermediate_steps" + +SPARQL_GENERATION_TEMPLATE = """ +Task: Generate a SPARQL SELECT statement for querying a graph database. +For instance, to find all email addresses of John Doe, the following +query in backticks would be suitable: +``` +PREFIX foaf: +SELECT ?email +WHERE {{ + ?person foaf:name "John Doe" . + ?person foaf:mbox ?email . +}} +``` +Instructions: +Use only the node types and properties provided in the schema. +Do not use any node types and properties that are not explicitly provided. +Include all necessary prefixes. + +Examples: + +Schema: +{schema} +Note: Be as concise as possible. +Do not include any explanations or apologies in your responses. +Do not respond to any questions that ask for anything else than +for you to construct a SPARQL query. +Do not include any text except the SPARQL query generated. + +The question is: +{prompt}""" + +SPARQL_GENERATION_PROMPT = PromptTemplate( + input_variables=["schema", "prompt"], template=SPARQL_GENERATION_TEMPLATE +) + + +def extract_sparql(query: str) -> str: + """Extract SPARQL code from a text. + + Args: + query: Text to extract SPARQL code from. + + Returns: + SPARQL code extracted from the text. + """ + query = query.strip() + querytoks = query.split("```") + if len(querytoks) == 3: + query = querytoks[1] + + if query.startswith("sparql"): + query = query[6:] + elif query.startswith("") and query.endswith(""): + query = query[8:-9] + return query + + +class NeptuneSparqlQAChain(Chain): + """Chain for question-answering against a Neptune graph + by generating SPARQL statements. + + *Security note*: Make sure that the database connection uses credentials + that are narrowly-scoped to only include necessary permissions. + Failure to do so may result in data corruption or loss, since the calling + code may attempt commands that would result in deletion, mutation + of data if appropriately prompted or reading sensitive data if such + data is present in the database. + The best way to guard against such negative outcomes is to (as appropriate) + limit the permissions granted to the credentials used with this tool. + + See https://python.langchain.com/docs/security for more information. + + Example: + .. code-block:: python + + chain = NeptuneSparqlQAChain.from_llm( + llm=llm, + graph=graph + ) + response = chain.invoke(query) + """ + + graph: NeptuneRdfGraph = Field(exclude=True) + sparql_generation_chain: LLMChain + qa_chain: LLMChain + input_key: str = "query" #: :meta private: + output_key: str = "result" #: :meta private: + top_k: int = 10 + return_intermediate_steps: bool = False + """Whether or not to return the intermediate steps along with the final answer.""" + return_direct: bool = False + """Whether or not to return the result of querying the graph directly.""" + extra_instructions: Optional[str] = None + """Extra instructions by the appended to the query generation prompt.""" + + allow_dangerous_requests: bool = False + """Forced user opt-in to acknowledge that the chain can make dangerous requests. + + *Security note*: Make sure that the database connection uses credentials + that are narrowly-scoped to only include necessary permissions. + Failure to do so may result in data corruption or loss, since the calling + code may attempt commands that would result in deletion, mutation + of data if appropriately prompted or reading sensitive data if such + data is present in the database. + The best way to guard against such negative outcomes is to (as appropriate) + limit the permissions granted to the credentials used with this tool. + + See https://python.langchain.com/docs/security for more information. + """ + + def __init__(self, **kwargs: Any) -> None: + """Initialize the chain.""" + super().__init__(**kwargs) + if self.allow_dangerous_requests is not True: + raise ValueError( + "In order to use this chain, you must acknowledge that it can make " + "dangerous requests by setting `allow_dangerous_requests` to `True`." + "You must narrowly scope the permissions of the database connection " + "to only include necessary permissions. Failure to do so may result " + "in data corruption or loss or reading sensitive data if such data is " + "present in the database." + "Only use this chain if you understand the risks and have taken the " + "necessary precautions. " + "See https://python.langchain.com/docs/security for more information." + ) + + @property + def input_keys(self) -> List[str]: + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + _output_keys = [self.output_key] + return _output_keys + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + *, + qa_prompt: BasePromptTemplate = SPARQL_QA_PROMPT, + sparql_prompt: BasePromptTemplate = SPARQL_GENERATION_PROMPT, + examples: Optional[str] = None, + **kwargs: Any, + ) -> NeptuneSparqlQAChain: + """Initialize from LLM.""" + qa_chain = LLMChain(llm=llm, prompt=qa_prompt) + template_to_use = SPARQL_GENERATION_TEMPLATE + if examples: + template_to_use = template_to_use.replace( + "Examples:", "Examples: " + examples + ) + sparql_prompt = PromptTemplate( + input_variables=["schema", "prompt"], template=template_to_use + ) + sparql_generation_chain = LLMChain(llm=llm, prompt=sparql_prompt) + + return cls( # type: ignore[call-arg] + qa_chain=qa_chain, + sparql_generation_chain=sparql_generation_chain, + examples=examples, + **kwargs, + ) + + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + """ + Generate SPARQL query, use it to retrieve a response from the gdb and answer + the question. + """ + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + callbacks = _run_manager.get_child() + prompt = inputs[self.input_key] + + intermediate_steps: List = [] + + generated_sparql = self.sparql_generation_chain.run( + {"prompt": prompt, "schema": self.graph.get_schema}, callbacks=callbacks + ) + + # Extract SPARQL + generated_sparql = extract_sparql(generated_sparql) + + _run_manager.on_text("Generated SPARQL:", end="\n", verbose=self.verbose) + _run_manager.on_text( + generated_sparql, color="green", end="\n", verbose=self.verbose + ) + + intermediate_steps.append({"query": generated_sparql}) + + context = self.graph.query(generated_sparql) + + if self.return_direct: + final_result = context + else: + _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose) + _run_manager.on_text( + str(context), color="green", end="\n", verbose=self.verbose + ) + + intermediate_steps.append({"context": context}) + + result = self.qa_chain( + {"prompt": prompt, "context": context}, + callbacks=callbacks, + ) + final_result = result[self.qa_chain.output_key] + + chain_result: Dict[str, Any] = {self.output_key: final_result} + if self.return_intermediate_steps: + chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps + + return chain_result diff --git a/libs/aws/langchain_aws/chains/graph_qa/prompts.py b/libs/aws/langchain_aws/chains/graph_qa/prompts.py new file mode 100644 index 00000000..4335edb0 --- /dev/null +++ b/libs/aws/langchain_aws/chains/graph_qa/prompts.py @@ -0,0 +1,83 @@ +# flake8: noqa +from langchain_core.prompts.prompt import PromptTemplate + +CYPHER_GENERATION_TEMPLATE = """Task:Generate Cypher statement to query a graph database. +Instructions: +Use only the provided relationship types and properties in the schema. +Do not use any other relationship types or properties that are not provided. +Schema: +{schema} +Note: Do not include any explanations or apologies in your responses. +Do not respond to any questions that might ask anything else than for you to construct a Cypher statement. +Do not include any text except the generated Cypher statement. + +The question is: +{question}""" +CYPHER_GENERATION_PROMPT = PromptTemplate( + input_variables=["schema", "question"], template=CYPHER_GENERATION_TEMPLATE +) + +CYPHER_QA_TEMPLATE = """You are an assistant that helps to form nice and human understandable answers. +The information part contains the provided information that you must use to construct an answer. +The provided information is authoritative, you must never doubt it or try to use your internal knowledge to correct it. +Make the answer sound as a response to the question. Do not mention that you based the result on the given information. +Here is an example: + +Question: Which managers own Neo4j stocks? +Context:[manager:CTL LLC, manager:JANE STREET GROUP LLC] +Helpful Answer: CTL LLC, JANE STREET GROUP LLC owns Neo4j stocks. + +Follow this example when generating answers. +If the provided information is empty, say that you don't know the answer. +Information: +{context} + +Question: {question} +Helpful Answer:""" +CYPHER_QA_PROMPT = PromptTemplate( + input_variables=["context", "question"], template=CYPHER_QA_TEMPLATE +) + +SPARQL_QA_TEMPLATE = """Task: Generate a natural language response from the results of a SPARQL query. +You are an assistant that creates well-written and human understandable answers. +The information part contains the information provided, which you can use to construct an answer. +The information provided is authoritative, you must never doubt it or try to use your internal knowledge to correct it. +Make your response sound like the information is coming from an AI assistant, but don't add any information. +Information: +{context} + +Question: {prompt} +Helpful Answer:""" +SPARQL_QA_PROMPT = PromptTemplate( + input_variables=["context", "prompt"], template=SPARQL_QA_TEMPLATE +) + +NEPTUNE_OPENCYPHER_EXTRA_INSTRUCTIONS = """ +Instructions: +Generate the query in openCypher format and follow these rules: +Do not use `NONE`, `ALL` or `ANY` predicate functions, rather use list comprehensions. +Do not use `REDUCE` function. Rather use a combination of list comprehension and the `UNWIND` clause to achieve similar results. +Do not use `FOREACH` clause. Rather use a combination of `WITH` and `UNWIND` clauses to achieve similar results.{extra_instructions} +\n""" + +NEPTUNE_OPENCYPHER_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace( + "Instructions:", NEPTUNE_OPENCYPHER_EXTRA_INSTRUCTIONS +) + +NEPTUNE_OPENCYPHER_GENERATION_PROMPT = PromptTemplate( + input_variables=["schema", "question", "extra_instructions"], + template=NEPTUNE_OPENCYPHER_GENERATION_TEMPLATE, +) + +NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_TEMPLATE = """ +Write an openCypher query to answer the following question. Do not explain the answer. Only return the query.{extra_instructions} +Question: "{question}". +Here is the property graph schema: +{schema} +\n""" + +NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT = PromptTemplate( + input_variables=["schema", "question", "extra_instructions"], + template=NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_TEMPLATE, +) + diff --git a/libs/aws/langchain_aws/graphs/__init__.py b/libs/aws/langchain_aws/graphs/__init__.py index 1aa136a4..174ba471 100644 --- a/libs/aws/langchain_aws/graphs/__init__.py +++ b/libs/aws/langchain_aws/graphs/__init__.py @@ -3,5 +3,6 @@ NeptuneAnalyticsGraph, NeptuneGraph, ) +from langchain_aws.graphs.neptune_rdf_graph import NeptuneRdfGraph -__all__ = ["BaseNeptuneGraph", "NeptuneAnalyticsGraph", "NeptuneGraph"] +__all__ = ["BaseNeptuneGraph", "NeptuneAnalyticsGraph", "NeptuneGraph", "NeptuneRdfGraph"] diff --git a/libs/aws/pyproject.toml b/libs/aws/pyproject.toml index 9c7d35db..7e0cc881 100644 --- a/libs/aws/pyproject.toml +++ b/libs/aws/pyproject.toml @@ -13,6 +13,7 @@ license = "MIT" [tool.poetry.dependencies] python = ">=3.9,<4.0" langchain-core = ">=0.3.15,<0.4" +langchain = ">=0.3.11,<0.4" boto3 = ">=1.35.74" pydantic = ">=2,<3" numpy = [