Skip to content

Commit

Permalink
Merge pull request #385 from 20001LastOrder/reranking
Browse files Browse the repository at this point in the history
Reranking
  • Loading branch information
amirfz authored Jun 12, 2024
2 parents 0644e4c + 4a99a47 commit ba9e631
Show file tree
Hide file tree
Showing 18 changed files with 347 additions and 190 deletions.
15 changes: 8 additions & 7 deletions src/sherpa_ai/actions/arxiv_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,30 @@ class ArxivSearch(BaseRetrievalAction):
task: str
llm: Any # The BaseLanguageModel from LangChain is not compatible with Pydantic 2 yet
description: str = SEARCH_SUMMARY_DESCRIPTION
max_results: int = 5
_search_tool: Any

# Override the name and args from BaseAction
name: str = "ArxivSearch"
args: dict = {"query": "string"}
usage: str = "Search paper on the Arxiv website"
perform_refinement: bool = True

def __init__(self, **kwargs):
super().__init__(**kwargs)
self._search_tool = SearchArxivTool()

def execute(self, query) -> str:
result, resources = self._search_tool._run(query, return_resources=True)
def search(self, query) -> list[dict]:
resources = self._search_tool._run(query, return_resources=True)
self.add_resources(resources)

return resources

def refine(self, result: str) -> str:
prompt = self.description.format(
task=self.task,
paper_title_summary=result,
n=self.max_results,
n=self.num_documents,
role_description=self.role_description,
)

result = self.llm.predict(prompt)

return result
return self.llm.predict(prompt)
48 changes: 46 additions & 2 deletions src/sherpa_ai/actions/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

import json
from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import cached_property

from loguru import logger
from pydantic import BaseModel, Field

from sherpa_ai.actions.utils.reranking import BaseReranking


class ActionResource(BaseModel):
"""
Expand Down Expand Up @@ -36,6 +39,12 @@ def __str__(self):

class BaseRetrievalAction(BaseAction, ABC):
resources: list[ActionResource] = Field(default_factory=list)
num_documents: int = 5 # Number of documents to retrieve
reranker: BaseReranking = None
current_task: str = ""

perform_reranking: bool = False
perform_refinement: bool = False

def add_resources(self, resources: list[dict]):
action_resources = self.resources
Expand All @@ -45,3 +54,38 @@ def add_resources(self, resources: list[dict]):
action_resources.append(
ActionResource(source=resource["Source"], content=resource["Document"])
)

def execute(self, query: str) -> str:
results = self.search(query)

results = [result["Document"] for result in results]

if self.perform_reranking:
results = self.reranking(results)

results = "\n\n".join(results)
logger.debug("Action Results: {}", results)

if self.perform_refinement:
results = self.refine(results)

return results

@abstractmethod
def search(self, query: str) -> str:
"""
Search for relevant documents based on the query.
"""
pass

def reranking(self, documents: list[str]) -> list[str]:
"""
Rerank the documents based on the query.
"""
return self.reranker.rerank(documents, self.current_task)

def refine(self, results: str) -> str:
"""
Refine the results based on the query.
"""
return results
18 changes: 9 additions & 9 deletions src/sherpa_ai/actions/context_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,33 +24,33 @@ class ContextSearch(BaseRetrievalAction):
task: str
llm: Any # The BaseLanguageModel from LangChain is not compatible with Pydantic 2 yet
description: str = SEARCH_SUMMARY_DESCRIPTION
n: int = 5
_context: Any

# Override the name and args from BaseAction
name: str = "Context Search"
args: dict = {"query": "string"}
usage: str = "Search the conversation history with the user"
perform_refinement: bool = True

def __init__(self, **kwargs):
super().__init__(**kwargs)
self._context = ContextTool(memory=get_vectordb())

def execute(self, query) -> str:
result, resources = self._context._run(query, return_resources=True)
def search(self, query) -> str:
resources = self._context._run(query, return_resources=True)

self.add_resources(resources)

# result = "Context Search"
logger.debug("Context Search Result: {}", result)
# logger.debug("Context Search Result: {}", result)

return resources

def refine(self, result: str) -> str:
prompt = self.description.format(
task=self.task,
documents=result,
n=self.n,
n=self.num_documents,
role_description=self.role_description,
)

result = self.llm.predict(prompt)

return result
return self.llm.predict(prompt)
11 changes: 4 additions & 7 deletions src/sherpa_ai/actions/google_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ class GoogleSearch(BaseRetrievalAction):
task: str
llm: Any # The BaseLanguageModel from LangChain is not compatible with Pydantic 2 yet
description: str = SEARCH_SUMMARY_DESCRIPTION
n: int = 5
config: AgentConfig = AgentConfig()
_search_tool: Any

Expand All @@ -47,12 +46,10 @@ class GoogleSearch(BaseRetrievalAction):

def __init__(self, **kwargs):
super().__init__(**kwargs)
self._search_tool = SearchTool(config=self.config)
self._search_tool = SearchTool(config=self.config, top_k=self.num_documents)

def execute(self, query) -> str:
result, resources = self._search_tool._run(query, return_resources=True)
def search(self, query) -> list[dict]:
resources = self._search_tool._run(query, return_resources=True)
self.add_resources(resources)

logger.debug("Search Result: {}", result)

return result
return resources
Empty file.
42 changes: 42 additions & 0 deletions src/sherpa_ai/actions/utils/reranking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""
Different methods for reranking the results of a search query.
"""

from abc import ABC, abstractmethod
from typing import Any, Callable

import numpy as np
from numpy.typing import ArrayLike
from pydantic import BaseModel


