Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implements the backup manager #577

Merged
merged 5 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions lib/charms/vault_k8s/v0/vault_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

# Increment this PATCH version before using `charmcraft publish-lib` or reset
# to 0 if you are raising the major API version
LIBPATCH = 24
LIBPATCH = 25


RAFT_STATE_ENDPOINT = "v1/sys/storage/raft/autopilot/state"
Expand Down Expand Up @@ -483,13 +483,16 @@ def create_snapshot(self) -> requests.Response:
"""Create a snapshot of the Vault data."""
return self._client.sys.take_raft_snapshot()

def restore_snapshot(self, snapshot: IOBase) -> requests.Response:
def restore_snapshot(self, snapshot: IOBase) -> None:
"""Restore a snapshot of the Vault data.

Uses force_restore_raft_snapshot to restore the snapshot
even if the unseal key used at backup time is different from the current one.
"""
return self._client.sys.force_restore_raft_snapshot(snapshot)
response = self._client.sys.force_restore_raft_snapshot(snapshot)
if not 200 <= response.status_code < 300:
logger.warning("Error while restoring snapshot: %s", response.text)
raise VaultClientError(f"Error while restoring snapshot: {response.text}")

def get_raft_cluster_state(self) -> dict:
"""Get raft cluster state."""
Expand Down
175 changes: 173 additions & 2 deletions lib/charms/vault_k8s/v0/vault_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import timedelta
from datetime import datetime, timedelta
from enum import Enum, auto
from typing import FrozenSet, MutableMapping, TextIO

from charms.certificate_transfer_interface.v0.certificate_transfer import (
CertificateTransferProvides,
)
from charms.data_platform_libs.v0.s3 import S3Requirer
from charms.tls_certificates_interface.v4.tls_certificates import (
Certificate,
CertificateRequestAttributes,
Expand Down Expand Up @@ -73,6 +74,7 @@
VaultClientError,
)
from charms.vault_k8s.v0.vault_kv import VaultKvProvides
from charms.vault_k8s.v0.vault_s3 import S3, S3Error
from ops import CharmBase, EventBase, Object, Relation
from ops.pebble import PathError

Expand All @@ -84,7 +86,7 @@

# Increment this PATCH version before using `charmcraft publish-lib` or reset
# to 0 if you are raising the major API version
LIBPATCH = 4
LIBPATCH = 5


SEND_CA_CERT_RELATION_NAME = "send-ca-cert"
Expand Down Expand Up @@ -599,6 +601,7 @@ class Naming:
autounseal_approle_prefix: str = "charm-autounseal-"
autounseal_key_prefix: str = ""
autounseal_policy_prefix: str = "charm-autounseal-"
backup_s3_key_prefix: str = "vault-backup-"
kv_mount_prefix: str = "charm-"
kv_secret_prefix: str = "vault-kv-"

Expand All @@ -617,6 +620,12 @@ def autounseal_approle_name(cls, relation_id: int) -> str:
"""Return the approle name for the relation."""
return f"{cls.autounseal_approle_prefix}{relation_id}"

@classmethod
def backup_s3_key_name(cls, model_name: str) -> str:
"""Return the key name for the S3 backend."""
timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
return f"{cls.backup_s3_key_prefix}{model_name}-{timestamp}"

@classmethod
def kv_secret_label(cls, unit_name: str) -> str:
"""Return the secret label for the KV backend."""
Expand Down Expand Up @@ -1276,3 +1285,165 @@ def remove_unit_credentials(juju_facade: JujuFacade, unit_name: str) -> None:
unit_name: The name of the unit for which to remove the secret
"""
juju_facade.remove_secret(Naming.kv_secret_label(unit_name=unit_name))


class BackupManager:
"""Encapsulates the business logic for managing backups in Vault from a Charm.

This class provides the business logic for creating, listing, and restoring
backups of the Vault data.
"""

REQUIRED_S3_PARAMETERS = ["bucket", "access-key", "secret-key", "endpoint"]

def __init__(
self,
charm: CharmBase,
s3_requirer: S3Requirer,
relation_name: str,
):
self._charm = charm
self._juju_facade = JujuFacade(charm)
self._s3_requirer = s3_requirer
self._relation_name = relation_name

def create_backup(self, vault_client: VaultClient) -> str:
"""Create a backup of the Vault data.

Stores the backup in the S3 bucket provided by the S3 relation.

