From b7083ef05419ee9910f8f8c52b6c90b5c412941f Mon Sep 17 00:00:00 2001 From: Daniel Himmelstein Date: Wed, 12 Jul 2023 13:35:43 -0400 Subject: [PATCH] allow custom Node_Info subclasses merges https://github.com/related-sciences/nxontology/pull/26 Override NXOntology._get_node_info_cls to set a custom Node_Info subclass. --- nxontology/ontology.py | 20 +++++++++++++++++--- nxontology/similarity.py | 10 +++++----- nxontology/tests/ontology_test.py | 27 +++++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 8 deletions(-) diff --git a/nxontology/ontology.py b/nxontology/ontology.py index 4a9b58b..e44ecdc 100644 --- a/nxontology/ontology.py +++ b/nxontology/ontology.py @@ -29,7 +29,10 @@ class NXOntology(Freezable, Generic[Node]): Edges should go from general to more specific. """ - def __init__(self, graph: nx.DiGraph | None = None): + def __init__( + self, + graph: nx.DiGraph | None = None, + ): self.graph = nx.DiGraph(graph) if graph is None: # Store the nxontology version that created the graph as metadata, @@ -209,15 +212,26 @@ def compute_similarities( metrics = self.similarity_metrics(node_0, node_1, ic_metric=ic_metric) yield metrics + @classmethod + def _get_node_info_cls(cls) -> type[Node_Info[Node]]: + """ + Return the Node_Info class to use for this ontology. + Subclasses can override this to use a custom Node_Info class. + For the complexity of typing this method, see + . + """ + return Node_Info + def node_info(self, node: Node) -> Node_Info[Node]: """ Return Node_Info instance for `node`. If frozen, cache node info in `self._node_info_cache`. """ + node_info_cls = self._get_node_info_cls() if not self.frozen: - return Node_Info(self, node) + return node_info_cls(self, node) if node not in self._node_info_cache: - self._node_info_cache[node] = Node_Info(self, node) + self._node_info_cache[node] = node_info_cls(self, node) return self._node_info_cache[node] @cache_on_frozen diff --git a/nxontology/similarity.py b/nxontology/similarity.py index 48c2298..5fbe596 100644 --- a/nxontology/similarity.py +++ b/nxontology/similarity.py @@ -125,17 +125,17 @@ class SimilarityIC(Similarity[Node]): def __init__( self, - graph: NXOntology[Node], + nxo: NXOntology[Node], node_0: Node, node_1: Node, ic_metric: str = "intrinsic_ic_sanchez", ): - super().__init__(graph, node_0, node_1) - - if ic_metric not in Node_Info.ic_metrics: + super().__init__(nxo, node_0, node_1) + ic_metrics = nxo._get_node_info_cls().ic_metrics + if ic_metric not in ic_metrics: raise ValueError( f"{ic_metric!r} is not a supported ic_metric. " - f"Choose from: {', '.join(Node_Info.ic_metrics)}." + f"Choose from: {', '.join(ic_metrics)}." ) self.ic_metric = ic_metric self.ic_metric_scaled = f"{ic_metric}_scaled" diff --git a/nxontology/tests/ontology_test.py b/nxontology/tests/ontology_test.py index 3a67f75..2f26767 100644 --- a/nxontology/tests/ontology_test.py +++ b/nxontology/tests/ontology_test.py @@ -1,10 +1,12 @@ import pathlib from datetime import date +from typing import Type import networkx import pytest from nxontology.exceptions import DuplicateError, NodeNotFound +from nxontology.node import Node_Info from nxontology.ontology import NXOntology @@ -151,3 +153,28 @@ def test_node_info_by_name() -> None: def test_node_info_not_found(metal_nxo_frozen: NXOntology[str]) -> None: with pytest.raises(NodeNotFound, match="not-a-metal not in graph"): metal_nxo_frozen.node_info("not-a-metal") + + +def test_custom_node_info_class() -> None: + class CustomNodeInfo(Node_Info[str]): + @property + def custom_property(self) -> str: + return "custom" + + class CustomNxontology(NXOntology[str]): + @classmethod + def _get_node_info_cls(cls) -> Type[CustomNodeInfo]: + return CustomNodeInfo + + def node_info(self, node: str) -> CustomNodeInfo: + info = super().node_info(node) + assert isinstance(info, CustomNodeInfo) + return info + + nxo = CustomNxontology() + nxo.add_node("a", name="a_name") + assert nxo.node_info("a").custom_property == "custom" + assert nxo.node_info_by_name("a_name").custom_property == "custom" # type: ignore [attr-defined] + similarity = nxo.similarity("a", "a") + assert similarity.info_0.custom_property == "custom" # type: ignore [attr-defined] + assert similarity.info_1.custom_property == "custom" # type: ignore [attr-defined]