diff --git a/mltb2/arangodb.py b/mltb2/arangodb.py index 88b6c87..1770cb4 100644 --- a/mltb2/arangodb.py +++ b/mltb2/arangodb.py @@ -15,6 +15,7 @@ from typing import Optional, Sequence, Union from arango import ArangoClient +from arango.database import StandardDatabase from dotenv import dotenv_values from mltb2.db import BatchDataManager @@ -34,25 +35,25 @@ class ArangoBatchDataManager(BatchDataManager): aql_overwrite: Optional[str] = None @classmethod - def from_config_file(cls, config_file_name, aql_overwrite=None): + def from_config_file(cls, config_file_name, aql_overwrite: Optional[str] = None): """Construct ``ArangoDataManager`` from config file.""" arango_config = dotenv_values(config_file_name) return cls( - hosts=arango_config["hosts"], - db_name=arango_config["db_name"], - username=arango_config["username"], - password=arango_config["password"], - collection_name=arango_config["collection_name"], - attribute_name=arango_config["attribute_name"], - batch_size=int(arango_config["batch_size"]), + hosts=arango_config["hosts"], # type: ignore + db_name=arango_config["db_name"], # type: ignore + username=arango_config["username"], # type: ignore + password=arango_config["password"], # type: ignore + collection_name=arango_config["collection_name"], # type: ignore + attribute_name=arango_config["attribute_name"], # type: ignore + batch_size=int(arango_config["batch_size"]), # type: ignore aql_overwrite=aql_overwrite, ) - def _get_arango_client(self): + def _get_arango_client(self) -> ArangoClient: arango_client = ArangoClient(hosts=self.hosts) return arango_client - def _get_connection(self, arango_client): + def _get_connection(self, arango_client: ArangoClient) -> StandardDatabase: connection = arango_client.db(self.db_name, username=self.username, password=self.password) return connection @@ -71,14 +72,14 @@ def load_batch(self) -> Sequence: aql = self.aql_overwrite cursor = connection.aql.execute( aql, - bind_vars=bind_vars, + bind_vars=bind_vars, # type: ignore batch_size=self.batch_size, ) - with closing(cursor) as closing_cursor: - batch = closing_cursor.batch() - return batch + with closing(cursor) as closing_cursor: # type: ignore + batch = closing_cursor.batch() # type: ignore + return batch # type: ignore - def save_batch(self, batch: Sequence): + def save_batch(self, batch: Sequence) -> None: """TODO: add docstring.""" with closing(self._get_arango_client()) as arango_client: connection = self._get_connection(arango_client) diff --git a/mltb2/db.py b/mltb2/db.py index a75e3fb..4aef8f9 100644 --- a/mltb2/db.py +++ b/mltb2/db.py @@ -10,7 +10,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Sequence +from typing import Callable, Sequence class BatchDataManager(ABC): @@ -22,28 +22,24 @@ def load_batch(self) -> Sequence: pass @abstractmethod - def save_batch(self, batch: Sequence): + def save_batch(self, batch: Sequence) -> None: """TODO: add docstring.""" pass @dataclass -class BatchDataProcessor(ABC): +class BatchDataProcessor: """TODO: add docstring.""" data_manager: BatchDataManager + process_batch_callback: Callable[[Sequence], Sequence] - @abstractmethod - def process_batch(self, batch: Sequence): - """TODO: add docstring.""" - pass - - def run(self): + def run(self) -> None: """TODO: add docstring.""" while True: batch = self.data_manager.load_batch() if len(batch) == 0: break - new_batch = self.process_batch(batch) + new_batch = self.process_batch_callback(batch) if len(new_batch) > 0: self.data_manager.save_batch(new_batch)