diff --git a/pyproject.toml b/pyproject.toml index fc2e25b..11c8a24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "cosl" -version = "0.0.49" +version = "0.0.50" authors = [ { name = "sed-i", email = "82407168+sed-i@users.noreply.github.com" }, ] diff --git a/src/cosl/__init__.py b/src/cosl/__init__.py index b756575..dc95203 100644 --- a/src/cosl/__init__.py +++ b/src/cosl/__init__.py @@ -4,7 +4,7 @@ """Utils for observability Juju charms.""" from .cos_tool import CosTool -from .grafana_dashboard import GrafanaDashboard +from .grafana_dashboard import DashboardPath40UID, GrafanaDashboard, LZMABase64 from .juju_topology import JujuTopology from .mandatory_relation_pairs import MandatoryRelationPairs from .rules import AlertRules, RecordingRules @@ -13,6 +13,8 @@ "JujuTopology", "CosTool", "GrafanaDashboard", + "LZMABase64", + "DashboardPath40UID", "AlertRules", "RecordingRules", "MandatoryRelationPairs", diff --git a/src/cosl/grafana_dashboard.py b/src/cosl/grafana_dashboard.py index 09aa730..145ccc0 100644 --- a/src/cosl/grafana_dashboard.py +++ b/src/cosl/grafana_dashboard.py @@ -4,10 +4,13 @@ """Grafana Dashboard.""" import base64 +import binascii +import hashlib import json import logging import lzma -from typing import Any, Dict, Union +import warnings +from typing import Any, ClassVar, Dict, Tuple, Union logger = logging.getLogger(__name__) @@ -21,15 +24,19 @@ class GrafanaDashboard(str): @staticmethod def _serialize(raw_json: Union[str, bytes]) -> "GrafanaDashboard": - if not isinstance(raw_json, bytes): - raw_json = raw_json.encode("utf-8") - encoded = base64.b64encode(lzma.compress(raw_json)).decode("utf-8") - return GrafanaDashboard(encoded) + warnings.warn( + "GrafanaDashboard._serialize is deprecated; use LZMABase64.compress(json.dumps(...)) instead.", + category=DeprecationWarning, + ) + return GrafanaDashboard(LZMABase64.compress(raw_json)) def _deserialize(self) -> Dict[str, Any]: + warnings.warn( + "GrafanaDashboard._deserialize is deprecated; use json.loads(LZMABase64.decompress(...)) instead.", + category=DeprecationWarning, + ) try: - raw = lzma.decompress(base64.b64decode(self.encode("utf-8"))).decode() - return json.loads(raw) + return json.loads(LZMABase64.decompress(self)) except json.decoder.JSONDecodeError as e: logger.error("Invalid Dashboard format: %s", e) return {} @@ -37,3 +44,74 @@ def _deserialize(self) -> Dict[str, Any]: def __repr__(self): """Return string representation of self.""" return "" + + +class LZMABase64: + """A helper class for LZMA-compressed-base64-encoded strings. + + This is useful for transferring over juju relation data, which can only have keys of type string. + """ + + @classmethod + def compress(cls, raw_json: Union[str, bytes]) -> str: + """LZMA-compress and base64-encode into a string.""" + if not isinstance(raw_json, bytes): + raw_json = raw_json.encode("utf-8") + return base64.b64encode(lzma.compress(raw_json)).decode("utf-8") + + @classmethod + def decompress(cls, compressed: str) -> str: + """Decompress from base64-encoded-lzma-compressed string.""" + return lzma.decompress(base64.b64decode(compressed.encode("utf-8"))).decode() + + +class DashboardPath40UID: + """A helper class for dashboard UID of length 40, generated from charm name and dashboard path.""" + + length: ClassVar[int] = 40 + + @classmethod + def _hash(cls, components: Tuple[str, ...], length: int) -> str: + return hashlib.shake_256("-".join(components).encode("utf-8")).hexdigest(length) + + @classmethod + def generate(cls, charm_name: str, dashboard_path: str) -> str: + """Generate a dashboard uid from charm name and dashboard path. + + The combination of charm name and dashboard path (relative to the charm root) is guaranteed to be unique across + the ecosystem. By design, this intentionally does not take into account instances of the same charm with + different charm revisions, which could have different dashboard versions. + Ref: https://github.com/canonical/observability/pull/206 + + The max length grafana allows for a dashboard uid is 40. + Ref: https://grafana.com/docs/grafana/latest/developers/http_api/dashboard/#identifier-id-vs-unique-identifier-uid + + Args: + charm_name: The name of the charm (not app!) that owns the dashboard. + dashboard_path: Path (relative to charm root) to the dashboard file. + + Returns: A uid based on the input args. + """ + # Since the digest is bytes, we need to convert it to a charset that grafana accepts. + # Let's use hexdigest, which means 2 chars per byte, reducing our effective digest size to 20. + return cls._hash((charm_name, dashboard_path), cls.length // 2) + + @classmethod + def is_valid(cls, uid: str) -> bool: + """Check if the given UID is a valid "Path-40" UID. + + The UID must be of a particular length, and since we generate it with hexdigest() then we also know it must + unhexlify. + + This is not a bullet-proof check, because it's plausible that some dashboard would have a 40-length hexstring as + its uid, but given the current state of the ecosystem, it's quite unlikely. + """ + if not uid: + return False + if len(uid) != cls.length: + return False + try: + binascii.unhexlify(uid) + except binascii.Error: + return False + return True diff --git a/tests/test_grafana_dashboard.py b/tests/test_grafana_dashboard.py index a4f70a5..961651c 100644 --- a/tests/test_grafana_dashboard.py +++ b/tests/test_grafana_dashboard.py @@ -4,15 +4,56 @@ import json import unittest -from cosl import GrafanaDashboard +from cosl import DashboardPath40UID, GrafanaDashboard, LZMABase64 -class TestDashboard(unittest.TestCase): - """Tests the GrafanaDashboard class.""" +class TestRoundTripEncDec(unittest.TestCase): + """Tests the round-trip encoding/decoding of the GrafanaDashboard class.""" - def test_serializes_and_deserializes(self): - expected_output = {"msg": "this is the expected output after passing through the class."} + def test_round_trip(self): + d = { + "some": "dict", + "with": "keys", + "even": [{"nested": "types", "and_integers": [42, 42]}], + } + self.assertDictEqual(d, GrafanaDashboard._serialize(json.dumps(d))._deserialize()) - dash = GrafanaDashboard._serialize(json.dumps(expected_output)) - assert dash._deserialize() == expected_output +class TestLZMABase64(unittest.TestCase): + """Tests the round-trip encoding/decoding of the GrafanaDashboard class.""" + + def test_round_trip(self): + s = "starting point" + self.assertEqual(s, LZMABase64.decompress(LZMABase64.compress(s))) + + +class TestGenerateUID(unittest.TestCase): + """Spec for the UID generation logic.""" + + def test_uid_length_is_40(self): + self.assertEqual(40, len(DashboardPath40UID.generate("my-charm", "my-dash.json"))) + + def test_collisions(self): + """A very naive and primitive collision check that is meant to catch trivial errors.""" + self.assertNotEqual( + DashboardPath40UID.generate("some-charm", "dashboard1.json"), + DashboardPath40UID.generate("some-charm", "dashboard2.json"), + ) + + self.assertNotEqual( + DashboardPath40UID.generate("some-charm", "dashboard.json"), + DashboardPath40UID.generate("diff-charm", "dashboard.json"), + ) + + def test_validity(self): + """Make sure validity check fails for trivial cases.""" + self.assertFalse(DashboardPath40UID.is_valid("1234")) + self.assertFalse(DashboardPath40UID.is_valid("short non-hex string")) + self.assertFalse(DashboardPath40UID.is_valid("non-hex string, crafted to be 40 chars!!")) + + self.assertTrue(DashboardPath40UID.is_valid("0" * 40)) + self.assertTrue( + DashboardPath40UID.is_valid( + DashboardPath40UID.generate("some-charm", "dashboard.json") + ) + )