Skip to content

Commit

Permalink
improve doc and typing
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilipMay committed Dec 31, 2023
1 parent 1976723 commit 742bc17
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 25 deletions.
31 changes: 16 additions & 15 deletions mltb2/arangodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Optional, Sequence, Union

from arango import ArangoClient
from arango.database import StandardDatabase
from dotenv import dotenv_values

from mltb2.db import BatchDataManager
Expand All @@ -34,25 +35,25 @@ class ArangoBatchDataManager(BatchDataManager):
aql_overwrite: Optional[str] = None

@classmethod
def from_config_file(cls, config_file_name, aql_overwrite=None):
def from_config_file(cls, config_file_name, aql_overwrite: Optional[str] = None):
"""Construct ``ArangoDataManager`` from config file."""
arango_config = dotenv_values(config_file_name)
return cls(
hosts=arango_config["hosts"],
db_name=arango_config["db_name"],
username=arango_config["username"],
password=arango_config["password"],
collection_name=arango_config["collection_name"],
attribute_name=arango_config["attribute_name"],
batch_size=int(arango_config["batch_size"]),
hosts=arango_config["hosts"], # type: ignore
db_name=arango_config["db_name"], # type: ignore
username=arango_config["username"], # type: ignore
password=arango_config["password"], # type: ignore
collection_name=arango_config["collection_name"], # type: ignore
attribute_name=arango_config["attribute_name"], # type: ignore
batch_size=int(arango_config["batch_size"]), # type: ignore
aql_overwrite=aql_overwrite,
)

def _get_arango_client(self):
def _get_arango_client(self) -> ArangoClient:
arango_client = ArangoClient(hosts=self.hosts)
return arango_client

def _get_connection(self, arango_client):
def _get_connection(self, arango_client: ArangoClient) -> StandardDatabase:
connection = arango_client.db(self.db_name, username=self.username, password=self.password)
return connection

Expand All @@ -71,14 +72,14 @@ def load_batch(self) -> Sequence:
aql = self.aql_overwrite
cursor = connection.aql.execute(
aql,
bind_vars=bind_vars,
bind_vars=bind_vars, # type: ignore
batch_size=self.batch_size,
)
with closing(cursor) as closing_cursor:
batch = closing_cursor.batch()
return batch
with closing(cursor) as closing_cursor: # type: ignore
batch = closing_cursor.batch() # type: ignore
return batch # type: ignore

def save_batch(self, batch: Sequence):
def save_batch(self, batch: Sequence) -> None:
"""TODO: add docstring."""
with closing(self._get_arango_client()) as arango_client:
connection = self._get_connection(arango_client)
Expand Down
16 changes: 6 additions & 10 deletions mltb2/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Sequence
from typing import Callable, Sequence


class BatchDataManager(ABC):
Expand All @@ -22,28 +22,24 @@ def load_batch(self) -> Sequence:
pass

@abstractmethod
def save_batch(self, batch: Sequence):
def save_batch(self, batch: Sequence) -> None:
"""TODO: add docstring."""
pass


@dataclass
class BatchDataProcessor(ABC):
class BatchDataProcessor:
"""TODO: add docstring."""

data_manager: BatchDataManager
process_batch_callback: Callable[[Sequence], Sequence]

@abstractmethod
def process_batch(self, batch: Sequence):
"""TODO: add docstring."""
pass

def run(self):
def run(self) -> None:
"""TODO: add docstring."""
while True:
batch = self.data_manager.load_batch()
if len(batch) == 0:
break
new_batch = self.process_batch(batch)
new_batch = self.process_batch_callback(batch)
if len(new_batch) > 0:
self.data_manager.save_batch(new_batch)

0 comments on commit 742bc17

Please sign in to comment.