Skip to content

Commit

Permalink
refactor: Factor out auto-unseal functionality into a manager class (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielArndt authored Nov 29, 2024
1 parent ec62186 commit 01b09f4
Show file tree
Hide file tree
Showing 17 changed files with 796 additions and 598 deletions.
163 changes: 85 additions & 78 deletions lib/charms/vault_k8s/v0/vault_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from dataclasses import dataclass
from enum import Enum
from io import IOBase
from typing import List, Optional, Protocol
from typing import List, Protocol

import hvac
import requests
Expand All @@ -28,7 +28,7 @@

# Increment this PATCH version before using `charmcraft publish-lib` or reset
# to 0 if you are raising the major API version
LIBPATCH = 21
LIBPATCH = 22


RAFT_STATE_ENDPOINT = "v1/sys/storage/raft/autopilot/state"
Expand Down Expand Up @@ -114,7 +114,7 @@ class VaultClientError(Exception):
"""Base class for exceptions raised by the Vault client."""


class Vault:
class VaultClient:
"""Class to interact with Vault through its API."""

def __init__(self, url: str, ca_cert_path: str | None):
Expand Down Expand Up @@ -169,6 +169,45 @@ def is_sealed(self) -> bool:
logging.error("Error while checking Vault seal status: %s", e)
raise VaultClientError(e) from e

def read(self, path: str) -> dict:
"""Read the data at the given path."""
try:
data = self._client.read(path)
except VaultError as e:
logger.error("Error while writing data to %s: %s", path, e)
return {}
if data is None:
return {}
if isinstance(data, requests.Response):
data = data.json()
return data.get("data", {})

def write(self, path: str, data: dict) -> bool:
"""Write the data at the given path."""
try:
response = self._client.write_data(path, data=data)
except VaultError as e:
logger.error("Error while writing data to %s: %s", path, e)
return False
logger.info("Wrote data to %s: %s", path, response)
return True

def list(self, path: str) -> List[str]:
"""List the keys at the given path."""
try:
data = self._client.list(path)
except VaultError as e:
logger.error("Error while listing keys at %s: %s", path, e)
return []
if data is None:
return []
if isinstance(data, requests.Response):
data = data.json()
try:
return data["data"]["keys"]
except KeyError:
return []

def needs_migration(self) -> bool:
"""Return true if the vault needs to be migrated, false otherwise."""
return self._client.seal_status["migration"] # type: ignore -- bad type hint in stubs
Expand Down Expand Up @@ -247,28 +286,44 @@ def enable_approle_auth_method(self) -> None:
except VaultError as e:
raise VaultClientError(e) from e

def configure_policy(self, policy_name: str, policy_path: str, **formatting_args: str) -> None:
"""Create/update a policy within vault.
def create_or_update_policy_from_file(
self, name: str, path: str, **formatting_args: str
) -> None:
"""Create/update a policy within vault, using the file contents as the policy.
Args:
policy_name: Name of the policy to create
policy_path: The path of the file where the policy is defined, ending with .hcl
name: Name of the policy to create
path: The path of the file where the policy is defined, ending with .hcl
**formatting_args: Additional arguments to format the policy
"""
with open(policy_path, "r") as f:
# TODO: Remove this method when it is no longer needed. Prefer create_or_update_policy.
with open(path, "r") as f:
policy = f.read()
try:
self._client.sys.create_or_update_policy(
name=policy_name,
name=name,
policy=policy if not formatting_args else policy.format(**formatting_args),
)
except VaultError as e:
raise VaultClientError(e) from e
logger.debug("Created or updated charm policy: %s", policy_name)
logger.debug("Created or updated charm policy: %s", name)

def create_or_update_policy(self, name: str, content: str) -> None:
"""Create/update a policy within vault.
Args:
name: Name of the policy to create
content: The policy content
"""
try:
self._client.sys.create_or_update_policy(name=name, policy=content)
except VaultError as e:
raise VaultClientError(e) from e
logger.debug("Created or updated charm policy: %s", name)

def configure_approle(
def create_or_update_approle(
self,
role_name: str,
name: str,
token_ttl=None,
token_max_ttl=None,
policies: List[str] | None = None,
Expand All @@ -278,23 +333,23 @@ def configure_approle(
"""Create/update a role within vault associating the supplied policies.
Args:
role_name: Name of the role to be created or updated
name: Name of the role to be created or updated
policies: The attached list of policy names this approle will have access to
token_ttl: Incremental lifetime for generated tokens, provided as a duration string such as "5m"
token_max_ttl: Maximum lifetime for generated tokens, provided as a duration string such as "5m"
token_period: The period within which the token must be renewed. See Vault documentation for more information.
cidrs: The list of IP networks that are allowed to authenticate
"""
self._client.auth.approle.create_or_update_approle(
role_name,
name,
bind_secret_id="true",
token_ttl=token_ttl,
token_max_ttl=token_max_ttl,
token_policies=policies,
token_bound_cidrs=cidrs,
token_period=token_period,
)
response = self._client.auth.approle.read_role_id(role_name)
response = self._client.auth.approle.read_role_id(name)
return response["data"]["role_id"]

def generate_role_secret_id(self, name: str, cidrs: List[str] | None = None) -> str:
Expand Down Expand Up @@ -440,24 +495,24 @@ def is_raft_cluster_healthy(self) -> bool:
"""Check if raft cluster is healthy."""
return self.get_raft_cluster_state()["healthy"]

def remove_raft_node(self, node_id: str) -> None:
def remove_raft_node(self, id: str) -> None:
"""Remove raft peer."""
try:
self._client.sys.remove_raft_node(server_id=node_id)
self._client.sys.remove_raft_node(server_id=id)
except (InternalServerError, ConnectionError) as e:
logger.warning("Error while removing raft node: %s", e)
return
logger.info("Removed raft node %s", node_id)
logger.info("Removed raft node %s", id)

def is_node_in_raft_peers(self, node_id: str) -> bool:
def is_node_in_raft_peers(self, id: str) -> bool:
"""Check if node is in raft peers."""
try:
raft_config = self._client.sys.read_raft_config()
except (InternalServerError, ConnectionError) as e:
logger.warning("Error while reading raft config: %s", e)
return False
for peer in raft_config["data"]["config"]["servers"]:
if peer["node_id"] == node_id:
if peer["node_id"] == id:
return True
return False

Expand All @@ -480,7 +535,7 @@ def is_common_name_allowed_in_pki_role(self, role: str, mount: str, common_name:
logger.warning("Role does not exist on the specified path.")
return False

def get_role_max_ttl(self, role: str, mount: str) -> Optional[int]:
def get_role_max_ttl(self, role: str, mount: str) -> int | None:
"""Get the max ttl for the specified PKI role in seconds."""
try:
return (
Expand Down Expand Up @@ -515,66 +570,18 @@ def make_latest_pki_issuer_default(self, mount: str) -> None:
except (TypeError, KeyError):
logger.error("Issuers config is not yet created")

def _get_autounseal_policy_name(self, relation_id: int) -> str:
"""Return the policy name for the given relation id."""
return f"charm-autounseal-{relation_id}"

def _get_autounseal_approle_name(self, relation_id: int) -> str:
"""Return the approle name for the given relation id."""
return f"charm-autounseal-{relation_id}"

def _get_autounseal_key_name(self, relation_id: int) -> str:
"""Return the key name for the given relation id."""
return str(relation_id)

def _create_autounseal_key(self, mount_point: str, relation_id: int) -> str:
"""Create a new autounseal key."""
key_name = self._get_autounseal_key_name(relation_id)
def create_transit_key(self, mount_point: str, key_name: str) -> None:
"""Create a new key in the transit backend."""
response = self._client.secrets.transit.create_key(mount_point=mount_point, name=key_name)
logging.debug(f"Created a new autounseal key: {response}")
return key_name

def _destroy_autounseal_key(self, mount_point, key_name):
"""Destroy the autounseal key."""
self._client.secrets.transit.delete_key(mount_point=mount_point, name=key_name)

def destroy_autounseal_credentials(self, relation_id: int, mount: str) -> None:
"""Destroy the approle and transit key for the given relation id."""
# Remove the approle
role_name = self._get_autounseal_approle_name(relation_id)
self._client.auth.approle.delete_role(role_name)
# Remove the policy
policy_name = self._get_autounseal_policy_name(relation_id)
self._client.sys.delete_policy(policy_name)
# Remove the transit key
# FIXME: This is currently disabled because we haven't figured out how
# to properly handle destroying the relation, yet. Destroying the key
# without migrating would make it impossible to recover the vault.
# key_name = self.get_autounseal_key_name(relation_id)
# self._destroy_autounseal_key(mount, key_name)

def create_autounseal_credentials(
self, relation_id: int, mount: str, policy_path: str
) -> tuple[str, str, str]:
"""Create auto-unseal credentials for the given relation id.
logging.debug("Created a new transit key. response=%s", response)

Args:
relation_id: The Juju relation id to use for the approle.
mount: The mount point for the transit backend.
policy_path: Path to a file that contains the autounseal policy.
def delete_role(self, name: str) -> None:
"""Delete the approle with the given name."""
return self._client.auth.approle.delete_role(name)

Returns:
A tuple containing the Role Id, Secret Id and Key Name.
"""
key_name = self._create_autounseal_key(mount, relation_id)
policy_name = self._get_autounseal_policy_name(relation_id)
self.configure_policy(policy_name, policy_path, mount=mount, key_name=key_name)

role_name = self._get_autounseal_approle_name(relation_id)
role_id = self.configure_approle(role_name, policies=[policy_name], token_period="60s")
secret_id = self.generate_role_secret_id(role_name)
return key_name, role_id, secret_id
def delete_policy(self, name: str) -> None:
"""Delete the policy with the given name."""
return self._client.sys.delete_policy(name)


def generate_pem_bundle(certificate: str, private_key: str) -> str:
Expand Down
Loading

0 comments on commit 01b09f4

Please sign in to comment.