Skip to content

Commit

Permalink
IL-167 Add ids to RetrieverBasedQa answers
Browse files Browse the repository at this point in the history
  • Loading branch information
NiklasKoehneckeAA committed Jan 23, 2024
1 parent 4580de8 commit 319c594
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 8 deletions.
49 changes: 45 additions & 4 deletions src/intelligence_layer/use_cases/qa/retriever_based_qa.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Generic, Optional, Sequence

from pydantic import BaseModel

from intelligence_layer.connectors.limited_concurrency_client import (
Expand All @@ -11,7 +13,7 @@
from intelligence_layer.use_cases.qa.multiple_chunk_qa import (
MultipleChunkQa,
MultipleChunkQaInput,
MultipleChunkQaOutput,
Subanswer,
)
from intelligence_layer.use_cases.search.search import Search, SearchInput

Expand All @@ -29,7 +31,25 @@ class RetrieverBasedQaInput(BaseModel):
language: Language = Language("en")


class RetrieverBasedQa(Task[RetrieverBasedQaInput, MultipleChunkQaOutput]):
class EnrichedSubanswer(Subanswer, Generic[ID]):
id: ID


class RetrieverBasedQaOutput(BaseModel, Generic[ID]):
"""The output of a `RetrieverBasedQa` task.
Attributes:
answer: The answer generated by the task. Can be a string or None (if no answer was found).
subanswers: All the subanswers used to generate the answer.
"""

answer: Optional[str]
subanswers: Sequence[EnrichedSubanswer[ID]]


class RetrieverBasedQa(
Task[RetrieverBasedQaInput, RetrieverBasedQaOutput[ID]], Generic[ID]
):
"""Answer a question based on documents found by a retriever.
`RetrieverBasedQa` is a task that answers a question based on a set of documents.
Expand Down Expand Up @@ -78,7 +98,7 @@ def __init__(

def do_run(
self, input: RetrieverBasedQaInput, task_span: TaskSpan
) -> MultipleChunkQaOutput:
) -> RetrieverBasedQaOutput[ID]:
search_output = self._search.run(SearchInput(query=input.question), task_span)

multi_chunk_qa_input = MultipleChunkQaInput(
Expand All @@ -88,4 +108,25 @@ def do_run(
question=input.question,
language=input.language,
)
return self._multi_chunk_qa.run(multi_chunk_qa_input, task_span)

result = self._multi_chunk_qa.run(multi_chunk_qa_input, task_span)

# multi_chunk_qa does not known IDs so we need to rematch them
text_to_id = {
document.document_chunk.text: document.id
for document in search_output.results
}
enriched_answers = [
EnrichedSubanswer(
answer=answer.answer,
chunk=answer.chunk,
highlights=answer.highlights,
id=text_to_id[answer.chunk],
)
for answer in result.subanswers
]
correctly_formatted_output = RetrieverBasedQaOutput(
answer=result.answer,
subanswers=enriched_answers,
)
return correctly_formatted_output
11 changes: 7 additions & 4 deletions tests/use_cases/qa/test_retriever_based_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from pytest import fixture

from intelligence_layer.connectors.document_index.document_index import DocumentPath
from intelligence_layer.connectors.limited_concurrency_client import (
AlephAlphaClientProtocol,
)
Expand Down Expand Up @@ -41,7 +42,7 @@ def in_memory_retriever_documents() -> Sequence[Document]:
def retriever_based_qa_with_in_memory_retriever(
client: AlephAlphaClientProtocol,
asymmetric_in_memory_retriever: QdrantInMemoryRetriever,
) -> RetrieverBasedQa:
) -> RetrieverBasedQa[int]:
return RetrieverBasedQa(
client, asymmetric_in_memory_retriever, model="luminous-base-control"
)
Expand All @@ -50,29 +51,31 @@ def retriever_based_qa_with_in_memory_retriever(
@fixture
def retriever_based_qa_with_document_index(
client: AlephAlphaClientProtocol, document_index_retriever: DocumentIndexRetriever
) -> RetrieverBasedQa:
) -> RetrieverBasedQa[DocumentPath]:
return RetrieverBasedQa(
client, document_index_retriever, model="luminous-base-control"
)


def test_retriever_based_qa_using_in_memory_retriever(
retriever_based_qa_with_in_memory_retriever: RetrieverBasedQa,
retriever_based_qa_with_in_memory_retriever: RetrieverBasedQa[int],
no_op_tracer: NoOpTracer,
) -> None:
question = "When was Robert Moses born?"
input = RetrieverBasedQaInput(question=question)
output = retriever_based_qa_with_in_memory_retriever.run(input, no_op_tracer)
assert output.answer
assert "1888" in output.answer
assert output.subanswers[0].id == 3


def test_retriever_based_qa_with_document_index(
retriever_based_qa_with_document_index: RetrieverBasedQa,
retriever_based_qa_with_document_index: RetrieverBasedQa[DocumentPath],
no_op_tracer: NoOpTracer,
) -> None:
question = "When was Robert Moses born?"
input = RetrieverBasedQaInput(question=question)
output = retriever_based_qa_with_document_index.run(input, no_op_tracer)
assert output.answer
assert "1888" in output.answer
assert output.subanswers[0].id.document_name == "Robert Moses (Begriffsklärung)"

0 comments on commit 319c594

Please sign in to comment.