From eaa1cf2b6ff45c0a0177ccbef345b27bf6e2c46c Mon Sep 17 00:00:00 2001 From: Varun Joshi Date: Thu, 4 Jul 2024 17:22:46 -0400 Subject: [PATCH] Patronus Lynx Integration --- README.md | 4 +- .../advanced/patronus-lynx-deployment.md | 80 ++++++ docs/user_guides/guardrails-library.md | 71 +++++ docs/user_guides/llm-support.md | 1 + examples/configs/patronusai/config.yml | 15 + examples/configs/patronusai/prompts.yml | 32 +++ nemoguardrails/library/patronusai/__init__.py | 14 + nemoguardrails/library/patronusai/actions.py | 110 +++++++ nemoguardrails/library/patronusai/flows.co | 12 + .../library/patronusai/requirements.txt | 2 + nemoguardrails/llm/types.py | 3 + nemoguardrails/rails/llm/config.py | 7 + tests/test_patronus_lynx.py | 269 ++++++++++++++++++ 13 files changed, 618 insertions(+), 2 deletions(-) create mode 100644 docs/user_guides/advanced/patronus-lynx-deployment.md create mode 100644 examples/configs/patronusai/config.yml create mode 100644 examples/configs/patronusai/prompts.yml create mode 100644 nemoguardrails/library/patronusai/__init__.py create mode 100644 nemoguardrails/library/patronusai/actions.py create mode 100644 nemoguardrails/library/patronusai/flows.co create mode 100644 nemoguardrails/library/patronusai/requirements.txt create mode 100644 tests/test_patronus_lynx.py diff --git a/README.md b/README.md index a069c8dd3..ec596f779 100644 --- a/README.md +++ b/README.md @@ -220,7 +220,7 @@ NeMo Guardrails comes with a set of [built-in guardrails](https://docs.nvidia.co > **NOTE**: The built-in guardrails are only intended to enable you to get started quickly with NeMo Guardrails. For production use cases, further development and testing of the rails are needed. -Currently, the guardrails library includes guardrails for: [jailbreak detection](https://docs.nvidia.com/nemo/guardrails/user_guides/guardrails-library.html#jailbreak-detection-heuristics), [output moderation](https://docs.nvidia.com/nemo/guardrails/user_guides/guardrails-library.html#self-check-output), [fact-checking](https://docs.nvidia.com/nemo/guardrails/user_guides/guardrails-library.html#fact-checking), [sensitive data detection](https://docs.nvidia.com/nemo/guardrails/user_guides/guardrails-library.html#presidio-based-sensitive-data-detection), [hallucination detection](https://docs.nvidia.com/nemo/guardrails/user_guides/guardrails-library.html#hallucination-detection) and [input moderation using ActiveFence](https://docs.nvidia.com/nemo/guardrails/user_guides/guardrails-library.html#activefence) and [hallucination detection for RAG applications using Got It AI's TruthChecker API](docs/user_guides/guardrails-library.md#got-it-ai). +Currently, the guardrails library includes guardrails for: [jailbreak detection](https://docs.nvidia.com/nemo/guardrails/user_guides/guardrails-library.html#jailbreak-detection-heuristics), [output moderation](https://docs.nvidia.com/nemo/guardrails/user_guides/guardrails-library.html#self-check-output), [fact-checking](https://docs.nvidia.com/nemo/guardrails/user_guides/guardrails-library.html#fact-checking), [sensitive data detection](https://docs.nvidia.com/nemo/guardrails/user_guides/guardrails-library.html#presidio-based-sensitive-data-detection), [hallucination detection](https://docs.nvidia.com/nemo/guardrails/user_guides/guardrails-library.html#hallucination-detection), [input moderation using ActiveFence](<), [hallucination detection for RAG applications using Got It AI's TruthChecker API](docs/user_guides/guardrails-library.md#got-it-ai), and [RAG hallucination detection using Patronus Lynx](docs/user_guides/guardrails-library.md#patronus-lynx-based-rag-hallucination-detection). ## CLI @@ -283,7 +283,7 @@ Evaluating the safety of a LLM-based conversational application is a complex tas ## How is this different? -There are many ways guardrails can be added to an LLM-based conversational application. For example: explicit moderation endpoints (e.g., OpenAI, ActiveFence), critique chains (e.g. constitutional chain), parsing the output (e.g. guardrails.ai), individual guardrails (e.g., LLM-Guard), hallucination detection for RAG applications (e.g., Got It AI). +There are many ways guardrails can be added to an LLM-based conversational application. For example: explicit moderation endpoints (e.g., OpenAI, ActiveFence), critique chains (e.g. constitutional chain), parsing the output (e.g. guardrails.ai), individual guardrails (e.g., LLM-Guard), hallucination detection for RAG applications (e.g., Got It AI, Patronus Lynx). NeMo Guardrails aims to provide a flexible toolkit that can integrate all these complementary approaches into a cohesive LLM guardrails layer. For example, the toolkit provides out-of-the-box integration with ActiveFence, AlignScore and LangChain chains. diff --git a/docs/user_guides/advanced/patronus-lynx-deployment.md b/docs/user_guides/advanced/patronus-lynx-deployment.md new file mode 100644 index 000000000..9563cc239 --- /dev/null +++ b/docs/user_guides/advanced/patronus-lynx-deployment.md @@ -0,0 +1,80 @@ +# Host Patronus Lynx + +## vLLM + +Lynx is fully open source, so you can host it however you like. One simple way is using vLLM. + +1. Get access to Patronus Lynx on HuggingFace. See [here](https://huggingface.co/PatronusAI/Patronus-Lynx-70B-Instruct) for the 70B parameters variant, and [here](https://huggingface.co/PatronusAI/Patronus-Lynx-8B-Instruct) for the 8B parameters variant. The examples below use the `70B` parameters model, but there's no additional configuration to deploy the smaller model, so you can swap the model name references out with `8B`. + +2. Log in to Hugging Face + +```bash +huggingface-cli login +``` + +3. Install vLLM and spin up a server hosting Patronus Lynx + +```bash +pip install vllm +python -m vllm.entrypoints.openai.api_server --port 5000 --model PatronusAI/Patronus-Lynx-70B-Instruct +``` + +This will launch the vLLM inference server on `http://localhost:5000/`. You can use the OpenAI API spec to send it a cURL request to make sure it works: + +```bash +curl http://localhost:5000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "PatronusAI/Patronus-Lynx-70B-Instruct", + "messages": [ + {"role": "user", "content": "What is a hallucination?"}, + ] +}' +``` + +4. Create a model called `patronus_lynx` in your `config.yml` file, setting the host and port to what you set it as above. If the vLLM is running on a different server from `nemoguardrails`, you'll have to replace `localhost` with the vLLM server's address. Check out the guide [here](../guardrails-library.md#patronus-lynx-based-rag-hallucination-detection) for more information. + +## Ollama + +You can also run Patronus Lynx 8B on your personal computer using Ollama! + +1. Install Ollama: https://ollama.com/download. + +2. Get access to a GGUF quantized version of Lynx 8B on Huggingface. Check it out [here](https://huggingface.co/PatronusAI/Lynx-8B-Instruct-Q4_K_M-GGUF). + +3. Download the gguf model from the repository [here](https://huggingface.co/PatronusAI/Lynx-8B-Instruct-Q4_K_M-GGUF/blob/main/patronus-lynx-8b-instruct-q4_k_m.gguf). This may take a few minutes. + +4. Create a file called `Modelfile` with the following contents: + +```bash + FROM "./patronus-lynx-8b-instruct-q4_k_m.gguf" + PARAMETER stop "<|im_start|>" + PARAMETER stop "<|im_end|>" + TEMPLATE """ + <|im_start|>system + {{ .System }}<|im_end|> + <|im_start|>user + {{ .Prompt }}<|im_end|> + <|im_start|>assistant +``` + +Ensure that the `FROM` field correctly points to the `patronus-lynx-8b-instruct-q4_k_m.gguf` file you downloaded in Step 3. + +5. Run `ollama create patronus-lynx-8b -f Modelfile`. + +6. Run `ollama run patronus-lynx-8b`. You should now be able to chat with `patronus-lynx-8b`! + +7. Create a model called `patronus_lynx` in your `config.yml` file, like this: + +```yaml +models: + ... + + - type: patronus_lynx + engine: ollama + model: patronus-lynx-8b + parameters: + base_url: "http://localhost:11434" +``` + +Check out the guide [here](../guardrails-library.md#patronus-lynx-based-rag-hallucination-detection) for more information. diff --git a/docs/user_guides/guardrails-library.md b/docs/user_guides/guardrails-library.md index 35cdc2691..513e25b2c 100644 --- a/docs/user_guides/guardrails-library.md +++ b/docs/user_guides/guardrails-library.md @@ -12,6 +12,7 @@ NeMo Guardrails comes with a library of built-in guardrails that you can easily - [AlignScore-based Fact Checking](#alignscore-based-fact-checking) - [LlamaGuard-based Content Moderation](#llama-guard-based-content-moderation) - [Presidio-based Sensitive data detection](#presidio-based-sensitive-data-detection) + - [Patronus Lynx-based RAG Hallucination Detection](#patronus-lynx-based-rag-hallucination-detection) - BERT-score Hallucination Checking - *[COMING SOON]* 3. Third-Party APIs @@ -638,6 +639,76 @@ rails: If you want to implement a completely different sensitive data detection mechanism, you can override the default actions [`detect_sensitive_data`](https://github.com/NVIDIA/NeMo-Guardrails/tree/develop/nemoguardrails/library/sensitive_data_detection/actions.py) and [`mask_sensitive_data`](https://github.com/NVIDIA/NeMo-Guardrails/tree/develop/nemoguardrails/library/sensitive_data_detection/actions.py). +### Patronus Lynx-based RAG Hallucination Detection + +NeMo Guardrails supports hallucination detection in RAG systems using [Patronus AI](www.patronus.ai)'s Lynx model. The model is hosted on Hugging Face and comes in both a 70B parameters (see [here](https://huggingface.co/PatronusAI/Patronus-Lynx-70B-Instruct)) and 8B parameters (see [here](https://huggingface.co/PatronusAI/Patronus-Lynx-8B-Instruct)) variant. + +There are three components of hallucination that Lynx checks for: + +- Information in the `bot_message` is contained in the `relevant_chunks` +- There is no extra information in the `bot_message` that is not in the `relevant_chunks` +- The `bot_message` does not contradict any information in the `relevant_chunks` + +#### Setup + +Since Patronus Lynx is fully open source, you can host it however you like. You can find a guide to host Lynx using vLLM [here](./advanced/patronus-lynx-deployment.md). + +#### Usage + +Here is how to configure your bot to use Patronus Lynx to check for RAG hallucinations in your bot output: + +1. Add a model of type `patronus_lynx` in `config.yml` - the example below uses vLLM to run Lynx: + +```yaml +models: + ... + + - type: patronus_lynx + engine: vllm_openai + parameters: + openai_api_base: "http://localhost:5000/v1" + model_name: "PatronusAI/Patronus-Lynx-70B-Instruct" # "PatronusAI/Patronus-Lynx-8B-Instruct" +``` + +2. Add the guardrail name is `patronus lynx check output hallucination` to your output rails in `config.yml`: + +```yaml +rails: + output: + flows: + - patronus lynx check output hallucination +``` + +3. Add a prompt for `patronus_lynx_check_output_hallucination` in the `prompts.yml` file: + +```yaml +prompts: + - task: patronus_lynx_check_output_hallucination + content: | + Given the following QUESTION, DOCUMENT and ANSWER you must analyze ... + ... +``` + +We recommend you base your Lynx hallucination detection prompt off of the provided example [here](https://github.com/NVIDIA/NeMo-Guardrails/tree/develop/examples/configs/patronusai/prompts.yml). + +Under the hood, the `patronus lynx check output hallucination` rail runs the `patronus_lynx_check_output_hallucination` action, which you can find [here](https://github.com/NVIDIA/NeMo-Guardrails/tree/develop/nemoguardrails/library/patronusai/actions.py). It returns whether a hallucination is detected (`True` or `False`) and potentially a reasoning trace explaining the decision. The bot's response will be blocked if hallucination is `True`. Note: If Lynx's outputs are misconfigured or a hallucination decision cannot be found, the action default is to return `True` for hallucination. + +Here's the `patronus lynx check output hallucination` flow, showing how the action is executed: + +```colang +define bot inform answer unknown + "I don't know the answer to that." + +define flow patronus lynx check output hallucination + $patronus_lynx_response = execute patronus_lynx_check_output_hallucination + $hallucination = $patronus_lynx_response["hallucination"] + # The Reasoning trace is currently unused, but can be used to modify the bot output + $reasoning = $patronus_lynx_response["reasoning"] + + if $hallucination + bot inform answer unknown + stop +``` ## Third-Party APIs diff --git a/docs/user_guides/llm-support.md b/docs/user_guides/llm-support.md index 52bf51c1b..f3d038693 100644 --- a/docs/user_guides/llm-support.md +++ b/docs/user_guides/llm-support.md @@ -35,6 +35,7 @@ If you want to use an LLM and you cannot see a prompt in the [prompts folder](ht | ActiveFence moderation _(LLM independent)_ | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | Llama Guard moderation _(LLM independent)_ | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | Got It AI RAG TruthChecker _(LLM independent)_ | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | +| Patronus Lynx RAG Hallucination detection _(LLM independent)_ | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | Table legend: - :heavy_check_mark: - Supported (_The feature is fully supported by the LLM based on our experiments and tests_) diff --git a/examples/configs/patronusai/config.yml b/examples/configs/patronusai/config.yml new file mode 100644 index 000000000..413d7340b --- /dev/null +++ b/examples/configs/patronusai/config.yml @@ -0,0 +1,15 @@ +models: + - type: main + engine: openai + model: gpt-3.5-turbo-instruct + + - type: patronus_lynx + engine: vllm_openai + parameters: + openai_api_base: "http://localhost:5000/v1" + model_name: "PatronusAI/Patronus-Lynx-70B-Instruct" # "PatronusAI/Patronus-Lynx-8B-Instruct" + +rails: + output: + flows: + - patronus lynx check output hallucination diff --git a/examples/configs/patronusai/prompts.yml b/examples/configs/patronusai/prompts.yml new file mode 100644 index 000000000..7d23467d7 --- /dev/null +++ b/examples/configs/patronusai/prompts.yml @@ -0,0 +1,32 @@ +prompts: + - task: patronus_lynx_check_output_hallucination + content: | + Given the following QUESTION, DOCUMENT and ANSWER you must analyze the provided answer and determine whether it is faithful to the contents of the DOCUMENT. + + The ANSWER must not offer new information beyond the context provided in the DOCUMENT. + + The ANSWER also must not contradict information provided in the DOCUMENT. + + Output your final score by strictly following this format: "PASS" if the answer is faithful to the DOCUMENT and "FAIL" if the answer is not faithful to the DOCUMENT. + + Show your reasoning. + + -- + QUESTION (THIS DOES NOT COUNT AS BACKGROUND INFORMATION): + {{ user_input }} + + -- + DOCUMENT: + {{ provided_context }} + + -- + ANSWER: + {{ bot_response }} + + -- + + Your output should be in JSON FORMAT with the keys "REASONING" and "SCORE". + + Ensure that the JSON is valid and properly formatted. + + {"REASONING": [""], "SCORE": ""} diff --git a/nemoguardrails/library/patronusai/__init__.py b/nemoguardrails/library/patronusai/__init__.py new file mode 100644 index 000000000..9ba9d4310 --- /dev/null +++ b/nemoguardrails/library/patronusai/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemoguardrails/library/patronusai/actions.py b/nemoguardrails/library/patronusai/actions.py new file mode 100644 index 000000000..3171de73a --- /dev/null +++ b/nemoguardrails/library/patronusai/actions.py @@ -0,0 +1,110 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re +from typing import List, Optional, Tuple + +from langchain.llms.base import BaseLLM + +from nemoguardrails.actions import action +from nemoguardrails.actions.llm.utils import llm_call +from nemoguardrails.context import llm_call_info_var +from nemoguardrails.llm.params import llm_params +from nemoguardrails.llm.taskmanager import LLMTaskManager +from nemoguardrails.llm.types import Task +from nemoguardrails.logging.explain import LLMCallInfo + +log = logging.getLogger(__name__) + + +def parse_patronus_lynx_response( + response: str, +) -> Tuple[bool, List[str] | None]: + """ + Parses the response from the Patronus Lynx LLM and returns a tuple of: + - Whether the response is hallucinated or not. + - A reasoning trace explaining the decision. + """ + log.info(f"Patronus Lynx response: {response}.") + # Default to hallucinated + hallucination, reasoning = True, None + reasoning_pattern = r'"REASONING":\s*\[(.*?)\]' + score_pattern = r'"SCORE":\s*"?\b(PASS|FAIL)\b"?' + + reasoning_match = re.search(reasoning_pattern, response, re.DOTALL) + score_match = re.search(score_pattern, response) + + if score_match: + score = score_match.group(1) + if score == "PASS": + hallucination = False + if reasoning_match: + reasoning_content = reasoning_match.group(1) + reasoning = re.split(r"['\"],\s*['\"]", reasoning_content) + + return hallucination, reasoning + + +@action() +async def patronus_lynx_check_output_hallucination( + llm_task_manager: LLMTaskManager, + context: Optional[dict] = None, + patronus_lynx_llm: Optional[BaseLLM] = None, +) -> dict: + """ + Check the bot response for hallucinations based on the given chunks + using the configured Patronus Lynx model. + """ + user_input = context.get("user_message") + bot_response = context.get("bot_message") + provided_context = context.get("relevant_chunks") + + if ( + not provided_context + or not isinstance(provided_context, str) + or not provided_context.strip() + ): + log.error( + "Could not run Patronus Lynx. `relevant_chunks` must be passed as a non-empty string." + ) + return {"hallucination": False, "reasoning": None} + + check_output_hallucination_prompt = llm_task_manager.render_task_prompt( + task=Task.PATRONUS_LYNX_CHECK_OUTPUT_HALLUCINATION, + context={ + "user_input": user_input, + "bot_response": bot_response, + "provided_context": provided_context, + }, + ) + + stop = llm_task_manager.get_stop_tokens( + task=Task.PATRONUS_LYNX_CHECK_OUTPUT_HALLUCINATION + ) + + # Initialize the LLMCallInfo object + llm_call_info_var.set( + LLMCallInfo(task=Task.PATRONUS_LYNX_CHECK_OUTPUT_HALLUCINATION.value) + ) + + with llm_params(patronus_lynx_llm, temperature=0.0): + result = await llm_call( + patronus_lynx_llm, check_output_hallucination_prompt, stop=stop + ) + + hallucination, reasoning = parse_patronus_lynx_response(result) + print(f"Hallucination: {hallucination}, Reasoning: {reasoning}") + return {"hallucination": hallucination, "reasoning": reasoning} diff --git a/nemoguardrails/library/patronusai/flows.co b/nemoguardrails/library/patronusai/flows.co new file mode 100644 index 000000000..4903fa607 --- /dev/null +++ b/nemoguardrails/library/patronusai/flows.co @@ -0,0 +1,12 @@ +define bot inform answer unknown + "I don't know the answer to that." + +define flow patronus lynx check output hallucination + $patronus_lynx_response = execute patronus_lynx_check_output_hallucination + $hallucination = $patronus_lynx_response["hallucination"] + # The Reasoning trace is currently unused, but can be used to modify the bot output + $reasoning = $patronus_lynx_response["reasoning"] + + if $hallucination + bot inform answer unknown + stop diff --git a/nemoguardrails/library/patronusai/requirements.txt b/nemoguardrails/library/patronusai/requirements.txt new file mode 100644 index 000000000..b6ba7d750 --- /dev/null +++ b/nemoguardrails/library/patronusai/requirements.txt @@ -0,0 +1,2 @@ +# The minimal set of requirements to run Patronus Lynx on vLLM. +vllm==0.2.7 diff --git a/nemoguardrails/llm/types.py b/nemoguardrails/llm/types.py index ec07ee4b4..e34fbfeb6 100644 --- a/nemoguardrails/llm/types.py +++ b/nemoguardrails/llm/types.py @@ -41,6 +41,9 @@ class Task(Enum): SELF_CHECK_OUTPUT = "self_check_output" LLAMA_GUARD_CHECK_INPUT = "llama_guard_check_input" LLAMA_GUARD_CHECK_OUTPUT = "llama_guard_check_output" + PATRONUS_LYNX_CHECK_OUTPUT_HALLUCINATION = ( + "patronus_lynx_check_output_hallucination" + ) SELF_CHECK_FACTS = "fact_checking" CHECK_HALLUCINATION = "check_hallucination" diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index a056868b0..77c236fca 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -818,6 +818,13 @@ def check_prompt_exist_for_self_check_rails(cls, values): raise ValueError( "You must provide a `llama_guard_check_output` prompt template." ) + if ( + "patronus lynx check output hallucination" in enabled_output_rails + and "patronus_lynx_check_output_hallucination" not in provided_task_prompts + ): + raise ValueError( + "You must provide a `patronus_lynx_check_output_hallucination` prompt template." + ) if ( "self check facts" in enabled_output_rails diff --git a/tests/test_patronus_lynx.py b/tests/test_patronus_lynx.py new file mode 100644 index 000000000..9fccfdbd7 --- /dev/null +++ b/tests/test_patronus_lynx.py @@ -0,0 +1,269 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from nemoguardrails import LLMRails, RailsConfig +from nemoguardrails.actions.actions import ActionResult, action +from tests.utils import FakeLLM, TestChat + +COLANG_CONFIG = """ +define user express greeting + "hi" +define bot refuse to respond + "I'm sorry, I can't respond to that." +""" + +YAML_CONFIG = """ +models: + - type: main + engine: openai + model: gpt-3.5-turbo-instruct + - type: patronus_lynx + engine: vllm_openai + parameters: + openai_api_base: "http://localhost:5000/v1" + model_name: "PatronusAI/Patronus-Lynx-70B-Instruct" +rails: + output: + flows: + - patronus lynx check output hallucination +prompts: + - task: patronus_lynx_check_output_hallucination + content: | + Given the following QUESTION, DOCUMENT and ANSWER you must analyze the provided answer and determine whether it is faithful to the contents of the DOCUMENT. + + The ANSWER must not offer new information beyond the context provided in the DOCUMENT. + + The ANSWER also must not contradict information provided in the DOCUMENT. + + Output your final score by strictly following this format: "PASS" if the answer is faithful to the DOCUMENT and "FAIL" if the answer is not faithful to the DOCUMENT. + + Show your reasoning. + + -- + QUESTION (THIS DOES NOT COUNT AS BACKGROUND INFORMATION): + {{ user_input }} + + -- + DOCUMENT: + {{ provided_context }} + + -- + ANSWER: + {{ bot_response }} + + -- + + Your output should be in JSON FORMAT with the keys "REASONING" and "SCORE". + + Ensure that the JSON is valid and properly formatted. + + {"REASONING": [""], "SCORE": ""} +""" + + +@action() +def retrieve_relevant_chunks(): + context_updates = {"relevant_chunks": "Mock retrieved context."} + + return ActionResult( + return_value=context_updates["relevant_chunks"], + context_updates=context_updates, + ) + + +@pytest.mark.asyncio +def test_patronus_lynx_returns_no_hallucination(): + """ + Test that that chat flow completes successfully when + Patronus Lynx returns "PASS" for the hallucination check + """ + config = RailsConfig.from_content( + colang_content=COLANG_CONFIG, yaml_content=YAML_CONFIG + ) + chat = TestChat( + config, + llm_completions=[ + "Mock generated user intent", # mock response for the generate_user_intent action + "Mock generated next step", # mock response for the generate_next_step action + " Hi there! How are you doing?", # mock response for the generate_bot_message action + ], + ) + + chat.app.register_action(retrieve_relevant_chunks, "retrieve_relevant_chunks") + + patronus_lynx_llm = FakeLLM( + responses=[ + '{"REASONING": ["There is no hallucination."], "SCORE": "PASS"}', + ] + ) + chat.app.register_action_param("patronus_lynx_llm", patronus_lynx_llm) + + chat >> "Hi" + chat << "Hi there! How are you doing?" + + +@pytest.mark.asyncio +def test_patronus_lynx_returns_hallucination(): + """ + Test that that bot output is successfully guarded against when + Patronus Lynx returns "FAIL" for the hallucination check + """ + config = RailsConfig.from_content( + colang_content=COLANG_CONFIG, yaml_content=YAML_CONFIG + ) + chat = TestChat( + config, + llm_completions=[ + "Mock generated user intent", # mock response for the generate_user_intent action + "Mock generated next step", # mock response for the generate_next_step action + " Hi there! How are you doing?", # mock response for the generate_bot_message action + ], + ) + + chat.app.register_action(retrieve_relevant_chunks, "retrieve_relevant_chunks") + + patronus_lynx_llm = FakeLLM( + responses=[ + '{"REASONING": ["There is a hallucination."], "SCORE": "FAIL"}', + ] + ) + chat.app.register_action_param("patronus_lynx_llm", patronus_lynx_llm) + + chat >> "Hi" + chat << "I don't know the answer to that." + + +@pytest.mark.asyncio +def test_patronus_lynx_parses_score_when_no_double_quote(): + """ + Test that that chat flow completes successfully when + Patronus Lynx returns "PASS" for the hallucination check + """ + config = RailsConfig.from_content( + colang_content=COLANG_CONFIG, yaml_content=YAML_CONFIG + ) + chat = TestChat( + config, + llm_completions=[ + "Mock generated user intent", # mock response for the generate_user_intent action + "Mock generated next step", # mock response for the generate_next_step action + " Hi there! How are you doing?", # mock response for the generate_bot_message action + ], + ) + + chat.app.register_action(retrieve_relevant_chunks, "retrieve_relevant_chunks") + + patronus_lynx_llm = FakeLLM( + responses=[ + '{"REASONING": ["There is no hallucination."], "SCORE": PASS}', + ] + ) + chat.app.register_action_param("patronus_lynx_llm", patronus_lynx_llm) + + chat >> "Hi" + chat << "Hi there! How are you doing?" + + +@pytest.mark.asyncio +def test_patronus_lynx_returns_no_hallucination_when_no_retrieved_context(): + """ + Test that that Patronus Lynx does not block the bot output + when no relevant context is given + """ + config = RailsConfig.from_content( + colang_content=COLANG_CONFIG, yaml_content=YAML_CONFIG + ) + chat = TestChat( + config, + llm_completions=[ + "Mock generated user intent", # mock response for the generate_user_intent action + "Mock generated next step", # mock response for the generate_next_step action + " Hi there! How are you doing?", # mock response for the generate_bot_message action + ], + ) + + patronus_lynx_llm = FakeLLM( + responses=[ + '{"REASONING": ["There is a hallucination."], "SCORE": "FAIL"}', + ] + ) + chat.app.register_action_param("patronus_lynx_llm", patronus_lynx_llm) + + chat >> "Hi" + chat << "Hi there! How are you doing?" + + +@pytest.mark.asyncio +def test_patronus_lynx_returns_hallucination_when_no_score_in_llm_output(): + """ + Test that that Patronus Lynx defaults to blocking the bot output + when no score is returned in its response. + """ + config = RailsConfig.from_content( + colang_content=COLANG_CONFIG, yaml_content=YAML_CONFIG + ) + chat = TestChat( + config, + llm_completions=[ + "Mock generated user intent", # mock response for the generate_user_intent action + "Mock generated next step", # mock response for the generate_next_step action + " Hi there! How are you doing?", # mock response for the generate_bot_message action + ], + ) + + chat.app.register_action(retrieve_relevant_chunks, "retrieve_relevant_chunks") + + patronus_lynx_llm = FakeLLM( + responses=[ + '{"REASONING": ["Mock reasoning."]}', + ] + ) + chat.app.register_action_param("patronus_lynx_llm", patronus_lynx_llm) + + chat >> "Hi" + chat << "I don't know the answer to that." + + +@pytest.mark.asyncio +def test_patronus_lynx_returns_no_hallucination_when_no_reasoning_in_llm_output(): + """ + Test that that Patronus Lynx's hallucination check does not + depend on the reasoning provided in its response. + """ + config = RailsConfig.from_content( + colang_content=COLANG_CONFIG, yaml_content=YAML_CONFIG + ) + chat = TestChat( + config, + llm_completions=[ + "Mock generated user intent", # mock response for the generate_user_intent action + "Mock generated next step", # mock response for the generate_next_step action + " Hi there! How are you doing?", # mock response for the generate_bot_message action + ], + ) + + chat.app.register_action(retrieve_relevant_chunks, "retrieve_relevant_chunks") + + patronus_lynx_llm = FakeLLM( + responses=[ + '{"SCORE": "PASS"}', + ] + ) + chat.app.register_action_param("patronus_lynx_llm", patronus_lynx_llm) + + chat >> "Hi" + chat << "Hi there! How are you doing?"