Skip to content

Commit

Permalink
Merge pull request #357 from Eyobyb/regration-validation-logic
Browse files Browse the repository at this point in the history
change the validation logic with regression  validation logic,
  • Loading branch information
20001LastOrder authored Jun 3, 2024
2 parents 699c2a9 + 8c17cf4 commit c17a7ae
Show file tree
Hide file tree
Showing 18 changed files with 323 additions and 105 deletions.
131 changes: 109 additions & 22 deletions src/sherpa_ai/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from sherpa_ai.policies.base import BasePolicy
from sherpa_ai.verbose_loggers.base import BaseVerboseLogger
from sherpa_ai.verbose_loggers.verbose_loggers import DummyVerboseLogger

from langchain.base_language import BaseLanguageModel

# Avoid circular import
if TYPE_CHECKING:
Expand All @@ -32,7 +32,10 @@ def __init__(
validation_steps: int = 1,
validations: List[BaseOutputProcessor] = [],
feedback_agent_name: str = "critic",
global_regen_max: int = 12,
llm: BaseLanguageModel = None,
):
self.llm = llm
self.name = name
self.description = description
self.shared_memory = shared_memory
Expand All @@ -45,6 +48,7 @@ def __init__(
self.verbose_logger = verbose_logger
self.actions = actions
self.validation_steps = validation_steps
self.global_regen_max = global_regen_max
self.validations = validations
self.feedback_agent_name = feedback_agent_name

Expand Down Expand Up @@ -91,13 +95,67 @@ def run(self):
EventType.action_output, self.name, action_output
)

result = self.validate_output()
result = (
self.validate_output()
if len(self.validations) > 0
else self.synthesize_output()
)

logger.debug(f"```🤖{self.name} wrote: {result}```")

self.shared_memory.add(EventType.result, self.name, result)
return result

# The validation_iterator function is responsible for iterating through each instantiated validation in the 'self.validations' list.
# It performs the necessary validation steps for each validation, updating the belief system and synthesizing output if needed.
# It keeps track of the global regeneration count, whether all validations have passed, and if any validation has been escaped.
# The function returns the updated global regeneration count, the status of all validations, whether any validation has been escaped, and the synthesized output.
def validation_iterator(
self,
validations,
global_regen_count,
all_pass,
validation_is_scaped,
result,
):

for i in range(len(validations)):
validation = validations[i]
logger.info(f"validation_running: {validation.__class__.__name__}")
logger.info(f"validation_count: {validation.count}")
# this checks if the validator has already exceeded the validation steps limit.
if validation.count < self.validation_steps:
self.belief.update_internal(EventType.result, self.name, result)
validation_result = validation.process_output(
text=result, belief=self.belief, llm=self.llm
)
logger.info(f"validation_result: {validation_result}")
if not validation_result.is_valid:
self.belief.update_internal(
EventType.feedback,
self.feedback_agent_name,
validation_result.feedback,
)
result = self.synthesize_output()
global_regen_count += 1
break

# if all validations passed then set all_pass to True
elif i == len(validations) - 1:
result = validation_result.result
all_pass = True
else:
result = validation_result.result
# if validation is the last one and surpassed the validation steps limit then finish the loop with all_pass and mention there is a scaped validation.
elif i == len(validations) - 1:
validation_is_scaped = True
all_pass = True

else:
# if the validation has already reached the validation steps limit then continue to the next validation.
validation_is_scaped = True
return global_regen_count, all_pass, validation_is_scaped, result

def validate_output(self):
"""
Validate the synthesized output through a series of validation steps.
Expand All @@ -112,35 +170,64 @@ def validate_output(self):
Returns:
str: The synthesized output after validation.
"""
failed_validation = []
result = ""
# create array of instance of validation so that we can keep track of how many times regeneration happened.
all_pass = False
validation_is_scaped = False
iteration_count = 0
result = self.synthesize_output()
global_regen_count = 0

# reset the state of all the validation before starting the validation process.
for validation in self.validations:
for count in range(self.validation_steps):
self.belief.update_internal(EventType.result, self.name, result)
validation.reset_state()

validations = self.validations

# this loop will run until max regeneration reached or all validations have failed
while self.global_regen_max > global_regen_count and not all_pass:
logger.info(f"validations_size: {len(validations)}")
iteration_count += 1
logger.info(f"main_iteration: {iteration_count}")
logger.info(f"regen_count: {global_regen_count}")

global_regen_count, all_pass, validation_is_scaped, result = (
self.validation_iterator(
all_pass=all_pass,
global_regen_count=global_regen_count,
validation_is_scaped=validation_is_scaped,
validations=validations,
result=result,
)
)
# if all didn't pass or validation reached max regeneration run the validation one more time but no regeneration.
if validation_is_scaped or self.global_regen_max >= global_regen_count:
failed_validations = []

for validation in validations:
validation_result = validation.process_output(
text=result, belief=self.belief, iteration_count=count
text=result, belief=self.belief, llm=self.llm
)

if validation_result.is_valid:
result = validation_result.result
break
if not validation_result.is_valid:
failed_validations.append(validation)
else:
self.belief.update_internal(
EventType.feedback,
self.feedback_agent_name,
validation_result.feedback,
)
result = self.synthesize_output()

if count >= self.validation_steps:
failed_validation.append(validation)
result = validation_result.result

if len(failed_validation) > 0:
# if the validation failed after all steps, append the error messages to the result
result += "\n".join(
failed_validation.get_failure_message()
for failed_validation in failed_validation
for failed_validation in failed_validations
)

