Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DB query optimization and reducing sqlalchemy logs #575

Merged
12 changes: 10 additions & 2 deletions pebblo/app/models/sqltables.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import logging

from sqlalchemy import JSON, Column, Integer, create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import declarative_base

from pebblo.app.config.config import var_server_config_dict
from pebblo.app.enums.common import StorageTypes
from pebblo.app.enums.enums import CacheDir, SQLiteTables
from pebblo.app.utils.utils import get_full_path
from pebblo.log import get_logger

logger = get_logger(__name__)

Base = declarative_base()

Expand Down Expand Up @@ -66,7 +71,10 @@ class AiUser(Base):
# Create an engine that stores data in the local directory's my_database.db file.
full_path = get_full_path(CacheDir.HOME_DIR.value)
sqlite_db_path = CacheDir.SQLITE_ENGINE.value.format(full_path)
engine = create_engine(sqlite_db_path, echo=True)
if logger.isEnabledFor(logging.DEBUG):
engine = create_engine(sqlite_db_path, echo=True)
else:
engine = create_engine(sqlite_db_path)

# Create all tables in the engine. This is equivalent to "Create Table" statements in raw SQL.
Base.metadata.create_all(engine)
20 changes: 15 additions & 5 deletions pebblo/app/service/local_ui/loader_apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
get_current_time,
get_full_path,
get_pebblo_server_version,
timeit,
)
from pebblo.log import get_logger

