Skip to content

Commit

Permalink
Add ArangoDB utility module and database utility module
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilipMay committed Dec 31, 2023
1 parent e8b46a7 commit 1976723
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 0 deletions.
86 changes: 86 additions & 0 deletions mltb2/arangodb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) 2023 Philip May
# This software is distributed under the terms of the MIT license
# which is available at https://opensource.org/licenses/MIT

"""ArangoDB utils module.
Hint:
Use pip to install the necessary dependencies for this module:
``pip install mltb2[arangodb]``
"""


from contextlib import closing
from dataclasses import dataclass
from typing import Optional, Sequence, Union

from arango import ArangoClient
from dotenv import dotenv_values

from mltb2.db import BatchDataManager


@dataclass
class ArangoBatchDataManager(BatchDataManager):
"""TODO: add docstring."""

hosts: Union[str, Sequence[str]]
db_name: str
username: str
password: str
collection_name: str
attribute_name: str
batch_size: int = 20
aql_overwrite: Optional[str] = None

@classmethod
def from_config_file(cls, config_file_name, aql_overwrite=None):
"""Construct ``ArangoDataManager`` from config file."""
arango_config = dotenv_values(config_file_name)
return cls(
hosts=arango_config["hosts"],
db_name=arango_config["db_name"],
username=arango_config["username"],
password=arango_config["password"],
collection_name=arango_config["collection_name"],
attribute_name=arango_config["attribute_name"],
batch_size=int(arango_config["batch_size"]),
aql_overwrite=aql_overwrite,
)

def _get_arango_client(self):
arango_client = ArangoClient(hosts=self.hosts)
return arango_client

def _get_connection(self, arango_client):
connection = arango_client.db(self.db_name, username=self.username, password=self.password)
return connection

def load_batch(self) -> Sequence:
"""TODO: add docstring."""
with closing(self._get_arango_client()) as arango_client:
connection = self._get_connection(arango_client)
bind_vars = {
"@coll": self.collection_name,
"attribute": self.attribute_name,
"batch_size": self.batch_size,
}
if self.aql_overwrite is None:
aql = "FOR doc IN @@coll FILTER !HAS(doc, @attribute) LIMIT @batch_size RETURN doc"
else:
aql = self.aql_overwrite
cursor = connection.aql.execute(
aql,
bind_vars=bind_vars,
batch_size=self.batch_size,
)
with closing(cursor) as closing_cursor:
batch = closing_cursor.batch()
return batch

def save_batch(self, batch: Sequence):
"""TODO: add docstring."""
with closing(self._get_arango_client()) as arango_client:
connection = self._get_connection(arango_client)
collection = connection.collection(self.collection_name)
collection.import_bulk(batch, on_duplicate="update")
49 changes: 49 additions & 0 deletions mltb2/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) 2023 Philip May
# This software is distributed under the terms of the MIT license
# which is available at https://opensource.org/licenses/MIT

"""Database utils module.
This module provides utility functions for other modules.
It is not meant to be used directly.
"""

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Sequence


class BatchDataManager(ABC):
"""TODO: add docstring."""

@abstractmethod
def load_batch(self) -> Sequence:
"""TODO: add docstring."""
pass

@abstractmethod
def save_batch(self, batch: Sequence):
"""TODO: add docstring."""
pass


@dataclass
class BatchDataProcessor(ABC):
"""TODO: add docstring."""

data_manager: BatchDataManager

@abstractmethod
def process_batch(self, batch: Sequence):
"""TODO: add docstring."""
pass

def run(self):
"""TODO: add docstring."""
while True:
batch = self.data_manager.load_batch()
if len(batch) == 0:
break
new_batch = self.process_batch(batch)
if len(new_batch) > 0:
self.data_manager.save_batch(new_batch)
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ pyyaml = {version = "*", optional = true}
pandas = {version = "*", optional = true}
beautifulsoup4 = {version = "*", optional = true}
joblib = {version = "*", optional = true}
python-dotenv = {version = "*", optional = true}
python-arango = {version = "*", optional = true}

[tool.poetry.extras]
files = ["platformdirs", "scikit-learn"]
Expand All @@ -83,6 +85,7 @@ transformers = ["scikit-learn", "torch", "transformers", "safetensors"]
md = ["scikit-learn", "torch", "transformers", "safetensors"]
somajo_transformers = ["SoMaJo", "scikit-learn", "torch", "transformers", "safetensors"]
openai = ["tiktoken", "openai", "pyyaml"]
arangodb = ["python-dotenv", "python-arango"]

[tool.poetry.group.lint.dependencies]
black = "*"
Expand Down

0 comments on commit 1976723

Please sign in to comment.