Skip to content

Commit

Permalink
chore: update charm libraries
Browse files Browse the repository at this point in the history
  • Loading branch information
telcobot committed Dec 13, 2024
1 parent 21c13ca commit 58eab09
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def _on_certificate_removed(self, event: CertificateRemovedEvent):
from typing import List, Mapping

from jsonschema import exceptions, validate # type: ignore[import-untyped]
from ops import Relation
from ops.charm import CharmBase, CharmEvents, RelationBrokenEvent, RelationChangedEvent
from ops.framework import EventBase, EventSource, Handle, Object

Expand All @@ -112,7 +113,7 @@ def _on_certificate_removed(self, event: CertificateRemovedEvent):

# Increment this PATCH version before using `charmcraft publish-lib` or reset
# to 0 if you are raising the major API version
LIBPATCH = 8
LIBPATCH = 9

PYDEPS = ["jsonschema"]

Expand Down Expand Up @@ -391,3 +392,11 @@ def _on_relation_broken(self, event: RelationBrokenEvent) -> None:
None
"""
self.on.certificate_removed.emit(relation_id=event.relation.id)

def is_ready(self, relation: Relation) -> bool:
"""Check if the relation is ready by checking that it has valid relation data."""
relation_data = _load_relation_data(relation.data[relation.app])
if not self._relation_data_is_valid(relation_data):
logger.warning("Provider relation data did not pass JSON Schema validation: ")
return False
return True
57 changes: 51 additions & 6 deletions lib/charms/tls_certificates_interface/v4/tls_certificates.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import NameOID
from ops import BoundEvent, CharmBase, CharmEvents, SecretExpiredEvent
from ops import BoundEvent, CharmBase, CharmEvents, SecretExpiredEvent, SecretRemoveEvent
from ops.framework import EventBase, EventSource, Handle, Object
from ops.jujuversion import JujuVersion
from ops.model import (
Expand All @@ -52,7 +52,7 @@

# Increment this PATCH version before using `charmcraft publish-lib` or reset
# to 0 if you are raising the major API version
LIBPATCH = 1
LIBPATCH = 3

PYDEPS = ["cryptography", "pydantic"]

Expand Down Expand Up @@ -305,6 +305,37 @@ def from_string(cls, certificate: str) -> "Certificate":
validity_start_time=validity_start_time,
)

def matches_private_key(self, private_key: PrivateKey) -> bool:
"""Check if this certificate matches a given private key.
Args:
private_key (PrivateKey): The private key to validate against.
Returns:
bool: True if the certificate matches the private key, False otherwise.
"""
try:
cert_object = x509.load_pem_x509_certificate(self.raw.encode())
key_object = serialization.load_pem_private_key(
private_key.raw.encode(), password=None
)

cert_public_key = cert_object.public_key()
key_public_key = key_object.public_key()

if not isinstance(cert_public_key, rsa.RSAPublicKey):
logger.warning("Certificate does not use RSA public key")
return False

if not isinstance(key_public_key, rsa.RSAPublicKey):
logger.warning("Private key is not an RSA key")
return False

return cert_public_key.public_numbers() == key_public_key.public_numbers()
except Exception as e:
logger.warning("Failed to validate certificate and private key match: %s", e)
return False


@dataclass(frozen=True)
class CertificateSigningRequest:
Expand Down Expand Up @@ -974,6 +1005,7 @@ def __init__(
self.framework.observe(charm.on[relationship_name].relation_created, self._configure)
self.framework.observe(charm.on[relationship_name].relation_changed, self._configure)
self.framework.observe(charm.on.secret_expired, self._on_secret_expired)
self.framework.observe(charm.on.secret_remove, self._on_secret_remove)
for event in refresh_events:
self.framework.observe(event, self._configure)

Expand All @@ -996,6 +1028,10 @@ def _configure(self, _: EventBase):
def _mode_is_valid(self, mode) -> bool:
return mode in [Mode.UNIT, Mode.APP]

def _on_secret_remove(self, event: SecretRemoveEvent) -> None:
"""Handle Secret Removed Event."""
event.secret.remove_revision(event.revision)

def _on_secret_expired(self, event: SecretExpiredEvent) -> None:
"""Handle Secret Expired Event.
Expand Down Expand Up @@ -1069,7 +1105,7 @@ def _get_app_or_unit(self) -> Union[Application, Unit]:
raise TLSCertificatesError("Invalid mode")

@property
def private_key(self) -> PrivateKey | None:
def private_key(self) -> Optional[PrivateKey]:
"""Return the private key."""
if not self._private_key_generated():
return None
Expand Down Expand Up @@ -1238,7 +1274,7 @@ def _send_certificate_requests(self):

def get_assigned_certificate(
self, certificate_request: CertificateRequestAttributes
) -> Tuple[ProviderCertificate | None, PrivateKey | None]:
) -> Tuple[Optional[ProviderCertificate], Optional[PrivateKey]]:
"""Get the certificate that was assigned to the given certificate request."""
for requirer_csr in self.get_csrs_from_requirer_relation_data():
if certificate_request == CertificateRequestAttributes.from_csr(
Expand All @@ -1248,7 +1284,9 @@ def get_assigned_certificate(
return self._find_certificate_in_relation_data(requirer_csr), self.private_key
return None, None

def get_assigned_certificates(self) -> Tuple[List[ProviderCertificate], PrivateKey | None]:
def get_assigned_certificates(
self,
) -> Tuple[List[ProviderCertificate], Optional[PrivateKey]]:
"""Get a list of certificates that were assigned to this or app."""
assigned_certificates = []
for requirer_csr in self.get_csrs_from_requirer_relation_data():
Expand All @@ -1259,12 +1297,19 @@ def get_assigned_certificates(self) -> Tuple[List[ProviderCertificate], PrivateK
def _find_certificate_in_relation_data(
self, csr: RequirerCertificateRequest
) -> Optional[ProviderCertificate]:
"""Return the certificate that match the given CSR."""
"""Return the certificate that matches the given CSR, validated against the private key."""
if not self.private_key:
return None
for provider_certificate in self.get_provider_certificates():
if (
provider_certificate.certificate_signing_request == csr.certificate_signing_request
and provider_certificate.certificate.is_ca == csr.is_ca
):
if not provider_certificate.certificate.matches_private_key(self.private_key):
logger.warning(
"Certificate does not match the private key. Ignoring invalid certificate."
)
continue
return provider_certificate
return None

Expand Down

0 comments on commit 58eab09

Please sign in to comment.