Expand Down Expand Up @@ -58,13 +59,20 @@ def _get_snippet_details(
This function finds snippet details based on labels
"""
response = []
for snippet_id in snippet_ids:
result, output = self.db.query_by_list(
AiSnippetsTable,
filter_key="id",
filter_values=snippet_ids[: ReportConstants.SNIPPET_LIMIT.value],
)

if not result or len(output) == 0:
return response

for row in output:
if len(response) >= ReportConstants.SNIPPET_LIMIT.value:
break
result, output = self.db.query(AiSnippetsTable, {"id": snippet_id})
if not result or len(output) == 0:
continue
snippet_details = output[0].data

snippet_details = row.data
entity_details = {}
topic_details = {}
if snippet_details.get("topicDetails") and snippet_details[
Expand Down Expand Up @@ -351,6 +359,7 @@ def _create_loader_app_model(self, app_list: list) -> LoaderAppModel:
)
return loader_response

@timeit
def get_all_loader_apps(self):
"""
Returns all necessary loader app details required for get all app functionality.
Expand Down Expand Up @@ -402,6 +411,7 @@ def get_all_loader_apps(self):
# Closing the session
self.db.session.close()

@timeit
def get_loader_app_details(self, db: SQLiteClient, app_name: str) -> str:
"""
This function is being used by the loader_doc_service to get data needed to generate pdf.
Expand Down
3 changes: 3 additions & 0 deletions pebblo/app/service/local_ui/retriever_apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
AiUser,
)
from pebblo.app.storage.sqlite_db import SQLiteClient
from pebblo.app.utils.utils import timeit
from pebblo.log import get_logger

config_details = var_server_config_dict.get()
Expand Down Expand Up @@ -385,6 +386,7 @@ def prepare_retrieval_app_response(self, app_data, retrieval_data):
)
return json.dumps(response.model_dump(), default=str, indent=4)

@timeit
def get_all_retriever_apps(self):
try:
self.db = SQLiteClient()
Expand Down Expand Up @@ -462,6 +464,7 @@ def get_all_retriever_apps(self):
# Closing the session
self.db.session.close()

@timeit
def get_retriever_app_details(self, app_name):
try:
retrieval_data = []
Expand Down
72 changes: 70 additions & 2 deletions pebblo/app/storage/sqlite_db.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from sqlalchemy import and_, create_engine, text
import logging
from math import ceil
from typing import List, Type

from sqlalchemy import and_, create_engine, func, text
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.attributes import flag_modified
from sqlalchemy.orm.decl_api import DeclarativeMeta

from pebblo.app.enums.enums import CacheDir
from pebblo.app.storage.database import Database
Expand All @@ -21,7 +26,10 @@ def _create_engine():
# Create an engine that stores data in the local directory's db file.
full_path = get_full_path(CacheDir.HOME_DIR.value)
sqlite_db_path = CacheDir.SQLITE_ENGINE.value.format(full_path)
engine = create_engine(sqlite_db_path, echo=True)
if logger.isEnabledFor(logging.DEBUG):
engine = create_engine(sqlite_db_path, echo=True)
else:
engine = create_engine(sqlite_db_path)
return engine

def create_session(self):
Expand Down Expand Up @@ -104,6 +112,66 @@ def query_by_id(self, table_obj, id):
)
return False, err

@timeit
def query_by_list(
self,
table_obj: Type[DeclarativeMeta],
filter_key: str,
filter_values: List[str],
shreyas-damle marked this conversation as resolved.
Show resolved Hide resolved
page_size: int = 100,
):
"""
Pass filter like list. For example get snippets with ids in [<id1>, <id2>]
:param table_obj: Table object on which query is to be performed
:param filter_key: Search key
:param filter_values: List of strings to be added to filter criteria.
:param page_size: Page size to be used per iteration.
All items from filter_values would be search based on page_size.
"""
table_name = table_obj.__tablename__
try:
logger.debug(f"Fetching data from table {table_name}")
total_records = len(filter_values)
total_pages = ceil(total_records / page_size)
results = []
for page in range(total_pages):
try:
# Calculate start and end indices for the current batch
start_idx = page * page_size
end_idx = start_idx + page_size

# Slice filter_values to match the current batch
current_batch = filter_values[start_idx:end_idx]

logger.debug(
f"Processing batch {page + 1}/{total_pages}, filter values: {current_batch}"
)

# Execute the query for the current batch
batch_result = (
self.session.query(table_obj)
.filter(
func.json_extract(table_obj.data, f"$.{filter_key}").in_(
current_batch
)
)
.all()
)
results.extend(batch_result)
except Exception as err:
logger.error(
f"Failed in fetching data from table {table_name}, Error: {err}"
)
continue

return True, results

except Exception as err:
logger.error(
f"Failed in fetching data from table {table_name}, Error: {err}"
)
return False, []

@timeit
def update_data(self, table_obj, data):
table_name = table_obj.__tablename__
Expand Down
Empty file added tests/app/storage/__init__.py
Empty file.
98 changes: 98 additions & 0 deletions tests/app/storage/test_sqlite_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from unittest.mock import MagicMock

import pytest
from sqlalchemy.orm import Session

from pebblo.app.models.sqltables import AiSnippetsTable

# Assume table_obj is imported from the actual module where the table is defined


@pytest.fixture
def sqlite_client():
"""Fixture for creating an SQLiteClient instance."""
from pebblo.app.storage.sqlite_db import SQLiteClient

client = SQLiteClient()
client.session = MagicMock(spec=Session)
return client


def test_query_by_list_success(sqlite_client, mocker):
"""Test successful query with query_by_list."""
mock_session = sqlite_client.session
table_obj = AiSnippetsTable
filter_key = "id"
filter_values = ["snippet_id1", "snippet_id2"]

# Mocking query result
mock_query = mock_session.query.return_value
mock_filter = mock_query.filter.return_value
mock_filter.all.return_value = ["result1", "result2"] # Mocked results

# Call the method
success, result = sqlite_client.query_by_list(table_obj, filter_key, filter_values)

# Assertions
assert success is True
assert result == ["result1", "result2"]

# Ensure the query was called only once (no pagination)
assert mock_session.query().filter().all.call_count == 1


def test_query_by_list_page_size(sqlite_client):
"""Test successful query with query_by_list to verify max_filter_limit"""
mock_session = sqlite_client.session
table_obj = AiSnippetsTable
filter_key = "id"
filter_values = [
"snippet_id1",
"snippet_id2",
"snippet_id3",
"snippet_id4",
"snippet_id5",
]
page_size = 2

# Mocking query result
mock_result_page_1 = ["result1", "result2"]
mock_result_page_2 = ["result3", "result4"]
mock_result_page_3 = ["result5"]
mock_query = mock_session.query().filter().all
mock_query.side_effect = [
mock_result_page_1,
mock_result_page_2,
mock_result_page_3,
]

# Call the method
success, result = sqlite_client.query_by_list(
table_obj, filter_key, filter_values, page_size
)

# Assertions
assert success is True
assert result == ["result1", "result2", "result3", "result4", "result5"]


def test_query_by_list_failure(sqlite_client):
mock_session = sqlite_client.session
mock_table_obj = AiSnippetsTable
filter_key = "id"
filter_values = ["value1", "value2"]

# Create mock data
page_size = "abcd" # invalid page size

# Call the query_by_list function
success, results = sqlite_client.query_by_list(
table_obj=mock_table_obj,
filter_key=filter_key,
filter_values=filter_values,
page_size=page_size,
)

assert success is False
assert results == []
mock_session.query.assert_not_called()