else:

# check if validation is not passed after all the attempts if so return the error message.
result += "\n".join(
(
inst_val.get_failure_message()
if inst_val.count == self.validation_steps
else ""
)
for inst_val in validations
)

self.belief.update_internal(EventType.result, self.name, result)
Expand Down
24 changes: 14 additions & 10 deletions src/sherpa_ai/agents/qa_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
actions: List[BaseAction] = [],
validation_steps: int = 1,
validations: List[BaseOutputProcessor] = [],
global_regen_max: int = 5,
):
"""
The QA agent handles a single question-answering task.
Expand All @@ -63,16 +64,18 @@ def __init__(
validations (List[BaseOutputProcessor], optional): The list of validations the agent will perform. Defaults to [].
"""
super().__init__(
name,
description + "\n\n" + f"Your name is {name}.",
shared_memory,
belief,
policy,
num_runs,
verbose_logger,
actions,
validation_steps,
validations,
llm=llm,
name=name,
description=description + "\n\n" + f"Your name is {name}.",
shared_memory=shared_memory,
belief=belief,
policy=policy,
num_runs=num_runs,
verbose_logger=verbose_logger,
actions=actions,
validation_steps=validation_steps,
validations=validations,
global_regen_max=global_regen_max,
)

if self.policy is None:
Expand All @@ -89,6 +92,7 @@ def __init__(
belief = Belief()
self.belief = belief
self.citation_enabled = False

for validation in self.validations:
if isinstance(validation, CitationValidation):
self.citation_enabled = True
Expand Down
7 changes: 6 additions & 1 deletion src/sherpa_ai/output_parsers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class BaseOutputProcessor(ABC):
Defines the interface for processing output text.
Attributes:
None
count (int): Abstract global variable representing the count of failed validations.
Methods:
process_output(text: str) -> Tuple[bool, str]:
Expand All @@ -52,6 +52,11 @@ class BaseOutputProcessor(ABC):
"""

count: int = 0

def reset_state(self):
self.count = 0

@abstractmethod
def process_output(self, text: str, **kwargs) -> ValidationResult:
"""
Expand Down
18 changes: 14 additions & 4 deletions src/sherpa_ai/output_parsers/entity_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sherpa_ai.memory import Belief
from sherpa_ai.output_parsers.base import BaseOutputProcessor
from sherpa_ai.output_parsers.validation_result import ValidationResult
from langchain.base_language import BaseLanguageModel
from sherpa_ai.utils import (
extract_entities,
text_similarity,
Expand Down Expand Up @@ -36,7 +37,7 @@ class EntityValidation(BaseOutputProcessor):
"""

def process_output(
self, text: str, belief: Belief, iteration_count: int = 1
self, text: str, belief: Belief, llm: BaseLanguageModel = None, **kwargs
) -> ValidationResult:
"""
Verifies that entities within `text` exist in the `belief` source text.
Expand All @@ -58,7 +59,7 @@ def process_output(
exclude_types=[EventType.feedback, EventType.result],
)
entity_exist_in_source, error_message = self.check_entities_match(
text, source, self.similarity_picker(iteration_count)
text, source, self.similarity_picker(self.count), llm
)
if entity_exist_in_source:
return ValidationResult(
Expand All @@ -67,6 +68,7 @@ def process_output(
feedback="",
)
else:
self.count += 1
return ValidationResult(
is_valid=False,
result=text,
Expand All @@ -93,7 +95,11 @@ def get_failure_message(self) -> str:
return "Some enitities from the source might not be mentioned."

def check_entities_match(
self, result: str, source: str, stage: TextSimilarityMethod
self,
result: str,
source: str,
stage: TextSimilarityMethod,
llm: BaseLanguageModel,
):
"""
Check if entities extracted from a question are present in an answer.
Expand All @@ -118,9 +124,13 @@ def check_entities_match(
return text_similarity_by_metrics(
check_entity=check_entity, source_entity=source_entity
)
else:
elif stage > 1 and llm is not None:
return text_similarity_by_llm(
llm=llm,
source_entity=source_entity,
result=result,
source=source,
)
return text_similarity_by_metrics(
check_entity=check_entity, source_entity=source_entity
)
1 change: 1 addition & 0 deletions src/sherpa_ai/output_parsers/number_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def process_output(self, text: str, belief: Belief, **kwargs) -> ValidationResul
feedback="",
)
else:
self.count += 1
return ValidationResult(
is_valid=False,
result=text,
Expand Down
9 changes: 2 additions & 7 deletions src/sherpa_ai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from nltk.metrics import edit_distance, jaccard_distance
from pypdf import PdfReader
from word2number import w2n
from langchain.base_language import BaseLanguageModel

import sherpa_ai.config as cfg
from sherpa_ai.database.user_usage_tracker import UserUsageTracker
Expand Down Expand Up @@ -507,6 +508,7 @@ def json_from_text(text: str):


def text_similarity_by_llm(
llm: BaseLanguageModel,
source_entity: List[str],
source,
result,
Expand All @@ -527,13 +529,6 @@ def text_similarity_by_llm(
dict: Result of the check containing 'entity_exist' and 'messages'.
"""

llm = SherpaOpenAI(
temperature=cfg.TEMPERATURE,
openai_api_key=cfg.OPENAI_API_KEY,
user_id=user_id,
team_id=team_id,
)

instruction = f"""
I have a question and an answer. I want you to confirm whether the entities from the question are all mentioned in some form within the answer.
Expand Down
Loading

0 comments on commit c17a7ae

Please sign in to comment.