From 98d657742adb1688c0d1d66b5bf5f1633eeadc47 Mon Sep 17 00:00:00 2001 From: Philip May Date: Sun, 31 Dec 2023 23:14:52 +0100 Subject: [PATCH] Add ArangoDB collection backup functionality. (#124) * Add ArangoDB collection backup functionality * add doc * add return type info * fix linting * add prints --- mltb2/arangodb.py | 69 ++++++++++++++++++++++++++++++++++++++++++++--- pyproject.toml | 6 ++++- 2 files changed, 70 insertions(+), 5 deletions(-) diff --git a/mltb2/arangodb.py b/mltb2/arangodb.py index 4b0715e..3be32d0 100644 --- a/mltb2/arangodb.py +++ b/mltb2/arangodb.py @@ -10,17 +10,31 @@ """ +import gzip +from argparse import ArgumentParser from contextlib import closing from dataclasses import dataclass -from typing import Optional, Sequence, Union +from typing import Dict, Optional, Sequence, Union +import jsonlines from arango import ArangoClient from arango.database import StandardDatabase from dotenv import dotenv_values +from tqdm import tqdm from mltb2.db import AbstractBatchDataManager +def _check_config_keys(config: Dict[str, Optional[str]], expected_config_keys: Sequence[str]) -> None: + """Check if all expected keys are in config. + + This is useful to check if a config file contains all necessary keys. + """ + for expected_config_key in expected_config_keys: + if expected_config_key not in config: + raise ValueError(f"Config file must contain '{expected_config_key}'!") + + @dataclass class ArangoBatchDataManager(AbstractBatchDataManager): """ArangoDB implementation of the ``AbstractBatchDataManager``. @@ -90,9 +104,7 @@ def from_config_file(cls, config_file_name, aql_overwrite: Optional[str] = None) "attribute_name", "batch_size", ] - for expected_config_file_key in expected_config_file_keys: - if expected_config_file_key not in arango_config: - raise ValueError(f"Config file must contain '{expected_config_file_key}'!") + _check_config_keys(arango_config, expected_config_file_keys) return cls( hosts=arango_config["hosts"], # type: ignore @@ -155,3 +167,52 @@ def save_batch(self, batch: Sequence) -> None: connection = self._connection_factory(arango_client) collection = connection.collection(self.collection_name) collection.import_bulk(batch, on_duplicate="update") + + +def arango_collection_backup() -> None: + """Commandline tool to do an ArangoDB backup of a collection. + + The backup is written to a gzip compressed JSONL file in the current working directory. + Run ``arango-col-backup -h`` to get command line help. + """ + # argument parsing + description = ( + "ArangoDB backup of a collection. " + "The backup is written to a gzip compressed JSONL file in the current working directory." + ) + argument_parser = ArgumentParser(description=description) + argument_parser.add_argument( + "--conf", type=str, required=True, help="Config file containing 'hosts', 'db_name', 'username' and 'password'." + ) + argument_parser.add_argument("--col", type=str, required=True, help="Collection name to backup.") + args = argument_parser.parse_args() + + # load and check config file + arango_config = dotenv_values(args.conf) + expected_config_file_keys = ["hosts", "db_name", "username", "password"] + _check_config_keys(arango_config, expected_config_file_keys) + + output_file_name = f"./{args.col}_backup.jsonl.gz" + print(f"Writing backup to '{output_file_name}'...") + + with closing(ArangoClient(hosts=arango_config["hosts"])) as arango_client, gzip.open( # type: ignore + output_file_name, "w" + ) as gzip_out: + connection = arango_client.db( + arango_config["db_name"], # type: ignore + arango_config["username"], # type: ignore + arango_config["password"], # type: ignore + ) + jsonlines_writer = jsonlines.Writer(gzip_out) # type: ignore + try: + cursor = connection.aql.execute( + "FOR doc IN @@coll RETURN doc", + bind_vars={"@coll": args.col}, + batch_size=100, + max_runtime=60 * 60, # type: ignore # 1 hour + stream=True, + ) + for doc in tqdm(cursor): + jsonlines_writer.write(doc) + finally: + cursor.close(ignore_missing=True) # type: ignore diff --git a/pyproject.toml b/pyproject.toml index 9a13cf9..1a8a0ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,9 @@ classifiers = [ requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" +[tool.poetry.scripts] +arango-col-backup = 'mltb2.arangodb:arango_collection_backup' + [tool.poetry.urls] "Bug Tracker" = "https://github.com/telekom/mltb2/issues" @@ -73,6 +76,7 @@ beautifulsoup4 = {version = "*", optional = true} joblib = {version = "*", optional = true} python-dotenv = {version = "*", optional = true} python-arango = {version = "*", optional = true} +jsonlines = {version = "*", optional = true} [tool.poetry.extras] files = ["platformdirs", "scikit-learn"] @@ -85,7 +89,7 @@ transformers = ["scikit-learn", "torch", "transformers", "safetensors"] md = ["scikit-learn", "torch", "transformers", "safetensors"] somajo_transformers = ["SoMaJo", "scikit-learn", "torch", "transformers", "safetensors"] openai = ["tiktoken", "openai", "pyyaml"] -arangodb = ["python-dotenv", "python-arango"] +arangodb = ["python-dotenv", "python-arango", "jsonlines"] [tool.poetry.group.lint.dependencies] black = "*"