Skip to content

Commit

Permalink
IL-167 Add ids to RetrieverBasedQa answers (#412)
Browse files Browse the repository at this point in the history
* IL-167 Add ids to RetrieverBasedQa answers

* IL-167 add documentation and expose new classes
  • Loading branch information
NiklasKoehneckeAA authored Jan 24, 2024
1 parent b0229f9 commit 37272dc
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 8 deletions.
2 changes: 2 additions & 0 deletions src/intelligence_layer/use_cases/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@
from .qa.long_context_qa import LongContextQaInput as LongContextQaInput
from .qa.multiple_chunk_qa import MultipleChunkQa as MultipleChunkQa
from .qa.multiple_chunk_qa import MultipleChunkQaInput as MultipleChunkQaInput
from .qa.retriever_based_qa import EnrichedSubanswer as EnrichedSubanswer
from .qa.retriever_based_qa import RetrieverBasedQa as RetrieverBasedQa
from .qa.retriever_based_qa import RetrieverBasedQaInput as RetrieverBasedQaInput
from .qa.retriever_based_qa import RetrieverBasedQaOutput as RetrieverBasedQaOutput
from .qa.single_chunk_qa import SingleChunkQa as SingleChunkQa
from .qa.single_chunk_qa import SingleChunkQaInput as SingleChunkQaInput
from .qa.single_chunk_qa import SingleChunkQaOutput as SingleChunkQaOutput
Expand Down
59 changes: 55 additions & 4 deletions src/intelligence_layer/use_cases/qa/retriever_based_qa.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Generic, Optional, Sequence

from pydantic import BaseModel

from intelligence_layer.connectors.limited_concurrency_client import (
Expand All @@ -11,7 +13,7 @@
from intelligence_layer.use_cases.qa.multiple_chunk_qa import (
MultipleChunkQa,
MultipleChunkQaInput,
MultipleChunkQaOutput,
Subanswer,
)
from intelligence_layer.use_cases.search.search import Search, SearchInput

Expand All @@ -29,7 +31,35 @@ class RetrieverBasedQaInput(BaseModel):
language: Language = Language("en")


class RetrieverBasedQa(Task[RetrieverBasedQaInput, MultipleChunkQaOutput]):
class EnrichedSubanswer(Subanswer, Generic[ID]):
"""Individual answer for a chunk that also contains the origin of the chunk.
Attributes:
answer: The answer generated by the task. Can be a string or None (if no answer was found).
chunk: Piece of the original text that answer is based on.
highlights: The specific sentences that explain the answer the most.
These are generated by the `TextHighlight` Task.
id: The id of the document where the chunk came from.
"""

id: ID


class RetrieverBasedQaOutput(BaseModel, Generic[ID]):
"""The output of a `RetrieverBasedQa` task.
Attributes:
answer: The answer generated by the task. Can be a string or None (if no answer was found).
subanswers: All the subanswers used to generate the answer.
"""

answer: Optional[str]
subanswers: Sequence[EnrichedSubanswer[ID]]


class RetrieverBasedQa(
Task[RetrieverBasedQaInput, RetrieverBasedQaOutput[ID]], Generic[ID]
):
"""Answer a question based on documents found by a retriever.
`RetrieverBasedQa` is a task that answers a question based on a set of documents.
Expand Down Expand Up @@ -78,7 +108,7 @@ def __init__(

def do_run(
self, input: RetrieverBasedQaInput, task_span: TaskSpan
) -> MultipleChunkQaOutput:
) -> RetrieverBasedQaOutput[ID]:
search_output = self._search.run(SearchInput(query=input.question), task_span)

multi_chunk_qa_input = MultipleChunkQaInput(
Expand All @@ -88,4 +118,25 @@ def do_run(
question=input.question,
language=input.language,
)
return self._multi_chunk_qa.run(multi_chunk_qa_input, task_span)

result = self._multi_chunk_qa.run(multi_chunk_qa_input, task_span)

# multi_chunk_qa does not known IDs so we need to rematch them
text_to_id = {
document.document_chunk.text: document.id
for document in search_output.results
}
enriched_answers = [
EnrichedSubanswer(
answer=answer.answer,
chunk=answer.chunk,
highlights=answer.highlights,
id=text_to_id[answer.chunk],
)
for answer in result.subanswers
]
correctly_formatted_output = RetrieverBasedQaOutput(
answer=result.answer,
subanswers=enriched_answers,
)
return correctly_formatted_output
11 changes: 7 additions & 4 deletions tests/use_cases/qa/test_retriever_based_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from pytest import fixture

from intelligence_layer.connectors.document_index.document_index import DocumentPath
from intelligence_layer.connectors.limited_concurrency_client import (
AlephAlphaClientProtocol,
)
Expand Down Expand Up @@ -41,7 +42,7 @@ def in_memory_retriever_documents() -> Sequence[Document]:
def retriever_based_qa_with_in_memory_retriever(
client: AlephAlphaClientProtocol,
asymmetric_in_memory_retriever: QdrantInMemoryRetriever,
) -> RetrieverBasedQa:
) -> RetrieverBasedQa[int]:
return RetrieverBasedQa(
client, asymmetric_in_memory_retriever, model="luminous-base-control"
)
Expand All @@ -50,29 +51,31 @@ def retriever_based_qa_with_in_memory_retriever(
@fixture
def retriever_based_qa_with_document_index(
client: AlephAlphaClientProtocol, document_index_retriever: DocumentIndexRetriever
) -> RetrieverBasedQa:
) -> RetrieverBasedQa[DocumentPath]:
return RetrieverBasedQa(
client, document_index_retriever, model="luminous-base-control"
)


def test_retriever_based_qa_using_in_memory_retriever(
retriever_based_qa_with_in_memory_retriever: RetrieverBasedQa,
retriever_based_qa_with_in_memory_retriever: RetrieverBasedQa[int],
no_op_tracer: NoOpTracer,
) -> None:
question = "When was Robert Moses born?"
input = RetrieverBasedQaInput(question=question)
output = retriever_based_qa_with_in_memory_retriever.run(input, no_op_tracer)
assert output.answer
assert "1888" in output.answer
assert output.subanswers[0].id == 3


def test_retriever_based_qa_with_document_index(
retriever_based_qa_with_document_index: RetrieverBasedQa,
retriever_based_qa_with_document_index: RetrieverBasedQa[DocumentPath],
no_op_tracer: NoOpTracer,
) -> None:
question = "When was Robert Moses born?"
input = RetrieverBasedQaInput(question=question)
output = retriever_based_qa_with_document_index.run(input, no_op_tracer)
assert output.answer
assert "1888" in output.answer
assert output.subanswers[0].id.document_name == "Robert Moses (Begriffsklärung)"

0 comments on commit 37272dc

Please sign in to comment.