Skip to content

Commit

Permalink
Merge pull request #704
Browse files Browse the repository at this point in the history
Fix/output parser validation
  • Loading branch information
drazvan authored Aug 29, 2024
2 parents 52800e0 + 6638c95 commit 9c741bd
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 11 deletions.
19 changes: 14 additions & 5 deletions nemoguardrails/rails/llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,11 +901,20 @@ def check_output_parser_exists(cls, values):
]
prompts = values.get("prompts", [])
for prompt in prompts:
task = prompt.get("task")
if any(
task.startswith(task_prefix)
for task_prefix in tasks_requiring_output_parser
) and not prompt.get("output_parser"):
task = prompt.task if hasattr(prompt, "task") else prompt.get("task")
output_parser = (
prompt.output_parser
if hasattr(prompt, "output_parser")
else prompt.get("output_parser")
)

if (
any(
task.startswith(task_prefix)
for task_prefix in tasks_requiring_output_parser
)
and not output_parser
):
log.info(
f"Deprecation Warning: Output parser is not registered for the task. "
f"The correct way is to register the 'output_parser' in the prompts.yml for '{task}' task. "
Expand Down
27 changes: 21 additions & 6 deletions tests/test_rails_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,32 @@
import pytest

from nemoguardrails import RailsConfig
from nemoguardrails.llm.prompts import TaskPrompt


def test_check_output_parser_exists(caplog):
caplog.set_level(logging.INFO)
values = {
"prompts": [
@pytest.fixture(
params=[
[
TaskPrompt(task="self_check_input", output_parser=None, content="..."),
TaskPrompt(task="self_check_facts", output_parser="parser1", content="..."),
TaskPrompt(
task="self_check_output", output_parser="parser2", content="..."
),
],
[
{"task": "self_check_input", "output_parser": None},
{"task": "self_check_facts", "output_parser": "parser1"},
{"task": "self_check_output", "output_parser": "parser2"},
]
}
],
]
)
def prompts(request):
return request.param


def test_check_output_parser_exists(caplog, prompts):
caplog.set_level(logging.INFO)
values = {"prompts": prompts}

result = RailsConfig.check_output_parser_exists(values)

Expand Down

0 comments on commit 9c741bd

Please sign in to comment.