Skip to content

Commit

Permalink
Neighbouring Paragraphs RAG Strategy (#2410)
Browse files Browse the repository at this point in the history
  • Loading branch information
lferran authored Aug 27, 2024
1 parent eda92ca commit 9171b54
Show file tree
Hide file tree
Showing 35 changed files with 988 additions and 598 deletions.
31 changes: 18 additions & 13 deletions nucliadb/src/nucliadb/common/external_index_providers/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,20 +77,21 @@ def iter_matching_text_blocks(self) -> Iterator[TextBlockMatch]:
for order, matching_vector in enumerate(self.results.matches):
try:
vector_id = VectorId.from_string(matching_vector.id)
paragraph_id = ParagraphId.from_vector_id(vector_id)
except ValueError: # pragma: no cover
logger.error(f"Invalid Pinecone vector id: {matching_vector.id}")
continue
vector_metadata = VectorMetadata.model_validate(matching_vector.metadata) # noqa
yield TextBlockMatch(
text=None, # To be filled by the results hydrator
id=matching_vector.id,
resource_id=vector_id.field_id.rid,
field_id=vector_id.field_id.full(),
id=paragraph_id.full(),
resource_id=paragraph_id.field_id.rid,
field_id=paragraph_id.field_id.full(),
score=matching_vector.score,
order=order,
position_start=vector_id.vector_start,
position_end=vector_id.vector_end,
subfield_id=vector_id.field_id.subfield_id,
position_start=paragraph_id.paragraph_start,
position_end=paragraph_id.paragraph_end,
subfield_id=paragraph_id.field_id.subfield_id,
index=vector_id.index,
position_start_seconds=list(map(int, vector_metadata.position_start_seconds or [])),
position_end_seconds=list(map(int, vector_metadata.position_end_seconds or [])),
Expand Down Expand Up @@ -438,13 +439,17 @@ def get_index_host(self, vectorset_id: str, rollover: bool = False) -> str:

def get_prefixes_to_delete(self, index_data: Resource) -> set[str]:
prefixes_to_delete = set()
for sentence_id in index_data.sentences_to_delete:
for field_id in index_data.sentences_to_delete:
try:
delete_field = FieldId.from_string(sentence_id)
prefixes_to_delete.add(delete_field.full())
delete_vid = VectorId.from_string(field_id)
prefixes_to_delete.add(delete_vid.field_id.full())
except ValueError: # pragma: no cover
logger.warning(f"Invalid id to delete: {sentence_id}. VectorId expected.")
continue
try:
delete_field = FieldId.from_string(field_id)
prefixes_to_delete.add(delete_field.full())
except ValueError:
logger.warning(f"Invalid id to delete sentences from: {field_id}.")
continue
for paragraph_id in index_data.paragraphs_to_delete:
try:
delete_pid = ParagraphId.from_string(paragraph_id)
Expand Down Expand Up @@ -581,8 +586,8 @@ def _compute_base_vector_metadatas(
fid = ParagraphId.from_string(paragraph_id).field_id
vector_metadata = VectorMetadata(
rid=resource_uuid,
field_type=fid.field_type,
field_id=fid.field_id,
field_type=fid.type,
field_id=fid.key,
date_created=date_created,
date_modified=date_modified,
security_public=security_public,
Expand Down
119 changes: 105 additions & 14 deletions nucliadb/src/nucliadb/common/ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,45 +26,93 @@
from dataclasses import dataclass
from typing import Optional

# Index
from nucliadb_protos.resources_pb2 import FieldType

FIELD_TYPE_STR_TO_PB: dict[str, FieldType.ValueType] = {
"t": FieldType.TEXT,
"f": FieldType.FILE,
"u": FieldType.LINK,
"a": FieldType.GENERIC,
"c": FieldType.CONVERSATION,
}

FIELD_TYPE_PB_TO_STR = {v: k for k, v in FIELD_TYPE_STR_TO_PB.items()}


@dataclass
class FieldId:
"""
Field ids are used to identify fields in resources. They usually have the following format:
`rid/field_type/field_key`
where field type is one of: `t`, `f`, `u`, `a`, `c` (text, file, link, generic, conversation)
and field_key is an identifier for that field type on the resource, usually chosen by the user.
In some cases, fields can have subfields, for example, in conversations, where each part of the
conversation is a subfield. In those cases, the id has the following format:
`rid/field_type/field_key/subfield_id`
Examples:
>>> FieldId(rid="rid", type="u", key="/my-link")
FieldID("rid/u/my-link")
>>> FieldId.from_string("rid/u/my-link")
FieldID("rid/u/my-link")
"""

rid: str
field_id: str
type: str
key: str
# also knwon as `split`, this indicates a part of a field in, for example, conversations
subfield_id: Optional[str] = None

def __repr__(self) -> str:
return f"FieldId({self.full()})"

def full(self) -> str:
if self.subfield_id is None:
return f"{self.rid}/{self.field_id}"
return f"{self.rid}/{self.type}/{self.key}"
else:
return f"{self.rid}/{self.field_id}/{self.subfield_id}"
return f"{self.rid}/{self.type}/{self.key}/{self.subfield_id}"

def __hash__(self) -> int:
return hash(self.full())

@property
def field_type(self) -> str:
return self.field_id.split("/")[0]
def pb_type(self) -> FieldType.ValueType:
return FIELD_TYPE_STR_TO_PB[self.type]

@classmethod
def from_string(cls, value: str) -> "FieldId":
"""
Parse a FieldId from a string
Example:
>>> FieldId.from_string("rid/u/field_id")
FieldId(rid="rid", field_id="u/field_id")
>>> FieldId.from_string("rid/u/field_id/subfield_id")
FieldId(rid="rid", field_id="u/field_id", subfield_id="subfield_id")
>>> fid = FieldId.from_string("rid/u/foo")
>>> fid
FieldId("rid/u/foo")
>>> fid.type
'u'
>>> fid.key
'foo'
>>> FieldId.from_string("rid/u/foo/subfield_id").subfield_id
'subfield_id'
"""
parts = value.split("/")
if len(parts) == 3:
rid, field_type, field_id = parts
return cls(rid=rid, field_id=f"{field_type}/{field_id}")
rid, _type, key = parts
if _type not in FIELD_TYPE_STR_TO_PB:
raise ValueError(f"Invalid FieldId: {value}")
return cls(rid=rid, type=_type, key=key)
elif len(parts) == 4:
rid, field_type, field_id, subfield_id = parts
rid, _type, key, subfield_id = parts
if _type not in FIELD_TYPE_STR_TO_PB:
raise ValueError(f"Invalid FieldId: {value}")
return cls(
rid=rid,
field_id=f"{field_type}/{field_id}",
type=_type,
key=key,
subfield_id=subfield_id,
)
else:
Expand All @@ -77,9 +125,15 @@ class ParagraphId:
paragraph_start: int
paragraph_end: int

def __repr__(self) -> str:
return f"ParagraphId({self.full()})"

def full(self) -> str:
return f"{self.field_id.full()}/{self.paragraph_start}-{self.paragraph_end}"

def __hash__(self) -> int:
return hash(self.full())

@classmethod
def from_string(cls, value: str) -> "ParagraphId":
parts = value.split("/")
Expand All @@ -88,17 +142,54 @@ def from_string(cls, value: str) -> "ParagraphId":
field_id = FieldId.from_string("/".join(parts[:-1]))
return cls(field_id=field_id, paragraph_start=start, paragraph_end=end)

@classmethod
def from_vector_id(cls, vid: "VectorId") -> "ParagraphId":
"""
Returns a ParagraphId from a vector_key (the index part of the vector_key is ignored).
>>> vid = VectorId.from_string("rid/u/field_id/0/0-1")
>>> ParagraphId.from_vector_id(vid)
ParagraphId("rid/u/field_id/0-1")
"""
return cls(
field_id=vid.field_id,
paragraph_start=vid.vector_start,
paragraph_end=vid.vector_end,
)


@dataclass
class VectorId:
"""
Ids of vectors are very similar to ParagraphIds, but for legacy reasons, they have an index
indicating the position of the corresponding text block in the list of text blocks for the field.
Examples:
>>> VectorId.from_string("rid/u/field_id/0/0-10")
VectorId("rid/u/field_id/0/0-10")
>>> VectorId(
... field_id=FieldId.from_string("rid/u/field_id"),
... index=0,
... vector_start=0,
... vector_end=10,
... )
VectorId("rid/u/field_id/0/0-10")
"""

field_id: FieldId
index: int
vector_start: int
vector_end: int

def __repr__(self) -> str:
return f"VectorId({self.full()})"

def full(self) -> str:
return f"{self.field_id.full()}/{self.index}/{self.vector_start}-{self.vector_end}"

def __hash__(self) -> int:
return hash(self.full())

@classmethod
def from_string(cls, value: str) -> "VectorId":
parts = value.split("/")
Expand Down
31 changes: 20 additions & 11 deletions nucliadb/src/nucliadb/ingest/orm/brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,10 @@ def apply_field_metadata(
self.brain.paragraphs_to_delete.append(f"{self.rid}/{field_key}/{paragraph_to_delete}")

def delete_metadata(self, field_key: str, metadata: FieldComputedMetadata):
ftype, fkey = field_key.split("/")
for subfield, metadata_split in metadata.split_metadata.items():
self.brain.paragraphs_to_delete.append(
ids.FieldId(rid=self.rid, field_id=field_key, subfield_id=subfield).full()
ids.FieldId(rid=self.rid, type=ftype, key=fkey, subfield_id=subfield).full()
)
# TODO: Bw/c, remove this when paragraph deletion by field_id gets
# promoted
Expand All @@ -240,7 +241,9 @@ def delete_metadata(self, field_key: str, metadata: FieldComputedMetadata):
)

for paragraph in metadata.metadata.paragraphs:
self.brain.paragraphs_to_delete.append(ids.FieldId(rid=self.rid, field_id=field_key).full())
self.brain.paragraphs_to_delete.append(
ids.FieldId(rid=self.rid, type=ftype, key=fkey).full()
)
# TODO: Bw/c, remove this when paragraph deletion by field_id gets
# promoted
self.brain.paragraphs_to_delete.append(
Expand All @@ -258,11 +261,12 @@ def apply_field_vectors(
matryoshka_vector_dimension: Optional[int] = None,
):
replace_splits = replace_splits or []

fid = ids.FieldId.from_string(f"{self.rid}/{field_id}")
for subfield, vectors in vo.split_vectors.items():
_field_id = ids.FieldId(
rid=self.rid,
field_id=field_id,
rid=fid.rid,
type=fid.type,
key=fid.key,
subfield_id=subfield,
)
# For each split of this field
Expand All @@ -288,8 +292,9 @@ def apply_field_vectors(
)

_field_id = ids.FieldId(
rid=self.rid,
field_id=field_id,
rid=fid.rid,
type=fid.type,
key=fid.key,
)
for index, vector in enumerate(vo.vectors.vectors):
paragraph_key = ids.ParagraphId(
Expand All @@ -314,15 +319,19 @@ def apply_field_vectors(

for split in replace_splits:
self.brain.sentences_to_delete.append(
ids.FieldId(rid=self.rid, field_id=field_id, subfield_id=split).full()
ids.FieldId(rid=self.rid, type=fid.type, key=fid.key, subfield_id=split).full()
)
self.brain.paragraphs_to_delete.append(
ids.FieldId(rid=self.rid, field_id=field_id, subfield_id=split).full()
ids.FieldId(rid=self.rid, type=fid.type, key=fid.key, subfield_id=split).full()
)

if replace_field:
self.brain.sentences_to_delete.append(ids.FieldId(rid=self.rid, field_id=field_id).full())
self.brain.paragraphs_to_delete.append(ids.FieldId(rid=self.rid, field_id=field_id).full())
self.brain.sentences_to_delete.append(
ids.FieldId(rid=self.rid, type=fid.type, key=fid.key).full()
)
self.brain.paragraphs_to_delete.append(
ids.FieldId(rid=self.rid, type=fid.type, key=fid.key).full()
)

def _apply_field_vector(
self,
Expand Down
15 changes: 3 additions & 12 deletions nucliadb/src/nucliadb/ingest/orm/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from nucliadb.common import datamanagers
from nucliadb.common.datamanagers.resources import KB_RESOURCE_FIELDS, KB_RESOURCE_SLUG
from nucliadb.common.ids import FIELD_TYPE_PB_TO_STR, FIELD_TYPE_STR_TO_PB
from nucliadb.common.maindb.driver import Transaction
from nucliadb.ingest.fields.base import Field
from nucliadb.ingest.fields.conversation import Conversation
Expand Down Expand Up @@ -92,16 +93,6 @@
FieldType.CONVERSATION: Conversation,
}

KB_REVERSE: dict[str, FieldType.ValueType] = {
"t": FieldType.TEXT,
"f": FieldType.FILE,
"u": FieldType.LINK,
"a": FieldType.GENERIC,
"c": FieldType.CONVERSATION,
}

FIELD_TYPE_TO_ID = {v: k for k, v in KB_REVERSE.items()}

_executor = ThreadPoolExecutor(10)


Expand Down Expand Up @@ -407,7 +398,7 @@ async def _deprecated_scan_fields_ids(
# The [6:8] `slicing purpose is to match exactly the two
# splitted parts corresponding to type and field, and nothing else!
type, field = key.split("/")[6:8]
type_id = KB_REVERSE.get(type)
type_id = FIELD_TYPE_STR_TO_PB.get(type)
if type_id is None:
raise AttributeError("Invalid field type")
result = (type_id, field)
Expand Down Expand Up @@ -823,7 +814,7 @@ async def _apply_field_large_metadata(self, field_large_metadata: LargeComputedM
await field_obj.set_large_field_metadata(field_large_metadata)

def generate_field_id(self, field: FieldID) -> str:
return f"{FIELD_TYPE_TO_ID[field.field_type]}/{field.field}"
return f"{FIELD_TYPE_PB_TO_STR[field.field_type]}/{field.field}"

async def compute_security(self, brain: ResourceBrain):
security = await self.get_security()
Expand Down
4 changes: 2 additions & 2 deletions nucliadb/src/nucliadb/reader/api/v1/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from starlette.datastructures import Headers
from starlette.responses import StreamingResponse

from nucliadb.ingest.orm.resource import FIELD_TYPE_TO_ID
from nucliadb.common.ids import FIELD_TYPE_PB_TO_STR
from nucliadb.ingest.serialize import get_resource_uuid_by_slug
from nucliadb.reader import SERVICE_NAME, logger
from nucliadb.reader.api.models import FIELD_NAMES_TO_PB_TYPE_MAP
Expand Down Expand Up @@ -98,7 +98,7 @@ async def _download_extract_file(
storage = await get_storage(service_name=SERVICE_NAME)

pb_field_type = FIELD_NAMES_TO_PB_TYPE_MAP[field_type]
field_type_letter = FIELD_TYPE_TO_ID[pb_field_type]
field_type_letter = FIELD_TYPE_PB_TO_STR[pb_field_type]

sf = storage.file_extracted(kbid, rid, field_type_letter, field_id, download_field)

Expand Down
Loading

0 comments on commit 9171b54

Please sign in to comment.