Skip to content

Commit

Permalink
Add ArangoDB utility module. (#123)
Browse files Browse the repository at this point in the history
* Add ArangoDB utility module and database utility module

* improve doc and typing

* improve doc

* add doc to db module

* add doc

* add doc

* Add demo_*.py to .gitignore

* Refactor ArangoDB and BatchDataManager classes

* improve doc

* Update .gitignore to exclude all demo files

* Add config file validation in ArangoBatchDataManager constructor

* Update ArangoDB configuration
  • Loading branch information
PhilipMay authored Dec 31, 2023
1 parent e8b46a7 commit bacdb0e
Show file tree
Hide file tree
Showing 7 changed files with 229 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
.python-version
poetry.lock
poetry.toml
demo_*

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
6 changes: 6 additions & 0 deletions docs/source/api-reference/arangodb.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.. _aramgodb_code_doc:

:mod:`~mltb2.arangodb`
======================

.. automodule:: mltb2.arangodb
6 changes: 6 additions & 0 deletions docs/source/api-reference/db.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.. _db_code_doc:

:mod:`~mltb2.db`
================

.. automodule:: mltb2.db
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
"numpy": ("https://numpy.org/doc/stable/", None),
"git": ("https://gitpython.readthedocs.io/en/stable/", None),
"platformdirs": ("https://platformdirs.readthedocs.io/en/latest/", None),
"arango": ("https://docs.python-arango.com/en/main/", None),
# "matplotlib": ("https://matplotlib.org/stable/", None),
}

Expand Down
157 changes: 157 additions & 0 deletions mltb2/arangodb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Copyright (c) 2023 Philip May
# This software is distributed under the terms of the MIT license
# which is available at https://opensource.org/licenses/MIT

"""ArangoDB utils module.
Hint:
Use pip to install the necessary dependencies for this module:
``pip install mltb2[arangodb]``
"""


from contextlib import closing
from dataclasses import dataclass
from typing import Optional, Sequence, Union

from arango import ArangoClient
from arango.database import StandardDatabase
from dotenv import dotenv_values

from mltb2.db import AbstractBatchDataManager


@dataclass
class ArangoBatchDataManager(AbstractBatchDataManager):
"""ArangoDB implementation of the ``AbstractBatchDataManager``.
Args:
hosts: ArangoDB host or hosts.
db_name: ArangoDB database name.
username: ArangoDB username.
password: ArangoDB password.
collection_name: Documents from this collection are processed.
attribute_name: This attribute is used to check if a document is already processed.
If the attribute is not present in a document, the document is processed.
If it is available the document is considered as already processed.
batch_size: The batch size.
aql_overwrite: AQL string to overwrite the default.
"""

hosts: Union[str, Sequence[str]]
db_name: str
username: str
password: str
collection_name: str
attribute_name: str
batch_size: int = 20
aql_overwrite: Optional[str] = None

@classmethod
def from_config_file(cls, config_file_name, aql_overwrite: Optional[str] = None):
"""Construct this from config file.
The config file must contain these values:
- ``hosts``
- ``db_name``
- ``username``
- ``password``
- ``collection_name``
- ``attribute_name``
- ``batch_size``
Config file example:
.. code-block::
hosts="https://arangodb.com"
db_name="my_ml_database"
username="my_username"
password="secret"
collection_name="my_ml_data_collection"
attribute_name="processing_metadata"
batch_size=100
Args:
config_file_name: The config file name (path).
aql_overwrite: AQL string to overwrite the default.
"""
# load config file data
arango_config = dotenv_values(config_file_name)

# check if all necessary keys are in config file
expected_config_file_keys = [
"hosts",
"db_name",
"username",
"password",
"collection_name",
"attribute_name",
"batch_size",
]
for expected_config_file_key in expected_config_file_keys:
if expected_config_file_key not in arango_config:
raise ValueError(f"Config file must contain '{expected_config_file_key}'!")

return cls(
hosts=arango_config["hosts"], # type: ignore
db_name=arango_config["db_name"], # type: ignore
username=arango_config["username"], # type: ignore
password=arango_config["password"], # type: ignore
collection_name=arango_config["collection_name"], # type: ignore
attribute_name=arango_config["attribute_name"], # type: ignore
batch_size=int(arango_config["batch_size"]), # type: ignore
aql_overwrite=aql_overwrite,
)

