Skip to content

Commit

Permalink
Merge pull request #12 from decodingml/feat/webinar
Browse files Browse the repository at this point in the history
tweak RAG retrieval module
  • Loading branch information
alexandruvesa authored May 3, 2024
2 parents 9fa8a5d + ca68f47 commit 77a8273
Show file tree
Hide file tree
Showing 26 changed files with 629 additions and 445 deletions.
2 changes: 0 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,6 @@ dmypy.json
# pytype static type analyzer
.pytype/

.idea

# Cython debug symbols
cython_debug/

Expand Down
1 change: 0 additions & 1 deletion course/module-3/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ MONGO_DATABASE_NAME="scrabble"
# QdrantDB config
QDRANT_DATABASE_HOST="localhost"
QDRANT_DATABASE_PORT=6333
CLEANED_DATA_OUTPUT_COLLECTION_NAME="cleaned_posts"
QDRANT_APIKEY=

# MQ config
Expand Down
7 changes: 5 additions & 2 deletions course/module-3/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ share/python-wheels/
*.egg
MANIFEST

data/
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
Expand Down Expand Up @@ -158,4 +157,8 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
.idea/

# Data Folders
data/
dataset/
2 changes: 1 addition & 1 deletion course/module-3/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ local-insert-data-mongo: #Insert data to mongodb
poetry run python insert_data_mongo.py

local-bytewax: # Run bytewax pipeline
poetry run python -m bytewax.run data_flow/bytewax_pipeline
RUST_BACKTRACE=full poetry run python -m bytewax.run data_flow/bytewax_pipeline

local-test-retriever: # Test retriever
poetry run python retriever.py
7 changes: 3 additions & 4 deletions course/module-3/cdc.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import json

from bson import json_util

from data_flow.mq import RabbitMQConnection
from db.mongo import MongoDatabaseConnector


def stream_process():
def stream_process() -> None:
mq_connection = RabbitMQConnection()
mq_connection.connect()

Expand All @@ -19,12 +18,12 @@ def stream_process():
) # Filter for inserts only
for change in changes:
data_type = change["ns"]["coll"]
entry_id = str(change["fullDocument"]["_id"]) # Convert ObjectId to string
entry_id = str(change["fullDocument"]["_id"])

change["fullDocument"].pop("_id")
change["fullDocument"]["type"] = data_type
change["fullDocument"]["entry_id"] = entry_id

# Use json_util to serialize the document
data = json.dumps(change["fullDocument"], default=json_util.default)
mq_connection.publish_message(data=data, queue="mongo_data")

Expand Down
40 changes: 26 additions & 14 deletions course/module-3/data_flow/mq.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,35 @@
import pika

import logger_utils

from settings import settings


logger = logger_utils.get_logger(__name__)


class RabbitMQConnection:
_instance = None

def __new__(
cls,
host: str = None,
port: int = None,
username: str = None,
password: str = None,
host: str | None = None,
port: int | None = None,
username: str | None = None,
password: str | None = None,
virtual_host: str = "/",
):
if not cls._instance:
cls._instance = super().__new__(cls)

return cls._instance

def __init__(
self,
host: str = None,
port: int = None,
username: str = None,
password: str = None,
host: str | None = None,
port: int | None = None,
username: str | None = None,
password: str | None = None,
virtual_host: str = "/",
fail_silently: bool = False,
**kwargs,
Expand Down Expand Up @@ -55,11 +61,12 @@ def connect(self):
)
)
except pika.exceptions.AMQPConnectionError as e:
print("Failed to connect to RabbitMQ:", e)
logger.exception("Failed to connect to RabbitMQ.")

if not self.fail_silently:
raise e

def publish_message(self, data, queue):
def publish_message(self, data: str, queue: str):
channel = self.get_channel()
channel.queue_declare(
queue=queue, durable=True, exclusive=False, auto_delete=False
Expand All @@ -68,11 +75,15 @@ def publish_message(self, data, queue):

try:
channel.basic_publish(
exchange="", routing_key="mongo_data", body=data, mandatory=True
exchange="", routing_key=queue, body=data, mandatory=True
)
logger.info(
"Sent message successfully.", queue_type="RabbitMQ", queue_name=queue
)
print("sent changes to RabbitMQ:", data)
except pika.exceptions.UnroutableError:
print("Message could not be confirmed")
logger.info(
"Failed to send the message.", queue_type="RabbitMQ", queue_name=queue
)

def is_connected(self) -> bool:
return self._connection is not None and self._connection.is_open
Expand All @@ -85,4 +96,5 @@ def close(self):
if self.is_connected():
self._connection.close()
self._connection = None
print("Closed RabbitMQ connection")

logger.info("Closed RabbitMQ connection.")
21 changes: 8 additions & 13 deletions course/module-3/data_flow/stream_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,26 @@
from typing import Generic, Iterable, List, Optional, TypeVar

from bytewax.inputs import FixedPartitionedSource, StatefulSourcePartition

from data_flow.mq import RabbitMQConnection

DATA = TypeVar(
"DATA"
) # The type of the items being produced in this case the data from the queue.
MESSAGE_ID = TypeVar(
"MESSAGE_ID"
) # The type of the state being saved and resumed in this case last message id from Rabbitmq.
DataT = TypeVar("DataT")
MessageT = TypeVar("MessageT")


class RabbitMQPartition(StatefulSourcePartition, Generic[DATA, MESSAGE_ID]):
class RabbitMQPartition(StatefulSourcePartition, Generic[DataT, MessageT]):
"""
Class responsible for creating a connection between bytewax and rabbitmq that facilitates the transfer of data from mq to bytewax streaming piepline.
Inherits StatefulSourcePartition for snapshot functionality that enables saving the state of the queue
"""

def __init__(self, queue_name, resume_state=None):
def __init__(self, queue_name: str, resume_state: MessageT | None = None) -> None:
self._in_flight_msg_ids = resume_state or set()
self.queue_name = queue_name
self.connection = RabbitMQConnection()
self.connection.connect()
self.channel = self.connection.get_channel()

def next_batch(self, sched: Optional[datetime]) -> Iterable[DATA]:
def next_batch(self, sched: Optional[datetime]) -> Iterable[DataT]:
method_frame, header_frame, body = self.channel.basic_get(
queue=self.queue_name, auto_ack=False
)
Expand All @@ -39,7 +34,7 @@ def next_batch(self, sched: Optional[datetime]) -> Iterable[DATA]:
else:
return []

def snapshot(self) -> MESSAGE_ID:
def snapshot(self) -> MessageT:
return self._in_flight_msg_ids

def garbage_collect(self, state):
Expand All @@ -57,6 +52,6 @@ def list_parts(self) -> List[str]:
return ["single partition"]

def build_part(
self, now: datetime, for_part: str, resume_state: Optional[MESSAGE_ID]
) -> StatefulSourcePartition[DATA, MESSAGE_ID]:
self, now: datetime, for_part: str, resume_state: MessageT | None = None
) -> StatefulSourcePartition[DataT, MessageT]:
return RabbitMQPartition(queue_name="mongo_data")
73 changes: 56 additions & 17 deletions course/module-3/data_flow/stream_output.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import logger_utils
from bytewax.outputs import DynamicSink, StatelessSinkPartition
from db.qdrant import QdrantDatabaseConnector
from models.base import DBDataModel
from models.base import VectorDBDataModel
from qdrant_client.http.api_client import UnexpectedResponse
from qdrant_client.models import Batch

