diff --git a/lib/charms/data_platform_libs/v0/data_interfaces.py b/lib/charms/data_platform_libs/v0/data_interfaces.py index a266a748..8ed72ec1 100644 --- a/lib/charms/data_platform_libs/v0/data_interfaces.py +++ b/lib/charms/data_platform_libs/v0/data_interfaces.py @@ -300,7 +300,7 @@ def _on_topic_requested(self, event: TopicRequestedEvent): from enum import Enum from typing import Callable, Dict, List, Optional, Set, Tuple, Union -from ops import JujuVersion, Secret, SecretInfo, SecretNotFoundError +from ops import JujuVersion, Model, Secret, SecretInfo, SecretNotFoundError from ops.charm import ( CharmBase, CharmEvents, @@ -320,7 +320,7 @@ def _on_topic_requested(self, event: TopicRequestedEvent): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 28 +LIBPATCH = 29 PYDEPS = ["ops>=2.0.0"] @@ -479,7 +479,7 @@ class CachedSecret: def __init__( self, - charm: CharmBase, + model: Model, component: Union[Application, Unit], label: str, secret_uri: Optional[str] = None, @@ -488,7 +488,7 @@ def __init__( self._secret_content = {} self._secret_uri = secret_uri self.label = label - self.charm = charm + self._model = model self.component = component def add_secret(self, content: Dict[str, str], relation: Relation) -> Secret: @@ -499,7 +499,7 @@ def add_secret(self, content: Dict[str, str], relation: Relation) -> Secret: ) secret = self.component.add_secret(content, label=self.label) - if relation.app != self.charm.app: + if relation.app != self._model.app: # If it's not a peer relation, grant is to be applied secret.grant(relation) self._secret_uri = secret.id @@ -513,10 +513,10 @@ def meta(self) -> Optional[Secret]: if not (self._secret_uri or self.label): return try: - self._secret_meta = self.charm.model.get_secret(label=self.label) + self._secret_meta = self._model.get_secret(label=self.label) except SecretNotFoundError: if self._secret_uri: - self._secret_meta = self.charm.model.get_secret( + self._secret_meta = self._model.get_secret( id=self._secret_uri, label=self.label ) return self._secret_meta @@ -558,19 +558,31 @@ def get_info(self) -> Optional[SecretInfo]: if self.meta: return self.meta.get_info() + def remove(self) -> None: + """Remove secret.""" + if not self.meta: + raise SecretsUnavailableError("Non-existent secret was attempted to be removed.") + try: + self.meta.remove_all_revisions() + except SecretNotFoundError: + pass + self._secret_content = {} + self._secret_meta = None + self._secret_uri = None + class SecretCache: """A data structure storing CachedSecret objects.""" - def __init__(self, charm: CharmBase, component: Union[Application, Unit]): - self.charm = charm + def __init__(self, model: Model, component: Union[Application, Unit]): + self._model = model self.component = component self._secrets: Dict[str, CachedSecret] = {} def get(self, label: str, uri: Optional[str] = None) -> Optional[CachedSecret]: """Getting a secret from Juju Secret store or cache.""" if not self._secrets.get(label): - secret = CachedSecret(self.charm, self.component, label, uri) + secret = CachedSecret(self._model, self.component, label, uri) if secret.meta: self._secrets[label] = secret return self._secrets.get(label) @@ -580,16 +592,111 @@ def add(self, label: str, content: Dict[str, str], relation: Relation) -> Cached if self._secrets.get(label): raise SecretAlreadyExistsError(f"Secret {label} already exists") - secret = CachedSecret(self.charm, self.component, label) + secret = CachedSecret(self._model, self.component, label) secret.add_secret(content, relation) self._secrets[label] = secret return self._secrets[label] + def remove(self, label: str) -> None: + """Remove a secret from the cache.""" + if secret := self.get(label): + secret.remove() + self._secrets.pop(label) + else: + logging.error("Non-existing Juju Secret was attempted to be removed %s", label) + # Base DataRelation -class DataRelation(Object, ABC): +class DataDict(dict): + """Python Standard Library 'dict' - like representation of Relation Data.""" + + def __init__(self, relation_data: "DataRelation", relation_id: int): + self.relation_data = relation_data + self.relation_id = relation_id + + def all_data(self) -> Dict[str, str]: + """Return the full content of the Abstract Relation Data dictionary.""" + result = self.relation_data.fetch_my_relation_data([self.relation_id]) + try: + result_remote = self.relation_data.fetch_relation_data([self.relation_id]) + except NotImplementedError: + result_remote = {} + if result: + result_remote.update(result) + return result_remote.get(self.relation_id, {}) + + def __setitem__(self, key: str, item: str) -> None: + """Set an item of the Abstract Relation Data dictionary.""" + self.relation_data.update_relation_data(self.relation_id, {key: item}) + + def __getitem__(self, key: str) -> Optional[str]: + """Get an item of the Abstract Relation Data dictionary.""" + result = None + if not (result := self.relation_data.fetch_my_relation_field(self.relation_id, key)): + try: + result = self.relation_data.fetch_relation_field(self.relation_id, key) + except NotImplementedError: + pass + return result + + def __repr__(self) -> str: + """String representation Abstract Relation Data dictionary.""" + return repr(self.all_data()) + + def __len__(self) -> int: + """Length of the Abstract Relation Data dictionary.""" + return len(self.all_data()) + + def __delitem__(self, key: str) -> None: + """Delete an item of the Abstract Relation Data dictionary.""" + self.relation_data.delete_relation_data(self.relation_id, [key]) + + def has_key(self, k) -> bool: + """Does the key exist in the Abstract Relation Data dictionary?""" + return k in self.all_data() + + def update(self, items: Dict[str, str]): + """Update the Abstract Relation Data dictionary.""" + self.relation_data.update_relation_data(self.relation_id, items) + + def keys(self) -> List[str]: + """Keys of the Abstract Relation Data dictionary.""" + return list(self.all_data().keys()) + + def values(self) -> List[str]: + """Values of the Abstract Relation Data dictionary.""" + return list(self.all_data().values()) + + def items(self) -> List[Tuple[str, str]]: + """Items of the Abstract Relation Data dictionary.""" + return list(self.all_data().items()) + + def pop(self, item: str) -> str: + """Pop an item of the Abstract Relation Data dictionary.""" + result = self.relation_data.fetch_my_relation_field(self.relation_id, item) + if not result: + raise KeyError(f"Item {item} doesn't exist.") + self.relation_data.delete_relation_data(self.relation_id, [item]) + return result + + def __contains__(self, item: str) -> bool: + """Does the Abstract Relation Data dictionary contain item?""" + return item in self.all_data().values() + + def __iter__(self): + """Iterate through the Abstract Relation Data dictionary.""" + return iter(self.all_data()) + + def get(self, key: str, default: Optional[str] = None) -> Optional[str]: + """Safely get an item of the Abstract Relation Data dictionary.""" + if result := self[key]: + return result + return default + + +class DataRelation(ABC): """Base relation data mainpulation (abstract) class.""" SCOPE = Scope.APP @@ -603,26 +710,29 @@ class DataRelation(Object, ABC): "tls-ca": SecretGroup.TLS, } - def __init__(self, charm: CharmBase, relation_name: str) -> None: - super().__init__(charm, relation_name) - self.charm = charm - self.local_app = self.charm.model.app - self.local_unit = self.charm.unit + def __init__( + self, + charm: Union[CharmBase, Model], + relation_name: str, + ) -> None: + self._model = charm.model if isinstance(charm, CharmBase) else charm + self.local_app = self._model.app + self.local_unit = self._model.unit self.relation_name = relation_name - self.framework.observe( - charm.on[relation_name].relation_changed, - self._on_relation_changed_event, - ) self._jujuversion = None self.component = self.local_app if self.SCOPE == Scope.APP else self.local_unit - self.secrets = SecretCache(self.charm, self.component) + self.secrets = SecretCache(self._model, self.component) + + def as_dict(self, relation_id: int): + """Dict behavior representation of the Abstract Data.""" + return DataDict(self, relation_id) @property def relations(self) -> List[Relation]: """The list of Relation instances associated with this relation_name.""" return [ relation - for relation in self.charm.model.relations[self.relation_name] + for relation in self._model.relations[self.relation_name] if self._is_relation_active(relation) ] @@ -635,11 +745,6 @@ def secrets_enabled(self): # Mandatory overrides for internal/helper methods - @abstractmethod - def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: - """Event emitted when the relation data has changed.""" - raise NotImplementedError - @abstractmethod def _get_relation_secret( self, relation_id: int, group_mapping: SecretGroup, relation_name: Optional[str] = None @@ -799,7 +904,7 @@ def _process_secret_fields( # self.local_app is sufficient to check (ignored if Requires, never has secrets -- works if Provides) fallback_to_databag = ( req_secret_fields - and self.local_unit.is_leader() + and (self.local_unit == self._model.unit and self.local_unit.is_leader()) and set(req_secret_fields) & set(relation.data[self.component]) ) @@ -860,16 +965,12 @@ def _fetch_relation_data_with_secrets( normal_fields = [] if not fields: - if component not in relation.data or not relation.data[component]: + if component not in relation.data: return {} all_fields = list(relation.data[component].keys()) normal_fields = [field for field in all_fields if not self._is_secret_field(field)] - - # There must have been secrets there - if all_fields != normal_fields and req_secret_fields: - # So we assemble the full fields list (without 'secret-' fields) - fields = normal_fields + req_secret_fields + fields = normal_fields + req_secret_fields if req_secret_fields else normal_fields if fields: result, normal_fields = self._process_secret_fields( @@ -917,7 +1018,7 @@ def _delete_relation_data_without_secrets( def get_relation(self, relation_name, relation_id) -> Relation: """Safe way of retrieving a relation.""" - relation = self.charm.model.get_relation(relation_name, relation_id) + relation = self._model.get_relation(relation_name, relation_id) if not relation: raise DataInterfacesError( @@ -1023,13 +1124,42 @@ def delete_relation_data(self, relation_id: int, fields: List[str]) -> None: return self._delete_relation_data(relation, fields) +class DataRelationEventHandlers(Object): + """Requires-side of the relation.""" + + def __init__(self, charm: CharmBase, relation_data: DataRelation, unique_key: str = ""): + """Manager of base client relations.""" + if not unique_key: + unique_key = relation_data.relation_name + super().__init__(charm, unique_key) + + self.charm = charm + self.relation_data = relation_data + + self.framework.observe( + charm.on[self.relation_data.relation_name].relation_changed, + self._on_relation_changed_event, + ) + + # Event handlers + + @abstractmethod + def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: + """Event emitted when the relation data has changed.""" + raise NotImplementedError + + # Base DataProvides and DataRequires class DataProvides(DataRelation): """Base provides-side of the data products relation.""" - def __init__(self, charm: CharmBase, relation_name: str) -> None: + def __init__( + self, + charm: Union[CharmBase, Model], + relation_name: str, + ) -> None: super().__init__(charm, relation_name) def _diff(self, event: RelationChangedEvent) -> Diff: @@ -1136,8 +1266,6 @@ def _delete_relation_secret( ) return False - secret.set_content(new_content) - # Remove secret from the relation if it's fully gone if not new_content: field = self._generate_secret_field_name(group) @@ -1145,6 +1273,10 @@ def _delete_relation_secret( relation.data[self.component].pop(field) except KeyError: pass + label = self._generate_secret_label(self.relation_name, relation.id, group) + self.secrets.remove(label) + else: + secret.set_content(new_content) # Return the content that was removed return True @@ -1163,7 +1295,7 @@ def _get_relation_secret( if secret := self.secrets.get(label): return secret - relation = self.charm.model.get_relation(relation_name, relation_id) + relation = self._model.get_relation(relation_name, relation_id) if not relation: return @@ -1284,14 +1416,6 @@ def __init__( if additional_secret_fields: self._secret_fields += additional_secret_fields - self.framework.observe( - self.charm.on[relation_name].relation_created, self._on_relation_created_event - ) - self.framework.observe( - charm.on.secret_changed, - self._on_secret_changed_event, - ) - @property def secret_fields(self) -> Optional[List[str]]: """Local access to secrets field, in case they are being used.""" @@ -1328,7 +1452,7 @@ def _register_secret_to_relation( label = self._generate_secret_label(relation_name, relation_id, group) # Fetchin the Secret's meta information ensuring that it's locally getting registered with - CachedSecret(self.charm, self.component, label, secret_id).meta + CachedSecret(self._model, self.component, label, secret_id).meta def _register_secrets_to_relation(self, relation: Relation, params_name_list: List[str]): """Make sure that secrets of the provided list are locally 'registered' from the databag. @@ -1388,23 +1512,6 @@ def is_resource_created(self, relation_id: Optional[int] = None) -> bool: else False ) - # Event handlers - - def _on_relation_created_event(self, event: RelationCreatedEvent) -> None: - """Event emitted when the relation is created.""" - if not self.local_unit.is_leader(): - return - - if self.secret_fields: - set_encoded_field( - event.relation, self.charm.app, REQ_SECRET_FIELDS, self.secret_fields - ) - - @abstractmethod - def _on_secret_changed_event(self, event: RelationChangedEvent) -> None: - """Event emitted when the relation data has changed.""" - raise NotImplementedError - # Mandatory internal overrides @juju_secrets_only @@ -1463,11 +1570,48 @@ def _delete_relation_data(self, relation: Relation, fields: List[str]) -> None: fetch_my_relation_field = leader_only(DataRelation.fetch_my_relation_field) +class DataRequiresEventHandlers(DataRelationEventHandlers): + """Requires-side of the relation.""" + + def __init__(self, charm: CharmBase, relation_data: DataRequires, unique_key: str = ""): + """Manager of base client relations.""" + super().__init__(charm, relation_data, unique_key) + + self.framework.observe( + self.charm.on[relation_data.relation_name].relation_created, + self._on_relation_created_event, + ) + self.framework.observe( + charm.on.secret_changed, + self._on_secret_changed_event, + ) + + # Event handlers + + def _on_relation_created_event(self, event: RelationCreatedEvent) -> None: + """Event emitted when the relation is created.""" + if not self.relation_data.local_unit.is_leader(): + return + + if self.relation_data.secret_fields: # pyright: ignore [reportAttributeAccessIssue] + set_encoded_field( + event.relation, + self.charm.app, + REQ_SECRET_FIELDS, + self.relation_data.secret_fields, # pyright: ignore [reportAttributeAccessIssue] + ) + + @abstractmethod + def _on_secret_changed_event(self, event: RelationChangedEvent) -> None: + """Event emitted when the relation data has changed.""" + raise NotImplementedError + + # Base DataPeer -class DataPeer(DataRequires, DataProvides): - """Represents peer relations.""" +class DataPeerData(DataRequires, DataProvides): + """Represents peer relations data.""" SECRET_FIELDS = ["operator-password"] SECRET_FIELD_NAME = "internal_secret" @@ -1484,7 +1628,11 @@ def __init__( ): """Manager of base client relations.""" DataRequires.__init__( - self, charm, relation_name, extra_user_roles, additional_secret_fields + self, + charm, + relation_name, + extra_user_roles, + additional_secret_fields, ) self.secret_field_name = secret_field_name if secret_field_name else self.SECRET_FIELD_NAME self.deleted_label = deleted_label @@ -1497,18 +1645,10 @@ def scope(self) -> Optional[Scope]: if isinstance(self.component, Unit): return Scope.UNIT - def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: - """Event emitted when the relation has changed.""" - pass - - def _on_secret_changed_event(self, event: SecretChangedEvent) -> None: - """Event emitted when the secret has changed.""" - pass - def _generate_secret_label( self, relation_name: str, relation_id: int, group_mapping: SecretGroup ) -> str: - members = [self.charm.app.name] + members = [self._model.app.name] if self.scope: members.append(self.scope.value) return f"{'.'.join(members)}" @@ -1533,7 +1673,7 @@ def _get_relation_secret( if not relation_name: relation_name = self.relation_name - relation = self.charm.model.get_relation(relation_name, relation_id) + relation = self._model.get_relation(relation_name, relation_id) if not relation: return @@ -1663,8 +1803,49 @@ def fetch_relation_field( fetch_my_relation_field = DataRelation.fetch_my_relation_field -class DataPeerUnit(DataPeer): - """Unit databag representation.""" +class DataPeerEventHandlers(DataRelationEventHandlers): + """Requires-side of the relation.""" + + def __init__(self, charm: CharmBase, relation_data: DataRequires, unique_key: str = ""): + """Manager of base client relations.""" + super().__init__(charm, relation_data, unique_key) + + def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: + """Event emitted when the relation has changed.""" + pass + + def _on_secret_changed_event(self, event: SecretChangedEvent) -> None: + """Event emitted when the secret has changed.""" + pass + + +class DataPeer(DataPeerData, DataPeerEventHandlers): + """Represents peer relations.""" + + def __init__( + self, + charm, + relation_name: str, + extra_user_roles: Optional[str] = None, + additional_secret_fields: Optional[List[str]] = [], + secret_field_name: Optional[str] = None, + deleted_label: Optional[str] = None, + unique_key: str = "", + ): + DataPeerData.__init__( + self, + charm, + relation_name, + extra_user_roles, + additional_secret_fields, + secret_field_name, + deleted_label, + ) + DataPeerEventHandlers.__init__(self, charm, self, unique_key) + + +class DataPeerUnitData(DataPeerData): + """Unit data abstraction representation.""" SCOPE = Scope.UNIT @@ -1672,6 +1853,75 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) +class DataPeerUnit(DataPeerUnitData, DataPeerEventHandlers): + """Unit databag representation.""" + + def __init__( + self, + charm, + relation_name: str, + extra_user_roles: Optional[str] = None, + additional_secret_fields: Optional[List[str]] = [], + secret_field_name: Optional[str] = None, + deleted_label: Optional[str] = None, + unique_key: str = "", + ): + DataPeerData.__init__( + self, + charm, + relation_name, + extra_user_roles, + additional_secret_fields, + secret_field_name, + deleted_label, + ) + DataPeerEventHandlers.__init__(self, charm, self, unique_key) + + +class DataPeerOtherUnitData(DataPeerUnitData): + """Unit data abstraction representation.""" + + def __init__(self, unit: Unit, *args, **kwargs): + super().__init__(*args, **kwargs) + self.local_unit = unit + self.component = unit + + +class DataPeerOtherUnitEventHandlers(DataPeerEventHandlers): + """Requires-side of the relation.""" + + def __init__(self, charm: CharmBase, relation_data: DataPeerUnitData): + """Manager of base client relations.""" + unique_key = f"{relation_data.relation_name}-{relation_data.local_unit.name}" + super().__init__(charm, relation_data, unique_key=unique_key) + + +class DataPeerOtherUnit(DataPeerOtherUnitData, DataPeerOtherUnitEventHandlers): + """Unit databag representation for another unit than the executor.""" + + def __init__( + self, + unit: Unit, + charm: CharmBase, + relation_name: str, + extra_user_roles: Optional[str] = None, + additional_secret_fields: Optional[List[str]] = [], + secret_field_name: Optional[str] = None, + deleted_label: Optional[str] = None, + unique_key: str = "", + ): + DataPeerData.__init__( + self, + charm, + relation_name, + extra_user_roles, + additional_secret_fields, + secret_field_name, + deleted_label, + ) + DataPeerEventHandlers.__init__(self, charm, self, unique_key) + + # General events @@ -1917,29 +2167,12 @@ class DatabaseRequiresEvents(CharmEvents): # Database Provider and Requires -class DatabaseProvides(DataProvides): - """Provider-side of the database relations.""" +class DatabaseProvidesData(DataProvides): + """Provider-side data of the database relations.""" - on = DatabaseProvidesEvents() # pyright: ignore [reportAssignmentType] - - def __init__(self, charm: CharmBase, relation_name: str) -> None: + def __init__(self, charm: Union[CharmBase, Model], relation_name: str) -> None: super().__init__(charm, relation_name) - def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: - """Event emitted when the relation has changed.""" - # Leader only - if not self.local_unit.is_leader(): - return - # Check which data has changed to emit customs events. - diff = self._diff(event) - - # Emit a database requested event if the setup key (database name and optional - # extra user roles) was added to the relation databag by the application. - if "database" in diff.added: - getattr(self.on, "database_requested").emit( - event.relation, app=event.app, unit=event.unit - ) - def set_database(self, relation_id: int, database_name: str) -> None: """Set database name. @@ -2012,14 +2245,49 @@ def set_version(self, relation_id: int, version: str) -> None: self.update_relation_data(relation_id, {"version": version}) -class DatabaseRequires(DataRequires): - """Requires-side of the database relation.""" +class DatabaseProvidesEventHandlers(DataRelationEventHandlers): + """Provider-side of the database relation handlers.""" - on = DatabaseRequiresEvents() # pyright: ignore [reportAssignmentType] + on = DatabaseProvidesEvents() # pyright: ignore [reportAssignmentType] + + def __init__( + self, charm: CharmBase, relation_data: DatabaseProvidesData, unique_key: str = "" + ): + """Manager of base client relations.""" + super().__init__(charm, relation_data, unique_key) + # Just to calm down pyright, it can't parse that the same type is being used in the super() call above + self.relation_data = relation_data + + def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: + """Event emitted when the relation has changed.""" + # Leader only + if not self.relation_data.local_unit.is_leader(): + return + # Check which data has changed to emit customs events. + diff = self.relation_data._diff(event) + + # Emit a database requested event if the setup key (database name and optional + # extra user roles) was added to the relation databag by the application. + if "database" in diff.added: + getattr(self.on, "database_requested").emit( + event.relation, app=event.app, unit=event.unit + ) + + +class DatabaseProvides(DatabaseProvidesData, DatabaseProvidesEventHandlers): + """Provider-side of the database relations.""" + + def __init__(self, charm: CharmBase, relation_name: str) -> None: + DatabaseProvidesData.__init__(self, charm, relation_name) + DatabaseProvidesEventHandlers.__init__(self, charm, self) + + +class DatabaseRequiresData(DataRequires): + """Requires-side of the database relation.""" def __init__( self, - charm, + charm: Union[CharmBase, Model], relation_name: str, database_name: str, extra_user_roles: Optional[str] = None, @@ -2033,18 +2301,84 @@ def __init__( self.relations_aliases = relations_aliases self.external_node_connectivity = external_node_connectivity + def is_postgresql_plugin_enabled(self, plugin: str, relation_index: int = 0) -> bool: + """Returns whether a plugin is enabled in the database. + + Args: + plugin: name of the plugin to check. + relation_index: optional relation index to check the database + (default: 0 - first relation). + + PostgreSQL only. + """ + # Psycopg 3 is imported locally to avoid the need of its package installation + # when relating to a database charm other than PostgreSQL. + import psycopg + + # Return False if no relation is established. + if len(self.relations) == 0: + return False + + relation_id = self.relations[relation_index].id + host = self.fetch_relation_field(relation_id, "endpoints") + + # Return False if there is no endpoint available. + if host is None: + return False + + host = host.split(":")[0] + + content = self.fetch_relation_data([relation_id], ["username", "password"]).get( + relation_id, {} + ) + user = content.get("username") + password = content.get("password") + + connection_string = ( + f"host='{host}' dbname='{self.database}' user='{user}' password='{password}'" + ) + try: + with psycopg.connect(connection_string) as connection: + with connection.cursor() as cursor: + cursor.execute( + "SELECT TRUE FROM pg_extension WHERE extname=%s::text;", (plugin,) + ) + return cursor.fetchone() is not None + except psycopg.Error as e: + logger.exception( + f"failed to check whether {plugin} plugin is enabled in the database: %s", str(e) + ) + return False + + +class DatabaseRequiresEventHandlers(DataRequiresEventHandlers): + """Requires-side of the relation.""" + + on = DatabaseRequiresEvents() # pyright: ignore [reportAssignmentType] + + def __init__( + self, charm: CharmBase, relation_data: DatabaseRequiresData, unique_key: str = "" + ): + """Manager of base client relations.""" + super().__init__(charm, relation_data, unique_key) + # Just to keep lint quiet, can't resolve inheritance. The same happened in super().__init__() above + self.relation_data = relation_data + # Define custom event names for each alias. - if relations_aliases: + if self.relation_data.relations_aliases: # Ensure the number of aliases does not exceed the maximum # of connections allowed in the specific relation. - relation_connection_limit = self.charm.meta.requires[relation_name].limit - if len(relations_aliases) != relation_connection_limit: + relation_connection_limit = self.charm.meta.requires[ + self.relation_data.relation_name + ].limit + if len(self.relation_data.relations_aliases) != relation_connection_limit: raise ValueError( f"The number of aliases must match the maximum number of connections allowed in the relation. " - f"Expected {relation_connection_limit}, got {len(relations_aliases)}" + f"Expected {relation_connection_limit}, got {len(self.relation_data.relations_aliases)}" ) - for relation_alias in relations_aliases: + if self.relation_data.relations_aliases: + for relation_alias in self.relation_data.relations_aliases: self.on.define_event(f"{relation_alias}_database_created", DatabaseCreatedEvent) self.on.define_event( f"{relation_alias}_endpoints_changed", DatabaseEndpointsChangedEvent @@ -2067,32 +2401,32 @@ def _assign_relation_alias(self, relation_id: int) -> None: relation_id: the identifier for a particular relation. """ # If no aliases were provided, return immediately. - if not self.relations_aliases: + if not self.relation_data.relations_aliases: return # Return if an alias was already assigned to this relation # (like when there are more than one unit joining the relation). - relation = self.charm.model.get_relation(self.relation_name, relation_id) - if relation and relation.data[self.local_unit].get("alias"): + relation = self.charm.model.get_relation(self.relation_data.relation_name, relation_id) + if relation and relation.data[self.relation_data.local_unit].get("alias"): return # Retrieve the available aliases (the ones that weren't assigned to any relation). - available_aliases = self.relations_aliases[:] - for relation in self.charm.model.relations[self.relation_name]: - alias = relation.data[self.local_unit].get("alias") + available_aliases = self.relation_data.relations_aliases[:] + for relation in self.charm.model.relations[self.relation_data.relation_name]: + alias = relation.data[self.relation_data.local_unit].get("alias") if alias: logger.debug("Alias %s was already assigned to relation %d", alias, relation.id) available_aliases.remove(alias) # Set the alias in the unit relation databag of the specific relation. - relation = self.charm.model.get_relation(self.relation_name, relation_id) + relation = self.charm.model.get_relation(self.relation_data.relation_name, relation_id) if relation: - relation.data[self.local_unit].update({"alias": available_aliases[0]}) + relation.data[self.relation_data.local_unit].update({"alias": available_aliases[0]}) # We need to set relation alias also on the application level so, # it will be accessible in show-unit juju command, executed for a consumer application unit - if self.local_unit.is_leader(): - self.update_relation_data(relation_id, {"alias": available_aliases[0]}) + if self.relation_data.local_unit.is_leader(): + self.relation_data.update_relation_data(relation_id, {"alias": available_aliases[0]}) def _emit_aliased_event(self, event: RelationChangedEvent, event_name: str) -> None: """Emit an aliased event to a particular relation if it has an alias. @@ -2116,60 +2450,11 @@ def _get_relation_alias(self, relation_id: int) -> Optional[str]: Returns: the relation alias or None if the relation was not found. """ - for relation in self.charm.model.relations[self.relation_name]: + for relation in self.charm.model.relations[self.relation_data.relation_name]: if relation.id == relation_id: - return relation.data[self.local_unit].get("alias") + return relation.data[self.relation_data.local_unit].get("alias") return None - def is_postgresql_plugin_enabled(self, plugin: str, relation_index: int = 0) -> bool: - """Returns whether a plugin is enabled in the database. - - Args: - plugin: name of the plugin to check. - relation_index: optional relation index to check the database - (default: 0 - first relation). - - PostgreSQL only. - """ - # Psycopg 3 is imported locally to avoid the need of its package installation - # when relating to a database charm other than PostgreSQL. - import psycopg - - # Return False if no relation is established. - if len(self.relations) == 0: - return False - - relation_id = self.relations[relation_index].id - host = self.fetch_relation_field(relation_id, "endpoints") - - # Return False if there is no endpoint available. - if host is None: - return False - - host = host.split(":")[0] - - content = self.fetch_relation_data([relation_id], ["username", "password"]).get( - relation_id, {} - ) - user = content.get("username") - password = content.get("password") - - connection_string = ( - f"host='{host}' dbname='{self.database}' user='{user}' password='{password}'" - ) - try: - with psycopg.connect(connection_string) as connection: - with connection.cursor() as cursor: - cursor.execute( - "SELECT TRUE FROM pg_extension WHERE extname=%s::text;", (plugin,) - ) - return cursor.fetchone() is not None - except psycopg.Error as e: - logger.exception( - f"failed to check whether {plugin} plugin is enabled in the database: %s", str(e) - ) - return False - def _on_relation_created_event(self, event: RelationCreatedEvent) -> None: """Event emitted when the database relation is created.""" super()._on_relation_created_event(event) @@ -2179,32 +2464,32 @@ def _on_relation_created_event(self, event: RelationCreatedEvent) -> None: # Sets both database and extra user roles in the relation # if the roles are provided. Otherwise, sets only the database. - if not self.local_unit.is_leader(): + if not self.relation_data.local_unit.is_leader(): return - event_data = {"database": self.database} + event_data = {"database": self.relation_data.database} - if self.extra_user_roles: - event_data["extra-user-roles"] = self.extra_user_roles + if self.relation_data.extra_user_roles: + event_data["extra-user-roles"] = self.relation_data.extra_user_roles # set external-node-connectivity field - if self.external_node_connectivity: + if self.relation_data.external_node_connectivity: event_data["external-node-connectivity"] = "true" - self.update_relation_data(event.relation.id, event_data) + self.relation_data.update_relation_data(event.relation.id, event_data) def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: """Event emitted when the database relation has changed.""" # Check which data has changed to emit customs events. - diff = self._diff(event) + diff = self.relation_data._diff(event) # Register all new secrets with their labels - if any(newval for newval in diff.added if self._is_secret_field(newval)): - self._register_secrets_to_relation(event.relation, diff.added) + if any(newval for newval in diff.added if self.relation_data._is_secret_field(newval)): + self.relation_data._register_secrets_to_relation(event.relation, diff.added) # Check if the database is created # (the database charm shared the credentials). - secret_field_user = self._generate_secret_field_name(SecretGroup.USER) + secret_field_user = self.relation_data._generate_secret_field_name(SecretGroup.USER) if ( "username" in diff.added and "password" in diff.added ) or secret_field_user in diff.added: @@ -2250,6 +2535,32 @@ def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: self._emit_aliased_event(event, "read_only_endpoints_changed") +class DatabaseRequires(DatabaseRequiresData, DatabaseRequiresEventHandlers): + """Provider-side of the database relations.""" + + def __init__( + self, + charm: CharmBase, + relation_name: str, + database_name: str, + extra_user_roles: Optional[str] = None, + relations_aliases: Optional[List[str]] = None, + additional_secret_fields: Optional[List[str]] = [], + external_node_connectivity: bool = False, + ): + DatabaseRequiresData.__init__( + self, + charm, + relation_name, + database_name, + extra_user_roles, + relations_aliases, + additional_secret_fields, + external_node_connectivity, + ) + DatabaseRequiresEventHandlers.__init__(self, charm, self) + + # Kafka related events @@ -2343,30 +2654,12 @@ class KafkaRequiresEvents(CharmEvents): # Kafka Provides and Requires -class KafkaProvides(DataProvides): +class KafkaProvidesData(DataProvides): """Provider-side of the Kafka relation.""" - on = KafkaProvidesEvents() # pyright: ignore [reportAssignmentType] - def __init__(self, charm: CharmBase, relation_name: str) -> None: super().__init__(charm, relation_name) - def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: - """Event emitted when the relation has changed.""" - # Leader only - if not self.local_unit.is_leader(): - return - - # Check which data has changed to emit customs events. - diff = self._diff(event) - - # Emit a topic requested event if the setup key (topic name and optional - # extra user roles) was added to the relation databag by the application. - if "topic" in diff.added: - getattr(self.on, "topic_requested").emit( - event.relation, app=event.app, unit=event.unit - ) - def set_topic(self, relation_id: int, topic: str) -> None: """Set topic name in the application relation databag. @@ -2404,10 +2697,43 @@ def set_zookeeper_uris(self, relation_id: int, zookeeper_uris: str) -> None: self.update_relation_data(relation_id, {"zookeeper-uris": zookeeper_uris}) -class KafkaRequires(DataRequires): - """Requires-side of the Kafka relation.""" +class KafkaProvidesEventHandlers(DataRelationEventHandlers): + """Provider-side of the Kafka relation.""" - on = KafkaRequiresEvents() # pyright: ignore [reportAssignmentType] + on = KafkaProvidesEvents() # pyright: ignore [reportAssignmentType] + + def __init__(self, charm: CharmBase, relation_data: KafkaProvidesData) -> None: + super().__init__(charm, relation_data) + # Just to keep lint quiet, can't resolve inheritance. The same happened in super().__init__() above + self.relation_data = relation_data + + def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: + """Event emitted when the relation has changed.""" + # Leader only + if not self.relation_data.local_unit.is_leader(): + return + + # Check which data has changed to emit customs events. + diff = self.relation_data._diff(event) + + # Emit a topic requested event if the setup key (topic name and optional + # extra user roles) was added to the relation databag by the application. + if "topic" in diff.added: + getattr(self.on, "topic_requested").emit( + event.relation, app=event.app, unit=event.unit + ) + + +class KafkaProvides(KafkaProvidesData, KafkaProvidesEventHandlers): + """Provider-side of the Kafka relation.""" + + def __init__(self, charm: CharmBase, relation_name: str) -> None: + KafkaProvidesData.__init__(self, charm, relation_name) + KafkaProvidesEventHandlers.__init__(self, charm, self) + + +class KafkaRequiresData(DataRequires): + """Requires-side of the Kafka relation.""" def __init__( self, @@ -2437,11 +2763,22 @@ def topic(self, value): raise ValueError(f"Error on topic '{value}', cannot be a wildcard.") self._topic = value + +class KafkaRequiresEventHandlers(DataRequiresEventHandlers): + """Requires-side of the Kafka relation.""" + + on = KafkaRequiresEvents() # pyright: ignore [reportAssignmentType] + + def __init__(self, charm: CharmBase, relation_data: KafkaRequiresData) -> None: + super().__init__(charm, relation_data) + # Just to keep lint quiet, can't resolve inheritance. The same happened in super().__init__() above + self.relation_data = relation_data + def _on_relation_created_event(self, event: RelationCreatedEvent) -> None: """Event emitted when the Kafka relation is created.""" super()._on_relation_created_event(event) - if not self.local_unit.is_leader(): + if not self.relation_data.local_unit.is_leader(): return # Sets topic, extra user roles, and "consumer-group-prefix" in the relation @@ -2450,7 +2787,7 @@ def _on_relation_created_event(self, event: RelationCreatedEvent) -> None: for f in ["consumer-group-prefix", "extra-user-roles", "topic"] } - self.update_relation_data(event.relation.id, relation_data) + self.relation_data.update_relation_data(event.relation.id, relation_data) def _on_secret_changed_event(self, event: SecretChangedEvent): """Event notifying about a new value of a secret.""" @@ -2459,16 +2796,16 @@ def _on_secret_changed_event(self, event: SecretChangedEvent): def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: """Event emitted when the Kafka relation has changed.""" # Check which data has changed to emit customs events. - diff = self._diff(event) + diff = self.relation_data._diff(event) # Check if the topic is created # (the Kafka charm shared the credentials). # Register all new secrets with their labels - if any(newval for newval in diff.added if self._is_secret_field(newval)): - self._register_secrets_to_relation(event.relation, diff.added) + if any(newval for newval in diff.added if self.relation_data._is_secret_field(newval)): + self.relation_data._register_secrets_to_relation(event.relation, diff.added) - secret_field_user = self._generate_secret_field_name(SecretGroup.USER) + secret_field_user = self.relation_data._generate_secret_field_name(SecretGroup.USER) if ( "username" in diff.added and "password" in diff.added ) or secret_field_user in diff.added: @@ -2491,6 +2828,30 @@ def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: return +class KafkaRequires(KafkaRequiresData, KafkaRequiresEventHandlers): + """Provider-side of the Kafka relation.""" + + def __init__( + self, + charm: CharmBase, + relation_name: str, + topic: str, + extra_user_roles: Optional[str] = None, + consumer_group_prefix: Optional[str] = None, + additional_secret_fields: Optional[List[str]] = [], + ) -> None: + KafkaRequiresData.__init__( + self, + charm, + relation_name, + topic, + extra_user_roles, + consumer_group_prefix, + additional_secret_fields, + ) + KafkaRequiresEventHandlers.__init__(self, charm, self) + + # Opensearch related events @@ -2541,29 +2902,12 @@ class OpenSearchRequiresEvents(CharmEvents): # OpenSearch Provides and Requires Objects -class OpenSearchProvides(DataProvides): +class OpenSearchProvidesData(DataProvides): """Provider-side of the OpenSearch relation.""" - on = OpenSearchProvidesEvents() # pyright: ignore[reportAssignmentType] - def __init__(self, charm: CharmBase, relation_name: str) -> None: super().__init__(charm, relation_name) - def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: - """Event emitted when the relation has changed.""" - # Leader only - if not self.local_unit.is_leader(): - return - # Check which data has changed to emit customs events. - diff = self._diff(event) - - # Emit an index requested event if the setup key (index name and optional extra user roles) - # have been added to the relation databag by the application. - if "index" in diff.added: - getattr(self.on, "index_requested").emit( - event.relation, app=event.app, unit=event.unit - ) - def set_index(self, relation_id: int, index: str) -> None: """Set the index in the application relation databag. @@ -2594,10 +2938,42 @@ def set_version(self, relation_id: int, version: str) -> None: self.update_relation_data(relation_id, {"version": version}) -class OpenSearchRequires(DataRequires): - """Requires-side of the OpenSearch relation.""" +class OpenSearchProvidesEventHandlers(DataRelationEventHandlers): + """Provider-side of the OpenSearch relation.""" - on = OpenSearchRequiresEvents() # pyright: ignore[reportAssignmentType] + on = OpenSearchProvidesEvents() # pyright: ignore[reportAssignmentType] + + def __init__(self, charm: CharmBase, relation_data: OpenSearchProvidesData) -> None: + super().__init__(charm, relation_data) + # Just to keep lint quiet, can't resolve inheritance. The same happened in super().__init__() above + self.relation_data = relation_data + + def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: + """Event emitted when the relation has changed.""" + # Leader only + if not self.relation_data.local_unit.is_leader(): + return + # Check which data has changed to emit customs events. + diff = self.relation_data._diff(event) + + # Emit an index requested event if the setup key (index name and optional extra user roles) + # have been added to the relation databag by the application. + if "index" in diff.added: + getattr(self.on, "index_requested").emit( + event.relation, app=event.app, unit=event.unit + ) + + +class OpenSearchProvides(OpenSearchProvidesData, OpenSearchProvidesEventHandlers): + """Provider-side of the OpenSearch relation.""" + + def __init__(self, charm: CharmBase, relation_name: str) -> None: + OpenSearchProvidesData.__init__(self, charm, relation_name) + OpenSearchProvidesEventHandlers.__init__(self, charm, self) + + +class OpenSearchRequiresData(DataRequires): + """Requires data side of the OpenSearch relation.""" def __init__( self, @@ -2612,27 +2988,38 @@ def __init__( self.charm = charm self.index = index + +class OpenSearchRequiresEventHandlers(DataRequiresEventHandlers): + """Requires events side of the OpenSearch relation.""" + + on = OpenSearchRequiresEvents() # pyright: ignore[reportAssignmentType] + + def __init__(self, charm: CharmBase, relation_data: OpenSearchRequiresData) -> None: + super().__init__(charm, relation_data) + # Just to keep lint quiet, can't resolve inheritance. The same happened in super().__init__() above + self.relation_data = relation_data + def _on_relation_created_event(self, event: RelationCreatedEvent) -> None: """Event emitted when the OpenSearch relation is created.""" super()._on_relation_created_event(event) - if not self.local_unit.is_leader(): + if not self.relation_data.local_unit.is_leader(): return # Sets both index and extra user roles in the relation if the roles are provided. # Otherwise, sets only the index. - data = {"index": self.index} - if self.extra_user_roles: - data["extra-user-roles"] = self.extra_user_roles + data = {"index": self.relation_data.index} + if self.relation_data.extra_user_roles: + data["extra-user-roles"] = self.relation_data.extra_user_roles - self.update_relation_data(event.relation.id, data) + self.relation_data.update_relation_data(event.relation.id, data) def _on_secret_changed_event(self, event: SecretChangedEvent): """Event notifying about a new value of a secret.""" if not event.secret.label: return - relation = self._relation_from_secret_label(event.secret.label) + relation = self.relation_data._relation_from_secret_label(event.secret.label) if not relation: logging.info( f"Received secret {event.secret.label} but couldn't parse, seems irrelevant" @@ -2658,14 +3045,14 @@ def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: This event triggers individual custom events depending on the changing relation. """ # Check which data has changed to emit customs events. - diff = self._diff(event) + diff = self.relation_data._diff(event) # Register all new secrets with their labels - if any(newval for newval in diff.added if self._is_secret_field(newval)): - self._register_secrets_to_relation(event.relation, diff.added) + if any(newval for newval in diff.added if self.relation_data._is_secret_field(newval)): + self.relation_data._register_secrets_to_relation(event.relation, diff.added) - secret_field_user = self._generate_secret_field_name(SecretGroup.USER) - secret_field_tls = self._generate_secret_field_name(SecretGroup.TLS) + secret_field_user = self.relation_data._generate_secret_field_name(SecretGroup.USER) + secret_field_tls = self.relation_data._generate_secret_field_name(SecretGroup.TLS) updates = {"username", "password", "tls", "tls-ca", secret_field_user, secret_field_tls} if len(set(diff._asdict().keys()) - updates) < len(diff): logger.info("authentication updated at: %s", datetime.now()) @@ -2695,3 +3082,25 @@ def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: event.relation, app=event.app, unit=event.unit ) # here check if this is the right design return + + +class OpenSearchRequires(OpenSearchRequiresData, OpenSearchRequiresEventHandlers): + """Requires-side of the OpenSearch relation.""" + + def __init__( + self, + charm: CharmBase, + relation_name: str, + index: str, + extra_user_roles: Optional[str] = None, + additional_secret_fields: Optional[List[str]] = [], + ) -> None: + OpenSearchRequiresData.__init__( + self, + charm, + relation_name, + index, + extra_user_roles, + additional_secret_fields, + ) + OpenSearchRequiresEventHandlers.__init__(self, charm, self) diff --git a/tests/unit/test_data_interfaces.py b/tests/unit/test_data_interfaces.py index 1da96ef7..4ff63156 100644 --- a/tests/unit/test_data_interfaces.py +++ b/tests/unit/test_data_interfaces.py @@ -889,10 +889,12 @@ def __init__(self, *args): ) self.framework.observe(self.requirer.on.endpoints_changed, self._on_endpoints_changed) self.framework.observe( - self.requirer.on.read_only_endpoints_changed, self._on_read_only_endpoints_changed + self.requirer.on.read_only_endpoints_changed, + self._on_read_only_endpoints_changed, ) self.framework.observe( - self.requirer.on.cluster1_database_created, self._on_cluster1_database_created + self.requirer.on.cluster1_database_created, + self._on_cluster1_database_created, ) def log_relation_size(self, prefix=""):