Skip to content

Commit

Permalink
Add dual bm25 retriever
Browse files Browse the repository at this point in the history
Signed-off-by: ChungYujoyce <[email protected]>
  • Loading branch information
ChungYujoyce committed Mar 9, 2024
1 parent 1692424 commit 3d2ade6
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions libs/langchain/langchain/chains/retrieval_qa/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from langchain.prompts import PromptTemplate
from langchain.pydantic_v1 import Extra, Field, root_validator
from langchain.schema import BaseRetriever, Document
from langchain_core.retrievers import BaseRetriever as BaseRetriever_core
from langchain.schema.language_model import BaseLanguageModel
from langchain.vectorstores.base import VectorStore
from rank_bm25 import BM25Okapi
Expand Down Expand Up @@ -212,6 +213,7 @@ class RetrievalQA(BaseRetrievalQA):
"""

retriever: BaseRetriever = Field(exclude=True)
retriever_bm25: BaseRetriever_core = Field(exclude=True)


def clean_text(self, text: str) -> str:
Expand All @@ -230,7 +232,7 @@ def clean_text(self, text: str) -> str:
text = re.sub(r'\s+', ' ', text)

text = text.strip()

text = text.split(" ")
return text


Expand All @@ -241,10 +243,10 @@ def _get_docs_bm25(
) -> List[Document]:

# Tokenize documents
tokenized_documents = [self.clean_text(doc.page_content).split(" ") for doc in documents]
tokenized_documents = [self.clean_text(doc.page_content) for doc in documents]

# Tokenize query
tokenized_query = self.clean_text(question).split(" ")
tokenized_query = self.clean_text(question)

# Create BM25 object
bm25 = BM25Okapi(tokenized_documents)
Expand All @@ -267,11 +269,16 @@ def _get_docs(
run_manager: CallbackManagerForChainRun,
) -> List[Document]:
"""Get docs."""
documents = self.retriever.get_relevant_documents(
question, callbacks=run_manager.get_child()
)

documents = self._get_docs_bm25(question, documents)
if '-' in question:
documents = self.retriever_bm25.get_relevant_documents(
question, callbacks=run_manager.get_child()
)
else:
documents = self.retriever.get_relevant_documents(
question, callbacks=run_manager.get_child()
)

documents = self._get_docs_bm25(question, documents)
return documents


Expand Down

0 comments on commit 3d2ade6

Please sign in to comment.