logger = logger_utils.get_logger(__name__)


class QdrantOutput(DynamicSink):
"""
Expand All @@ -17,44 +20,68 @@ def __init__(self, connection: QdrantDatabaseConnector, sink_type: str):

try:
self._connection.get_collection(collection_name="cleaned_posts")
except UnexpectedResponse as e:
print(f"Error when accessing the collection: {e}")
except UnexpectedResponse:
logger.exception(
"Couldn't access the collection. Creating a new one...",
collection_name="cleaned_posts",
)

self._connection.create_non_vector_collection(
collection_name="cleaned_posts"
)

try:
self._connection.get_collection(collection_name="cleaned_articles")
except UnexpectedResponse as e:
print(f"Error when accessing the collection: {e}")
except UnexpectedResponse:
logger.exception(
"Couldn't access the collection. Creating a new one...",
collection_name="cleaned_articles",
)

self._connection.create_non_vector_collection(
collection_name="cleaned_articles"
)

try:
self._connection.get_collection(collection_name="cleaned_repositories")
except UnexpectedResponse as e:
print(f"Error when accessing the collection: {e}")
except UnexpectedResponse:
logger.exception(
"Couldn't access the collection. Creating a new one...",
collection_name="cleaned_repositories",
)

self._connection.create_non_vector_collection(
collection_name="cleaned_repositories"
)

try:
self._connection.get_collection(collection_name="vector_posts")
except UnexpectedResponse as e:
print(f"Error when accessing the collection: {e}")
except UnexpectedResponse:
logger.exception(
"Couldn't access the collection. Creating a new one...",
collection_name="vector_posts",
)

self._connection.create_vector_collection(collection_name="vector_posts")

try:
self._connection.get_collection(collection_name="vector_articles")
except UnexpectedResponse as e:
print(f"Error when accessing the collection: {e}")
except UnexpectedResponse:
logger.exception(
"Couldn't access the collection. Creating a new one...",
collection_name="vector_articles",
)

self._connection.create_vector_collection(collection_name="vector_articles")

try:
self._connection.get_collection(collection_name="vector_repositories")
except UnexpectedResponse as e:
print(f"Error when accessing the collection: {e}")
except UnexpectedResponse:
logger.exception(
"Couldn't access the collection. Creating a new one...",
collection_name="vector_repositories",
)

self._connection.create_vector_collection(
collection_name="vector_repositories"
)
Expand All @@ -72,29 +99,41 @@ class QdrantCleanedDataSink(StatelessSinkPartition):
def __init__(self, connection: QdrantDatabaseConnector):
self._client = connection

def write_batch(self, items: list[DBDataModel]) -> None:
payloads = [item.save() for item in items]
def write_batch(self, items: list[VectorDBDataModel]) -> None:
payloads = [item.to_payload() for item in items]
ids, data = zip(*payloads)
collection_name = get_clean_collection(data_type=data[0]["type"])
self._client.write_data(
collection_name=collection_name,
points=Batch(ids=ids, vectors={}, payloads=data),
)

logger.info(
"Successfully inserted requested cleaned point(s)",
collection_name=collection_name,
num=len(ids),
)


class QdrantVectorDataSink(StatelessSinkPartition):
def __init__(self, connection: QdrantDatabaseConnector):
self._client = connection

def write_batch(self, items: list[DBDataModel]) -> None:
payloads = [item.save() for item in items]
def write_batch(self, items: list[VectorDBDataModel]) -> None:
payloads = [item.to_payload() for item in items]
ids, vectors, meta_data = zip(*payloads)
collection_name = get_vector_collection(data_type=meta_data[0]["type"])
self._client.write_data(
collection_name=collection_name,
points=Batch(ids=ids, vectors=vectors, payloads=meta_data),
)

logger.info(
"Successfully inserted requested vector point(s)",
collection_name=collection_name,
num=len(ids),
)


def get_clean_collection(data_type: str) -> str:
if data_type == "posts":
Expand Down
Loading

0 comments on commit 77a8273

Please sign in to comment.