Skip to content

Commit

Permalink
Merge pull request #317 from oshoma/refactor-citations
Browse files Browse the repository at this point in the history
Refactor citations to simplify code and improve docs (closes #303)
  • Loading branch information
20001LastOrder authored Mar 28, 2024
2 parents c49c8cf + 42546fb commit 6a20191
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 128 deletions.
246 changes: 123 additions & 123 deletions src/sherpa_ai/output_parsers/citation_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,11 @@ class CitationValidation(BaseOutputProcessor):
reference texts and links provided in the 'resources' list.
Attributes:
- seq_thresh (float): Threshold for common longest subsequence / text. Default is 0.7.
- jaccard_thresh (float): Jaccard similarity threshold. Default is 0.7.
- token_overlap (float): Token overlap threshold. Default is 0.7.
Methods:
- calculate_token_overlap(sentence1, sentence2): Calculate token overlap between two sentences.
- jaccard_index(sentence1, sentence2): Calculate Jaccard similarity index between two sentences.
- longestCommonSubsequence(text1, text2): Calculate the length of the longest common subsequence between two texts.
- unfoldList(nestedList): Flatten a nested list of strings.
- split_paragraph_into_sentences(paragraph): Tokenize a paragraph into sentences.
- parse_output(generated, resources): Add citation to each sentence in the generated text from resources based on fact-checking model.
Example Usage:
sequence_threshold (float): Threshold for common longest subsequence / text. Default is 0.7.
jaccard_threshold (float): Jaccard similarity threshold. Default is 0.7.
token_overlap (float): Token overlap threshold. Default is 0.7.
Typical usage example:
```python
citation_parser = CitationValidation(seq_thresh=0.7, jaccard_thresh=0.7, token_overlap=0.7)
result = citation_parser.parse_output(generated_text, list_of_resources)
Expand All @@ -41,36 +33,25 @@ class CitationValidation(BaseOutputProcessor):
def __init__(
self, sequence_threshold=0.7, jaccard_threshold=0.7, token_overlap=0.7
):
"""
Initialize the CitationValidation object.
Args:
- seq_thresh (float): Threshold for common longest subsequence / text. Default is 0.7.
- jaccard_thresh (float): Jaccard similarity threshold. Default is 0.7.
- token_overlap (float): Token overlap threshold. Default is 0.7.
"""
# threshold
self.sequence_threshold = (
sequence_threshold # threshold for common longest subsequece / text
)
self.sequence_threshold = sequence_threshold
self.jaccard_threshold = jaccard_threshold
self.token_overlap = token_overlap

def calculate_token_overlap(self, sentence1, sentence2) -> tuple:
"""
Calculate the percentage of token overlap between two sentences.
Calculates the percentage of token overlap between two sentences.
Tokenizes the input sentences and calculates the percentage of token overlap
by finding the intersection of the token sets and dividing it by the length
of each sentence's token set.
Args:
- sentence1 (str): The first sentence for token overlap calculation.
- sentence2 (str): The second sentence for token overlap calculation.
sentence1 (str): The first sentence for token overlap calculation.
sentence2 (str): The second sentence for token overlap calculation.
Returns:
- tuple: A tuple containing two float values representing the percentage
of token overlap for sentence1 and sentence2, respectively.
tuple: A tuple containing two float values representing the percentage
of token overlap for sentence1 and sentence2, respectively.
"""
# Tokenize the sentences
tokens1 = word_tokenize(sentence1)
Expand All @@ -91,17 +72,17 @@ def calculate_token_overlap(self, sentence1, sentence2) -> tuple:

def jaccard_index(sself, sentence1, sentence2) -> float:
"""
Calculate the Jaccard index between two sentences.
Calculates the Jaccard index between two sentences.
The Jaccard index is a measure of similarity between two sets, defined as the
size of the intersection divided by the size of the union of the sets.
Args:
- sentence1 (str): The first sentence for Jaccard index calculation.
- sentence2 (str): The second sentence for Jaccard index calculation.
sentence1 (str): The first sentence for Jaccard index calculation.
sentence2 (str): The second sentence for Jaccard index calculation.
Returns:
- float: The Jaccard index representing the similarity between the two sentences.
float: The Jaccard index representing the similarity between the two sentences.
"""
# Convert the sentences to sets of words
set1 = set(word_tokenize(sentence1))
Expand All @@ -115,9 +96,9 @@ def jaccard_index(sself, sentence1, sentence2) -> float:

return jaccard_index

def longestCommonSubsequence(self, text1: str, text2: str) -> int:
def longest_common_subsequence(self, text1: str, text2: str) -> int:
"""
Calculate the length of the longest common subsequence between two texts.
Calculates the length of the longest common subsequence between two texts.
A subsequence of a string is a new string generated from the original
string with some characters (can be none) deleted without changing
Expand All @@ -140,140 +121,159 @@ def longestCommonSubsequence(self, text1: str, text2: str) -> int:
dp[i][j] = max(diagnoal, dp[i - 1][j], dp[i][j - 1])
return dp[-1][-1]

def unfoldList(self, nestedList: list[list[str]]) -> list[str]:
def flatten_nested_list(self, nested_list: list[list[str]]) -> list[str]:
"""
Flatten a nested list of strings into a single list of strings.
Flattens a nested list of strings into a single list of strings.
Args:
- nestedList (list[list[str]]): The nested list of strings to be flattened.
nested_list (list[list[str]]): The nested list of strings to be flattened.
Returns:
- list[str]: A flat list containing all non-empty strings from the nested list.
list[str]: A flat list containing all non-empty strings from the nested list.
"""
sentences = []
for sublist in nestedList:
for sublist in nested_list:
for item in sublist:
if len(item) > 0:
sentences.append(item)
return sentences

def split_paragraph_into_sentences(self, paragraph: str) -> list[str]:
"""
Tokenize a paragraph into a list of sentences.
Uses NLTK's sent_tokenize to split a given paragraph into a list of sentences.
Uses NLTK's sent_tokenize to split the given paragraph into a list of sentences.
Args:
- paragraph (str): The input paragraph to be tokenized into sentences.
paragraph (str): The input paragraph to be tokenized into sentences.
Returns:
- list[str]: A list of sentences extracted from the input paragraph.
list[str]: A list of sentences extracted from the input paragraph.
"""
sentences = sent_tokenize(paragraph)
return sentences

def find_used_resources(self, belief: Belief) -> list[dict]:
def resources_from_belief(self, belief: Belief) -> list[dict]:
"""
Returns a list of all resources within belief.actions.
"""
resources = []
for action in belief.actions:
if hasattr(action, "meta") and action.meta is not None:
resources.extend(action.meta[-1])
return resources

# add citation to the generated text
def process_output(self, generated: str, belief: Belief) -> ValidationResult:
def process_output(self, text: str, belief: Belief) -> ValidationResult:
"""
Add citation to each sentence in the generated text from resources based on fact checking model.
Args:
generated (str): The generated content where we need to add citation/reference
agent (BaseAgent): Belief of the agents generated the content
Returns:
- ValidationResult: An object containing the result of citation addition and feedback.
The ValidationResult has attributes 'is_valid' indicating success, 'result' containing
the formatted text with citations, and 'feedback' providing additional information.
Note:
- The 'resources' list should contain dictionaries with "Document" and "Source" keys.
Example:
```python
resources = [{"Document": "Some reference text.", "Source": "http://example.com/source1"}]
citation_parser = CitationValidation()
result = citation_parser.parse_output("Generated text.", resources)
```
Add citations to sentences in the generated text using resources based on fact checking model.
Args:
text (str): The text which needs citations/references added
belief (Belief): Belief of the agent that generated `text`
Returns:
ValidationResult: The result of citation processing.
`is_valid` is True when citation processing succeeds or no citation resources are provided,
False otherwise.
`result` contains the formatted text with citations.
`feedback` providing additional optional information.
Note:
The 'resources' list should contain dictionaries with "Document" and "Source" keys.
Typical usage example:
```python
resources = [{"Document": "Some reference text.", "Source": "http://example.com/source1"}]
citation_parser = CitationValidation()
result = citation_parser.parse_output("Text needing citations.", resources)
```
"""
# resources type
# resources = [{"Document":, "Source":...}, {}]
resources = self.find_used_resources(belief)
resources = self.resources_from_belief(belief)

if len(resources) == 0:
# no resources used, return the original text
return ValidationResult(
is_valid=True,
result=generated,
result=text,
feedback="",
)

return self.add_citations(generated, resources)
return self.add_citations(text, resources)

def add_citation_to_sentence(self, sentence: str, resources: list[dict]):
"""
Uses a list of resources to add citations to a sentence
Returns:
citation_ids: a list of citation identifiers
citation_links: a list of citation links (URLs)
"""
citation_ids = []
citation_links = []

if len(sentence) <= 5:
return citation_ids, citation_links

for index, resource in enumerate(resources):
cited = False
resource_link = resource["Source"]
resource_text = resource["Document"]
resource_sentences = resource_text.split(".")
# TODO: verify that splitting each sentence on newlines improves citation results
nested_sentence_lines = [s.split("\n") for s in resource_sentences]
resource_lines = self.flatten_nested_list(nested_sentence_lines)

for resource_line in resource_lines:
if not cited and not (resource_link in citation_links):
seq = self.longest_common_subsequence(sentence, resource_line)
if (
(seq / len(sentence)) > self.sequence_threshold
or sentence in resource_line
or self.jaccard_index(sentence, resource_line)
> self.jaccard_threshold
):
citation_links.append(resource_link)
citation_ids.append(index + 1)
cited = True

return citation_ids, citation_links

def format_sentence_with_citations(self, sentence, ids, links):
"""
Appends citations to sentence
"""
if len(ids) == 0:
return sentence

citations = []
for id, url in zip(ids, links):
reference = f"[{id}]({url})"
citations.append(reference)

def add_citations(self, generated: str, resources: list[dict]) -> ValidationResult:
paragraph = generated.split("\n")
new_sentence = sentence[:-1] + " " + ", ".join(citations) + "."
return new_sentence

def add_citations(self, text: str, resources: list[dict]) -> ValidationResult:
paragraph = text.split("\n")
paragraph = [p for p in paragraph if len(p.strip()) > 0]

paragraphs = [
self.split_paragraph_into_sentences(s) for s in paragraph
] # nested list
paragraphs = [self.split_paragraph_into_sentences(s) for s in paragraph]

new_paragraph = []
for one_paragraph in paragraphs:
new_sentence = []
for _, sentence in enumerate(one_paragraph):
links = []
ids = []
for paragraph in paragraphs:
new_sentences = []

# for each sentence in each paragraph
for _, sentence in enumerate(paragraph):
sentence = sentence.strip()
if len(sentence) == 0:
continue

for index, source in enumerate(resources):
cited = False # if this resource is cited
text = source["Document"]
one_sentences = text.split(".")
sub_string = [s.split("\n") for s in one_sentences]
split_texts = self.unfoldList(sub_string)

link = source["Source"]

for j in split_texts:
if len(sentence) > 5 and not cited and not (link in links):
seq = self.longestCommonSubsequence(sentence, j)

contained = False
if sentence in j:
# print("contained", s, j)
contained = True
jaccard = self.jaccard_index(sentence, j)
# print(jaccard)

if (
(seq / len(sentence)) > self.sequence_threshold
or contained
or jaccard > self.jaccard_threshold
):
links.append(link)
ids.append(index + 1)
citations = []
for id, url in zip(ids, links):
reference = f"[{id}]({url})"
citations.append(reference)

if len(citations) > 0:
new_sentence.append(
sentence[:-1] + " " + ", ".join(citations) + "."
)
else:
new_sentence.append(sentence)

new_paragraph.append(" ".join(new_sentence) + "\n")
ids, links = self.add_citation_to_sentence(sentence, resources)
formatted_sentence = self.format_sentence_with_citations(
sentence, ids, links
)
new_sentences.append(formatted_sentence)

new_paragraph.append(" ".join(new_sentences) + "\n")

return ValidationResult(
is_valid=True,
Expand Down
24 changes: 19 additions & 5 deletions src/tests/unit_tests/output_parsers/test_citation_validation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from unittest import mock

from loguru import logger
from pytest import mark

from sherpa_ai.agents import QAAgent
from sherpa_ai.events import EventType
Expand All @@ -16,12 +17,25 @@ def test_citation_validation():
Senate in 1972. As a senator, Biden drafted and led the effort to pass the Violent Crime Control and Law Enforcement Act and the Violence Against Women Act. He also oversaw six U.S. Supreme Court confirmation hearings, including the contentious hearings for Robert Bork and Clarence Thomas.
Biden ran unsuccessfully for the Democratic presidential nomination in 1988 and 2008. In 2008, Obama chose Biden as his running mate, and he was a close counselor to Obama during his two terms as vice president. In the 2020 presidential election, Biden and his running mate, Kamala Harris, defeated incumbents Donald Trump and Mike Pence. He became the oldest president in U.S. history, and the first to have a female vice president.
"""
data = {"Document": text, "Source": "www.wiki_1.com"}
data_1 = {"Document": text, "Source": "www.wiki_1.com"}
data_2 = {"Document": text, "Source": "www.wiki_2.com"}
resource = [data, data_2]
resources = [data_1, data_2]
module = CitationValidation()
result = module.add_citations(text, resource)
assert data["Source"] in result.result
result = module.add_citations(text, resources)
assert result.is_valid
assert result.feedback == ""
assert data_1["Source"] in result.result
assert data_1["Source"] in result.result


@mark.skip("Placeholder for test we should implement")
def test_citation_succeeds_for_longest_common_subsequence():
pass


@mark.skip("Placeholder for test we should implement")
def test_citation_succeeds_for_jaccard_similarity():
pass


def test_task_agent_succeeds(get_llm): # noqa: F811
Expand All @@ -46,11 +60,11 @@ def test_task_agent_succeeds(get_llm): # noqa: F811
"What is AutoGPT?",
)

# get the last response from the LLM as the search mock to simulate the scenario while the LLM uses resources
GOOGLE_SEARCH_MOCK = {
"organic": [
{
"title": "AutoGPT ",
# use the last response from the mock LLM as the search mock to simulate the scenario
"snippet": llm.responses[-1],
"link": "https://www.google.com",
}
Expand Down

0 comments on commit 6a20191

Please sign in to comment.