Returns:
The S3 key of the backup.
"""
self._validate_s3_prerequisites()

s3_parameters = self._get_s3_parameters()

try:
s3 = S3(
access_key=s3_parameters["access-key"],
secret_key=s3_parameters["secret-key"],
endpoint=s3_parameters["endpoint"],
region=s3_parameters.get("region"),
)
except S3Error as e:
logger.error("Failed to create S3 session. %s", e)
raise ManagerError("Failed to create S3 session")

if not (s3.create_bucket(bucket_name=s3_parameters["bucket"])):
raise ManagerError("Failed to create S3 bucket")
backup_key = Naming.backup_s3_key_name(self._charm.model.name)

response = vault_client.create_snapshot()
content_uploaded = s3.upload_content(
content=response.raw, # type: ignore[reportArgumentType]
bucket_name=s3_parameters["bucket"],
key=backup_key,
)
if not content_uploaded:
raise ManagerError("Failed to upload backup to S3 bucket")
logger.info("Backup uploaded to S3 bucket %s", s3_parameters["bucket"])
return backup_key

def list_backups(self) -> list[str]:
"""List all the backups available in the S3 bucket.

Backups are identified by the key prefix from
``Naming.backup_s3_key_prefix``.

Returns:
A list of backup keys with the prefix.
"""
self._validate_s3_prerequisites()

s3_parameters = self._get_s3_parameters()

try:
s3 = S3(
access_key=s3_parameters["access-key"],
secret_key=s3_parameters["secret-key"],
endpoint=s3_parameters["endpoint"],
region=s3_parameters.get("region"),
)
except S3Error:
raise ManagerError("Failed to create S3 session")

try:
backup_ids = s3.get_object_key_list(
bucket_name=s3_parameters["bucket"], prefix=Naming.backup_s3_key_prefix
)
except S3Error as e:
raise ManagerError(f"Failed to list backups in S3 bucket: {e}")
return backup_ids

def restore_backup(self, vault_client: VaultClient, backup_key: str) -> None:
"""Restore the Vault data from the backup using the ``vault_client`` provided.

Args:
vault_client: The Vault client to use for restoring the snapshot
backup_key: The S3 key of the backup to restore
"""
self._validate_s3_prerequisites()

s3_parameters = self._get_s3_parameters()

try:
s3 = S3(
access_key=s3_parameters["access-key"],
secret_key=s3_parameters["secret-key"],
endpoint=s3_parameters["endpoint"],
region=s3_parameters.get("region"),
)
except S3Error:
raise ManagerError("Failed to create S3 session")

try:
snapshot = s3.get_content(
bucket_name=s3_parameters["bucket"],
object_key=backup_key,
)
except S3Error as e:
raise ManagerError(f"Failed to retrieve snapshot from S3: {e}")
if not snapshot:
raise ManagerError("Snapshot not found in S3 bucket")

try:
vault_client.restore_snapshot(snapshot=snapshot)
except VaultClientError as e:
raise ManagerError(f"Failed to restore snapshot: {e}")

def _validate_s3_prerequisites(self) -> str | None:
"""Validate the S3 pre-requisites are met.

Raises:
ManagerError: If any of the pre-requisites are not met.
"""
if not self._juju_facade.is_leader:
raise ManagerError("Only leader unit can perform backup operations")
if not self._juju_facade.relation_exists(self._relation_name):
raise ManagerError("S3 relation not created")
if missing_parameters := self._get_missing_s3_parameters():
raise ManagerError("S3 parameters missing ({})".format(", ".join(missing_parameters)))

def _get_missing_s3_parameters(self) -> list[str]:
"""Return the list of missing S3 parameters.

Returns:
List[str]: List of missing required S3 parameters.
"""
s3_parameters = self._s3_requirer.get_s3_connection_info()
return [param for param in self.REQUIRED_S3_PARAMETERS if param not in s3_parameters]

def _get_s3_parameters(self) -> dict[str, str]:
"""Retrieve S3 parameters from the S3 integrator relation.

Removes leading and trailing whitespaces from the parameters.

Returns:
Dict[str, str]: Dictionary of the S3 parameters.
"""
s3_parameters = self._s3_requirer.get_s3_connection_info()
for key, value in s3_parameters.items():
if isinstance(value, str):
s3_parameters[key] = value.strip()
return s3_parameters
Loading
Loading