From d1a1099371af4ed1aa62813e10a517de4782871a Mon Sep 17 00:00:00 2001 From: prezakhani <13303554+Pouyanpi@users.noreply.github.com> Date: Tue, 27 Aug 2024 14:57:25 +0200 Subject: [PATCH 1/2] fix(config): handle TaskPrompt object in check_output_parser_exists remove extra msg --- nemoguardrails/rails/llm/config.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index d3cf3f767..c6294cedc 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -901,11 +901,20 @@ def check_output_parser_exists(cls, values): ] prompts = values.get("prompts", []) for prompt in prompts: - task = prompt.get("task") - if any( - task.startswith(task_prefix) - for task_prefix in tasks_requiring_output_parser - ) and not prompt.get("output_parser"): + task = prompt.task if hasattr(prompt, "task") else prompt.get("task") + output_parser = ( + prompt.output_parser + if hasattr(prompt, "output_parser") + else prompt.get("output_parser") + ) + + if ( + any( + task.startswith(task_prefix) + for task_prefix in tasks_requiring_output_parser + ) + and not output_parser + ): log.info( f"Deprecation Warning: Output parser is not registered for the task. " f"The correct way is to register the 'output_parser' in the prompts.yml for '{task}' task. " From 6638c95912d51c4e4069e03d7ec239ec58705793 Mon Sep 17 00:00:00 2001 From: prezakhani <13303554+Pouyanpi@users.noreply.github.com> Date: Tue, 27 Aug 2024 14:57:58 +0200 Subject: [PATCH 2/2] test(config): update test to use TaskPrompt instances test(config): update test --- tests/test_rails_config.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/tests/test_rails_config.py b/tests/test_rails_config.py index fdaa3e49d..3bfb591d2 100644 --- a/tests/test_rails_config.py +++ b/tests/test_rails_config.py @@ -18,17 +18,32 @@ import pytest from nemoguardrails import RailsConfig +from nemoguardrails.llm.prompts import TaskPrompt -def test_check_output_parser_exists(caplog): - caplog.set_level(logging.INFO) - values = { - "prompts": [ +@pytest.fixture( + params=[ + [ + TaskPrompt(task="self_check_input", output_parser=None, content="..."), + TaskPrompt(task="self_check_facts", output_parser="parser1", content="..."), + TaskPrompt( + task="self_check_output", output_parser="parser2", content="..." + ), + ], + [ {"task": "self_check_input", "output_parser": None}, {"task": "self_check_facts", "output_parser": "parser1"}, {"task": "self_check_output", "output_parser": "parser2"}, - ] - } + ], + ] +) +def prompts(request): + return request.param + + +def test_check_output_parser_exists(caplog, prompts): + caplog.set_level(logging.INFO) + values = {"prompts": prompts} result = RailsConfig.check_output_parser_exists(values)