diff --git a/mltb2/arangodb.py b/mltb2/arangodb.py new file mode 100644 index 0000000..88b6c87 --- /dev/null +++ b/mltb2/arangodb.py @@ -0,0 +1,86 @@ +# Copyright (c) 2023 Philip May +# This software is distributed under the terms of the MIT license +# which is available at https://opensource.org/licenses/MIT + +"""ArangoDB utils module. + +Hint: + Use pip to install the necessary dependencies for this module: + ``pip install mltb2[arangodb]`` +""" + + +from contextlib import closing +from dataclasses import dataclass +from typing import Optional, Sequence, Union + +from arango import ArangoClient +from dotenv import dotenv_values + +from mltb2.db import BatchDataManager + + +@dataclass +class ArangoBatchDataManager(BatchDataManager): + """TODO: add docstring.""" + + hosts: Union[str, Sequence[str]] + db_name: str + username: str + password: str + collection_name: str + attribute_name: str + batch_size: int = 20 + aql_overwrite: Optional[str] = None + + @classmethod + def from_config_file(cls, config_file_name, aql_overwrite=None): + """Construct ``ArangoDataManager`` from config file.""" + arango_config = dotenv_values(config_file_name) + return cls( + hosts=arango_config["hosts"], + db_name=arango_config["db_name"], + username=arango_config["username"], + password=arango_config["password"], + collection_name=arango_config["collection_name"], + attribute_name=arango_config["attribute_name"], + batch_size=int(arango_config["batch_size"]), + aql_overwrite=aql_overwrite, + ) + + def _get_arango_client(self): + arango_client = ArangoClient(hosts=self.hosts) + return arango_client + + def _get_connection(self, arango_client): + connection = arango_client.db(self.db_name, username=self.username, password=self.password) + return connection + + def load_batch(self) -> Sequence: + """TODO: add docstring.""" + with closing(self._get_arango_client()) as arango_client: + connection = self._get_connection(arango_client) + bind_vars = { + "@coll": self.collection_name, + "attribute": self.attribute_name, + "batch_size": self.batch_size, + } + if self.aql_overwrite is None: + aql = "FOR doc IN @@coll FILTER !HAS(doc, @attribute) LIMIT @batch_size RETURN doc" + else: + aql = self.aql_overwrite + cursor = connection.aql.execute( + aql, + bind_vars=bind_vars, + batch_size=self.batch_size, + ) + with closing(cursor) as closing_cursor: + batch = closing_cursor.batch() + return batch + + def save_batch(self, batch: Sequence): + """TODO: add docstring.""" + with closing(self._get_arango_client()) as arango_client: + connection = self._get_connection(arango_client) + collection = connection.collection(self.collection_name) + collection.import_bulk(batch, on_duplicate="update") diff --git a/mltb2/db.py b/mltb2/db.py new file mode 100644 index 0000000..a75e3fb --- /dev/null +++ b/mltb2/db.py @@ -0,0 +1,49 @@ +# Copyright (c) 2023 Philip May +# This software is distributed under the terms of the MIT license +# which is available at https://opensource.org/licenses/MIT + +"""Database utils module. + +This module provides utility functions for other modules. +It is not meant to be used directly. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Sequence + + +class BatchDataManager(ABC): + """TODO: add docstring.""" + + @abstractmethod + def load_batch(self) -> Sequence: + """TODO: add docstring.""" + pass + + @abstractmethod + def save_batch(self, batch: Sequence): + """TODO: add docstring.""" + pass + + +@dataclass +class BatchDataProcessor(ABC): + """TODO: add docstring.""" + + data_manager: BatchDataManager + + @abstractmethod + def process_batch(self, batch: Sequence): + """TODO: add docstring.""" + pass + + def run(self): + """TODO: add docstring.""" + while True: + batch = self.data_manager.load_batch() + if len(batch) == 0: + break + new_batch = self.process_batch(batch) + if len(new_batch) > 0: + self.data_manager.save_batch(new_batch) diff --git a/pyproject.toml b/pyproject.toml index 4420746..9a13cf9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,8 @@ pyyaml = {version = "*", optional = true} pandas = {version = "*", optional = true} beautifulsoup4 = {version = "*", optional = true} joblib = {version = "*", optional = true} +python-dotenv = {version = "*", optional = true} +python-arango = {version = "*", optional = true} [tool.poetry.extras] files = ["platformdirs", "scikit-learn"] @@ -83,6 +85,7 @@ transformers = ["scikit-learn", "torch", "transformers", "safetensors"] md = ["scikit-learn", "torch", "transformers", "safetensors"] somajo_transformers = ["SoMaJo", "scikit-learn", "torch", "transformers", "safetensors"] openai = ["tiktoken", "openai", "pyyaml"] +arangodb = ["python-dotenv", "python-arango"] [tool.poetry.group.lint.dependencies] black = "*"