def _arango_client_factory(self) -> ArangoClient:
"""Create an ArangoDB client."""
arango_client = ArangoClient(hosts=self.hosts)
return arango_client

def _connection_factory(self, arango_client: ArangoClient) -> StandardDatabase:
"""Create an ArangoDB connection.
Args:
arango_client: ArangoDB client.
"""
connection = arango_client.db(self.db_name, username=self.username, password=self.password)
return connection

def load_batch(self) -> Sequence:
"""Load a batch of data from the ArangoDB database.
Returns:
The loaded batch of data.
"""
with closing(self._arango_client_factory()) as arango_client:
connection = self._connection_factory(arango_client)
bind_vars = {
"@coll": self.collection_name,
"attribute": self.attribute_name,
"batch_size": self.batch_size,
}
if self.aql_overwrite is None:
aql = "FOR doc IN @@coll FILTER !HAS(doc, @attribute) LIMIT @batch_size RETURN doc"
else:
aql = self.aql_overwrite
cursor = connection.aql.execute(
aql,
bind_vars=bind_vars, # type: ignore
batch_size=self.batch_size,
)
with closing(cursor) as closing_cursor: # type: ignore
batch = closing_cursor.batch() # type: ignore
return batch # type: ignore

def save_batch(self, batch: Sequence) -> None:
"""Save a batch of data to the ArangoDB database.
Args:
batch: The batch of data to save.
"""
with closing(self._arango_client_factory()) as arango_client:
connection = self._connection_factory(arango_client)
collection = connection.collection(self.collection_name)
collection.import_bulk(batch, on_duplicate="update")
55 changes: 55 additions & 0 deletions mltb2/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) 2023 Philip May
# This software is distributed under the terms of the MIT license
# which is available at https://opensource.org/licenses/MIT

"""Database utils module."""

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable, Sequence


class AbstractBatchDataManager(ABC):
"""Abstract base class for batch processing of database data.
This class (respectively an implementation of it) is intended to be
used in conjunction with the :class:`BatchDataProcessor`.
"""

@abstractmethod
def load_batch(self) -> Sequence:
"""Load a batch of data from the database."""
pass

@abstractmethod
def save_batch(self, batch: Sequence) -> None:
"""Save a batch of data to the database."""
pass


@dataclass
class BatchDataProcessor:
"""Process batches of data from a database.
Args:
data_manager: The data manager to load and save batches of data.
process_batch_callback: A callback function that processes one batch of data.
"""

data_manager: AbstractBatchDataManager
process_batch_callback: Callable[[Sequence], Sequence]

def run(self) -> None:
"""Run the batch data processing.
This is done until the data manager returns an empty batch.
For each batch the ``process_batch_callback`` is called.
Data is loaded by using an implementation of the :class:`AbstractBatchDataManager`.
"""
while True:
batch = self.data_manager.load_batch()
if len(batch) == 0:
break
new_batch = self.process_batch_callback(batch)
if len(new_batch) > 0:
self.data_manager.save_batch(new_batch)
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ pyyaml = {version = "*", optional = true}
pandas = {version = "*", optional = true}
beautifulsoup4 = {version = "*", optional = true}
joblib = {version = "*", optional = true}
python-dotenv = {version = "*", optional = true}
python-arango = {version = "*", optional = true}

[tool.poetry.extras]
files = ["platformdirs", "scikit-learn"]
Expand All @@ -83,6 +85,7 @@ transformers = ["scikit-learn", "torch", "transformers", "safetensors"]
md = ["scikit-learn", "torch", "transformers", "safetensors"]
somajo_transformers = ["SoMaJo", "scikit-learn", "torch", "transformers", "safetensors"]
openai = ["tiktoken", "openai", "pyyaml"]
arangodb = ["python-dotenv", "python-arango"]

[tool.poetry.group.lint.dependencies]
black = "*"
Expand Down

0 comments on commit bacdb0e

Please sign in to comment.