diff --git a/libs/aws/langchain_aws/__init__.py b/libs/aws/langchain_aws/__init__.py index 6683e118..94f1f447 100644 --- a/libs/aws/langchain_aws/__init__.py +++ b/libs/aws/langchain_aws/__init__.py @@ -1,3 +1,4 @@ +from langchain_aws.chains.graph_qa import create_neptune_sparql_qa_chain, create_neptune_opencypher_qa_chain from langchain_aws.chat_models import ChatBedrock, ChatBedrockConverse from langchain_aws.embeddings import BedrockEmbeddings from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph @@ -19,6 +20,8 @@ "SagemakerEndpoint", "AmazonKendraRetriever", "AmazonKnowledgeBasesRetriever", + "create_neptune_opencypher_qa_chain", + "create_neptune_sparql_qa_chain", "NeptuneAnalyticsGraph", "NeptuneGraph", "InMemoryVectorStore", diff --git a/libs/aws/langchain_aws/chains/__init__.py b/libs/aws/langchain_aws/chains/__init__.py new file mode 100644 index 00000000..13c9984d --- /dev/null +++ b/libs/aws/langchain_aws/chains/__init__.py @@ -0,0 +1,27 @@ +import importlib +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from langchain_aws.chains.graph_qa.neptune_cypher import ( + create_neptune_opencypher_qa_chain + ) + from langchain_aws.chains.graph_qa.neptune_sparql import ( + create_neptune_sparql_qa_chain + ) + +__all__ = [ + "create_neptune_opencypher_qa_chain", + "create_neptune_sparql_qa_chain" +] + +_module_lookup = { + "create_neptune_opencypher_qa_chain": "langchain_aws.chains.graph_qa.neptune_cypher", + "create_neptune_sparql_qa_chain": "langchain_aws.chains.graph_qa.neptune_sparql", +} + + +def __getattr__(name: str) -> Any: + if name in _module_lookup: + module = importlib.import_module(_module_lookup[name]) + return getattr(module, name) + raise AttributeError(f"module {__name__} has no attribute {name}") \ No newline at end of file diff --git a/libs/aws/langchain_aws/chains/graph_qa/__init__.py b/libs/aws/langchain_aws/chains/graph_qa/__init__.py new file mode 100644 index 00000000..cd327ecb --- /dev/null +++ b/libs/aws/langchain_aws/chains/graph_qa/__init__.py @@ -0,0 +1,7 @@ +from .neptune_cypher import create_neptune_opencypher_qa_chain +from .neptune_sparql import create_neptune_sparql_qa_chain + +__all__ = [ + "create_neptune_opencypher_qa_chain", + "create_neptune_sparql_qa_chain" +] \ No newline at end of file diff --git a/libs/aws/langchain_aws/chains/graph_qa/neptune_cypher.py b/libs/aws/langchain_aws/chains/graph_qa/neptune_cypher.py new file mode 100644 index 00000000..054b5916 --- /dev/null +++ b/libs/aws/langchain_aws/chains/graph_qa/neptune_cypher.py @@ -0,0 +1,187 @@ +from __future__ import annotations + +import re +from typing import Any, Optional + +from langchain_core.language_models import BaseLanguageModel +from langchain_core.prompts.base import BasePromptTemplate +from langchain_core.runnables import Runnable, RunnablePassthrough +from .prompts import ( + CYPHER_QA_PROMPT, + NEPTUNE_OPENCYPHER_GENERATION_PROMPT, + NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT, +) +from langchain_aws.graphs import BaseNeptuneGraph + +INTERMEDIATE_STEPS_KEY = "intermediate_steps" + + +def trim_query(query: str) -> str: + """Trim the query to only include Cypher keywords.""" + keywords = ( + "CALL", + "CREATE", + "DELETE", + "DETACH", + "LIMIT", + "MATCH", + "MERGE", + "OPTIONAL", + "ORDER", + "REMOVE", + "RETURN", + "SET", + "SKIP", + "UNWIND", + "WITH", + "WHERE", + "//", + ) + + lines = query.split("\n") + new_query = "" + + for line in lines: + if line.strip().upper().startswith(keywords): + new_query += line + "\n" + + return new_query + + +def extract_cypher(text: str) -> str: + """Extract Cypher code from text using Regex.""" + # The pattern to find Cypher code enclosed in triple backticks + pattern = r"```(.*?)```" + + # Find all matches in the input text + matches = re.findall(pattern, text, re.DOTALL) + + return matches[0] if matches else text + + +def use_simple_prompt(llm: BaseLanguageModel) -> bool: + """Decides whether to use the simple prompt""" + if llm._llm_type and "anthropic" in llm._llm_type: # type: ignore + return True + + # Bedrock anthropic + if hasattr(llm, "model_id") and "anthropic" in llm.model_id: # type: ignore + return True + + return False + + +def get_prompt(llm: BaseLanguageModel) -> BasePromptTemplate: + """Selects the final prompt""" + if use_simple_prompt(llm): + return NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT + else: + return NEPTUNE_OPENCYPHER_GENERATION_PROMPT + + +def create_neptune_opencypher_qa_chain( + llm: BaseLanguageModel, + graph: BaseNeptuneGraph, + qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT, + cypher_prompt: Optional[BasePromptTemplate] = None, + return_intermediate_steps: bool = False, + return_direct: bool = False, + extra_instructions: Optional[str] = None, + allow_dangerous_requests: bool = False +) -> Runnable[dict[str, Any], dict]: + """Chain for question-answering against a Neptune graph + by generating openCypher statements. + + *Security note*: Make sure that the database connection uses credentials + that are narrowly-scoped to only include necessary permissions. + Failure to do so may result in data corruption or loss, since the calling + code may attempt commands that would result in deletion, mutation + of data if appropriately prompted or reading sensitive data if such + data is present in the database. + The best way to guard against such negative outcomes is to (as appropriate) + limit the permissions granted to the credentials used with this tool. + + See https://python.langchain.com/docs/security for more information. + + Example: + .. code-block:: python + + chain = create_neptune_opencypher_qa_chain( + llm=llm, + graph=graph + ) + response = chain.invoke({"query": "your_query_here"}) + """ + + if allow_dangerous_requests is not True: + raise ValueError( + "In order to use this chain, you must acknowledge that it can make " + "dangerous requests by setting `allow_dangerous_requests` to `True`. " + "You must narrowly scope the permissions of the database connection " + "to only include necessary permissions. Failure to do so may result " + "in data corruption or loss or reading sensitive data if such data is " + "present in the database. " + "Only use this chain if you understand the risks and have taken the " + "necessary precautions. " + "See https://python.langchain.com/docs/security for more information." + ) + + qa_chain = qa_prompt | llm + + _cypher_prompt = cypher_prompt or get_prompt(llm) + cypher_generation_chain = _cypher_prompt | llm + + def execute_graph_query(cypher_query: str) -> dict: + return graph.query(cypher_query) + + def get_cypher_inputs(inputs: dict) -> dict: + return { + "question": inputs["query"], + "schema": graph.get_schema, + "extra_instructions": extra_instructions or "", + } + + def get_qa_inputs(inputs: dict) -> dict: + return { + "question": inputs["query"], + "context": inputs["context"], + } + + def format_response(inputs: dict) -> dict: + intermediate_steps = [ + {"query": inputs["cypher"]} + ] + + if return_direct: + final_response = {"result": inputs["context"]} + else: + final_response = {"result": inputs["qa_result"]} + intermediate_steps.append({"context": inputs["context"]}) + + if return_intermediate_steps: + final_response[INTERMEDIATE_STEPS_KEY] = intermediate_steps + + return final_response + + chain_result = ( + RunnablePassthrough.assign( + cypher_generation_inputs=get_cypher_inputs + ) + | { + "query": lambda x: x["query"], + "cypher": (lambda x: x["cypher_generation_inputs"]) + | cypher_generation_chain + | (lambda x: extract_cypher(x.content)) + | trim_query + } + | RunnablePassthrough.assign( + context=lambda x: execute_graph_query(x["cypher"]) + ) + | RunnablePassthrough.assign( + qa_result=(lambda x: get_qa_inputs(x)) + | qa_chain + ) + | format_response + ) + + return chain_result diff --git a/libs/aws/langchain_aws/chains/graph_qa/neptune_sparql.py b/libs/aws/langchain_aws/chains/graph_qa/neptune_sparql.py new file mode 100644 index 00000000..618c0bb6 --- /dev/null +++ b/libs/aws/langchain_aws/chains/graph_qa/neptune_sparql.py @@ -0,0 +1,161 @@ +""" +Question answering over an RDF or OWL graph using SPARQL. +""" + +from __future__ import annotations + +from typing import Any, Optional + +from langchain_core.language_models import BaseLanguageModel +from langchain_core.prompts.base import BasePromptTemplate +from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.runnables import Runnable, RunnablePassthrough +from .prompts import ( + SPARQL_QA_PROMPT, + NEPTUNE_SPARQL_GENERATION_TEMPLATE, + NEPTUNE_SPARQL_GENERATION_PROMPT, +) +from langchain_aws.graphs import NeptuneRdfGraph + +INTERMEDIATE_STEPS_KEY = "intermediate_steps" + + +def extract_sparql(query: str) -> str: + """Extract SPARQL code from a text. + + Args: + query: Text to extract SPARQL code from. + + Returns: + SPARQL code extracted from the text. + """ + query = query.strip() + querytoks = query.split("```") + if len(querytoks) == 3: + query = querytoks[1] + + if query.startswith("sparql"): + query = query[6:] + elif query.startswith("") and query.endswith(""): + query = query[8:-9] + return query + + +def get_prompt(examples: str) -> BasePromptTemplate: + """Selects the final prompt.""" + template_to_use = NEPTUNE_SPARQL_GENERATION_TEMPLATE + if examples: + template_to_use = template_to_use.replace( + "Examples:", "Examples: " + examples + ) + return PromptTemplate( + input_variables=["schema", "prompt"], template=template_to_use + ) + return NEPTUNE_SPARQL_GENERATION_PROMPT + + +def create_neptune_sparql_qa_chain( + llm: BaseLanguageModel, + graph: NeptuneRdfGraph, + qa_prompt: BasePromptTemplate = SPARQL_QA_PROMPT, + sparql_prompt: Optional[BasePromptTemplate] = None, + return_intermediate_steps: bool = False, + return_direct: bool = False, + extra_instructions: Optional[str] = None, + allow_dangerous_requests: bool = False, + examples: Optional[str] = None, +) -> Runnable[dict[str, Any], dict]: + """Chain for question-answering against a Neptune graph + by generating SPARQL statements. + + *Security note*: Make sure that the database connection uses credentials + that are narrowly-scoped to only include necessary permissions. + Failure to do so may result in data corruption or loss, since the calling + code may attempt commands that would result in deletion, mutation + of data if appropriately prompted or reading sensitive data if such + data is present in the database. + The best way to guard against such negative outcomes is to (as appropriate) + limit the permissions granted to the credentials used with this tool. + + See https://python.langchain.com/docs/security for more information. + + Example: + .. code-block:: python + + chain = create_neptune_sparql_qa_chain( + llm=llm, + graph=graph + ) + response = chain.invoke({"query": "your_query_here"}) + """ + if allow_dangerous_requests is not True: + raise ValueError( + "In order to use this chain, you must acknowledge that it can make " + "dangerous requests by setting `allow_dangerous_requests` to `True`. " + "You must narrowly scope the permissions of the database connection " + "to only include necessary permissions. Failure to do so may result " + "in data corruption or loss or reading sensitive data if such data is " + "present in the database. " + "Only use this chain if you understand the risks and have taken the " + "necessary precautions. " + "See https://python.langchain.com/docs/security for more information." + ) + + qa_chain = qa_prompt | llm + + _sparql_prompt = sparql_prompt or get_prompt(examples) + sparql_generation_chain = _sparql_prompt | llm + + def execute_graph_query(sparql_query: str) -> dict: + return graph.query(sparql_query) + + def get_sparql_inputs(inputs: dict) -> dict: + return { + "prompt": inputs["query"], + "schema": graph.get_schema, + "extra_instructions": extra_instructions or "", + } + + def get_qa_inputs(inputs: dict) -> dict: + return { + "prompt": inputs["query"], + "context": inputs["context"], + } + + def format_response(inputs: dict) -> dict: + intermediate_steps = [ + {"query": inputs["sparql"]} + ] + + if return_direct: + final_response = {"result": inputs["context"]} + else: + final_response = {"result": inputs["qa_result"]} + intermediate_steps.append({"context": inputs["context"]}) + + if return_intermediate_steps: + final_response[INTERMEDIATE_STEPS_KEY] = intermediate_steps + + return final_response + + chain_result = ( + RunnablePassthrough.assign( + sparql_generation_inputs=get_sparql_inputs + ) + | { + "query": lambda x: x["query"], + "sparql": (lambda x: x["sparql_generation_inputs"]) + | sparql_generation_chain + | (lambda x: extract_sparql(x.content)) + } + | RunnablePassthrough.assign( + context=lambda x: execute_graph_query(x["sparql"]) + ) + | RunnablePassthrough.assign( + qa_result=(lambda x: get_qa_inputs(x)) + | qa_chain + ) + | format_response + ) + + return chain_result diff --git a/libs/aws/langchain_aws/chains/graph_qa/prompts.py b/libs/aws/langchain_aws/chains/graph_qa/prompts.py new file mode 100644 index 00000000..e1bfdfde --- /dev/null +++ b/libs/aws/langchain_aws/chains/graph_qa/prompts.py @@ -0,0 +1,116 @@ +# flake8: noqa +from langchain_core.prompts.prompt import PromptTemplate + +CYPHER_GENERATION_TEMPLATE = """Task:Generate Cypher statement to query a graph database. +Instructions: +Use only the provided relationship types and properties in the schema. +Do not use any other relationship types or properties that are not provided. +Schema: +{schema} +Note: Do not include any explanations or apologies in your responses. +Do not respond to any questions that might ask anything else than for you to construct a Cypher statement. +Do not include any text except the generated Cypher statement. + +The question is: +{question}""" +CYPHER_GENERATION_PROMPT = PromptTemplate( + input_variables=["schema", "question"], template=CYPHER_GENERATION_TEMPLATE +) + +CYPHER_QA_TEMPLATE = """You are an assistant that helps to form nice and human understandable answers. +The information part contains the provided information that you must use to construct an answer. +The provided information is authoritative, you must never doubt it or try to use your internal knowledge to correct it. +Make the answer sound as a response to the question. Do not mention that you based the result on the given information. +Here is an example: + +Question: Which managers own Neo4j stocks? +Context:[manager:CTL LLC, manager:JANE STREET GROUP LLC] +Helpful Answer: CTL LLC, JANE STREET GROUP LLC owns Neo4j stocks. + +Follow this example when generating answers. +If the provided information is empty, say that you don't know the answer. +Information: +{context} + +Question: {question} +Helpful Answer:""" +CYPHER_QA_PROMPT = PromptTemplate( + input_variables=["context", "question"], template=CYPHER_QA_TEMPLATE +) + +SPARQL_QA_TEMPLATE = """Task: Generate a natural language response from the results of a SPARQL query. +You are an assistant that creates well-written and human understandable answers. +The information part contains the information provided, which you can use to construct an answer. +The information provided is authoritative, you must never doubt it or try to use your internal knowledge to correct it. +Make your response sound like the information is coming from an AI assistant, but don't add any information. +Information: +{context} + +Question: {prompt} +Helpful Answer:""" +SPARQL_QA_PROMPT = PromptTemplate( + input_variables=["context", "prompt"], template=SPARQL_QA_TEMPLATE +) + +NEPTUNE_OPENCYPHER_EXTRA_INSTRUCTIONS = """ +Instructions: +Generate the query in openCypher format and follow these rules: +Do not use `NONE`, `ALL` or `ANY` predicate functions, rather use list comprehensions. +Do not use `REDUCE` function. Rather use a combination of list comprehension and the `UNWIND` clause to achieve similar results. +Do not use `FOREACH` clause. Rather use a combination of `WITH` and `UNWIND` clauses to achieve similar results.{extra_instructions} +\n""" + +NEPTUNE_OPENCYPHER_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace( + "Instructions:", NEPTUNE_OPENCYPHER_EXTRA_INSTRUCTIONS +) + +NEPTUNE_OPENCYPHER_GENERATION_PROMPT = PromptTemplate( + input_variables=["schema", "question", "extra_instructions"], + template=NEPTUNE_OPENCYPHER_GENERATION_TEMPLATE, +) + +NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_TEMPLATE = """ +Write an openCypher query to answer the following question. Do not explain the answer. Only return the query.{extra_instructions} +Question: "{question}". +Here is the property graph schema: +{schema} +\n""" + +NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT = PromptTemplate( + input_variables=["schema", "question", "extra_instructions"], + template=NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_TEMPLATE, +) + +NEPTUNE_SPARQL_GENERATION_TEMPLATE = """ +Task: Generate a SPARQL SELECT statement for querying a graph database. +For instance, to find all email addresses of John Doe, the following +query in backticks would be suitable: +``` +PREFIX foaf: +SELECT ?email +WHERE {{ + ?person foaf:name "John Doe" . + ?person foaf:mbox ?email . +}} +``` +Instructions: +Use only the node types and properties provided in the schema. +Do not use any node types and properties that are not explicitly provided. +Include all necessary prefixes. + +Examples: + +Schema: +{schema} +Note: Be as concise as possible. +Do not include any explanations or apologies in your responses. +Do not respond to any questions that ask for anything else than +for you to construct a SPARQL query. +Do not include any text except the SPARQL query generated. + +The question is: +{prompt}""" + +NEPTUNE_SPARQL_GENERATION_PROMPT = PromptTemplate( + input_variables=["schema", "prompt"], template=NEPTUNE_SPARQL_GENERATION_TEMPLATE +) diff --git a/libs/aws/langchain_aws/graphs/__init__.py b/libs/aws/langchain_aws/graphs/__init__.py index 1aa136a4..174ba471 100644 --- a/libs/aws/langchain_aws/graphs/__init__.py +++ b/libs/aws/langchain_aws/graphs/__init__.py @@ -3,5 +3,6 @@ NeptuneAnalyticsGraph, NeptuneGraph, ) +from langchain_aws.graphs.neptune_rdf_graph import NeptuneRdfGraph -__all__ = ["BaseNeptuneGraph", "NeptuneAnalyticsGraph", "NeptuneGraph"] +__all__ = ["BaseNeptuneGraph", "NeptuneAnalyticsGraph", "NeptuneGraph", "NeptuneRdfGraph"] diff --git a/libs/aws/pyproject.toml b/libs/aws/pyproject.toml index 34a9f692..6dc2deb4 100644 --- a/libs/aws/pyproject.toml +++ b/libs/aws/pyproject.toml @@ -13,6 +13,7 @@ license = "MIT" [tool.poetry.dependencies] python = ">=3.9,<4.0" langchain-core = "^0.3.27" +langchain = ">=0.3.11,<0.4" boto3 = ">=1.35.74" pydantic = ">=2,<3" numpy = [