def cosine_similarity(v1: ArrayLike, v2: ArrayLike) -> float:
return np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))


class BaseReranking(ABC, BaseModel):
@abstractmethod
def rerank(self, documents: list[str], **kwargs) -> str:
pass


class RerankingByQuery(BaseReranking):
embeddings: Any # takes an Embedding Object from LangChain, use Any since it is not compatible with Pydantic 2 yet
distance_metric: Callable[[ArrayLike, ArrayLike], float] = cosine_similarity

def rerank(self, documents: list[str], query: str) -> str:
query_embedding = self.embeddings.embed_query(query)
document_embeddings = self.embeddings.embed_documents(documents)

# Calculate the similarity between the query and each document
similarities = [
self.distance_metric(query_embedding, doc_embedding)
for doc_embedding in document_embeddings
]

# Sort the documents by similarity
sorted_documents = [
doc for _, doc in sorted(zip(similarities, documents), reverse=True)
]

return sorted_documents
2 changes: 1 addition & 1 deletion src/sherpa_ai/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from langchain.base_language import BaseLanguageModel
from loguru import logger

from sherpa_ai.actions.base import BaseAction
from sherpa_ai.actions.base import BaseAction, BaseRetrievalAction
from sherpa_ai.events import EventType
from sherpa_ai.output_parsers.base import BaseOutputProcessor
from sherpa_ai.policies.base import BasePolicy
Expand Down
8 changes: 7 additions & 1 deletion src/sherpa_ai/memory/belief.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Callable, List, Optional

from sherpa_ai.actions.base import BaseAction
from sherpa_ai.actions.base import BaseAction, BaseRetrievalAction
from sherpa_ai.events import Event, EventType


Expand Down Expand Up @@ -134,6 +134,12 @@ def token_counter(x):
def set_actions(self, actions: List[BaseAction]):
self.actions = actions

# TODO: This is a quick an dirty way to set the current task
# in actions, need to find a better way
for action in actions:
if isinstance(action, BaseRetrievalAction):
action.current_task = self.current_task.content

@property
def action_description(self):
return "\n".join([str(action) for action in self.actions])
Expand Down
2 changes: 1 addition & 1 deletion src/sherpa_ai/output_parsers/entity_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def process_output(
"""

source = belief.get_histories_excluding_types(
exclude_types=[EventType.feedback, EventType.result],
exclude_types=[EventType.feedback, EventType.result, EventType.action],
)
entity_exist_in_source, error_message = self.check_entities_match(
text, source, self.similarity_picker(self.count), llm
Expand Down
40 changes: 18 additions & 22 deletions src/sherpa_ai/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def _run(
logger.debug(f"Arxiv Search Result: {result_list}")

if return_resources:
return result, resources
return resources
else:
return result

Expand All @@ -100,9 +100,7 @@ class SearchTool(BaseTool):
"you cannot find the information using internal search."
)

def _run(
self, query: str, return_resources=False
) -> Union[str, Tuple[str, List[dict]]]:
def _run(self, query: str, return_resources=False) -> Union[str, List[dict]]:
result = ""
if self.config.search_domains:
query_list = [
Expand All @@ -111,18 +109,14 @@ def _run(
]
if len(query_list) >= 5:
query_list = query_list[:5]
result = (
result
+ "Warning: Only the first 5 URLs are taken into consideration.\n"
) # noqa: E501
logger.warning("Only the first 5 URLs are taken into consideration.")
else:
query_list = [query]
if self.config.invalid_domains:
invalid_domain_string = ", ".join(self.config.invalid_domains)
result = (
result
+ f"Warning: The doman {invalid_domain_string} is invalid and is not taken into consideration.\n" # noqa: E501
) # noqa: E501
logger.warning(
f"The domain {invalid_domain_string} is invalid and is not taken into consideration." # noqa: E501
)

top_k = int(self.top_k / len(query_list))
if return_resources:
Expand All @@ -132,22 +126,21 @@ def _run(
cur_result = self._run_single_query(query, top_k, return_resources)

if return_resources:
result += "\n" + cur_result[0]
resources.extend(cur_result[1])
resources += cur_result
else:
result += "\n" + cur_result

if return_resources:
result = (result, resources)

return result
return resources
else:
return result

def formulate_site_search(self, query: str, site: str) -> str:
return query + " site:" + site

def _run_single_query(
self, query: str, top_k: int, return_resources=False
) -> Union[str, Tuple[str, List[dict]]]:
) -> Union[str, List[dict]]:
logger.debug(f"Search query: {query}")
google_serper = GoogleSerperAPIWrapper()
search_results = google_serper._google_serper_api_results(query)
Expand All @@ -168,7 +161,7 @@ def _run_single_query(
response = "Answer: " + answer
meta = [{"Document": answer, "Source": link}]
if return_resources:
return response, meta
return meta
else:
return response + "\nLink:" + link

Expand Down Expand Up @@ -202,7 +195,10 @@ def _run_single_query(
snippets.append(f"{attribute}: {value}.")

if len(snippets) == 0:
return ["No good Google Search Result was found"]
if return_resources:
return []
else:
return "No good Google Search Result was found"

result = []

Expand Down Expand Up @@ -240,7 +236,7 @@ def _run_single_query(
)
full_result = answer + "\n\n" + full_result
if return_resources:
return full_result, resources
return resources
else:
return full_result

Expand Down Expand Up @@ -280,7 +276,7 @@ def _run(
)

if return_resources:
return result, resources
return resources
else:
return result

Expand Down
Loading

0 comments on commit ba9e631

Please sign in to comment.