From f26beb2951ec5e2b7a45e17943dc21f85fc7c6e7 Mon Sep 17 00:00:00 2001 From: Daniel Obraczka Date: Mon, 25 Mar 2024 15:39:05 +0100 Subject: [PATCH] Add left/right for non-multi case in MGB (#36) --- sylloge/moviegraph_benchmark_loader.py | 13 +++++++++++++ tests/test_moviebenchmark.py | 3 +++ 2 files changed, 16 insertions(+) diff --git a/sylloge/moviegraph_benchmark_loader.py b/sylloge/moviegraph_benchmark_loader.py index a3abb1b..c898130 100644 --- a/sylloge/moviegraph_benchmark_loader.py +++ b/sylloge/moviegraph_benchmark_loader.py @@ -88,6 +88,11 @@ def __init__( dataset_names=ds_names, ds_prefix_tuples=GP_TO_DS_PREFIX[graph_pair], ) + if graph_pair != MULTI: + self.rel_triples_left = self.rel_triples[0] + self.rel_triples_right = self.rel_triples[1] + self.attr_triples_left = self.attr_triples[0] + self.attr_triples_right = self.attr_triples[1] def initial_read(self, backend: BACKEND_LITERAL): assert self._ds_prefixes @@ -130,6 +135,14 @@ def initial_read(self, backend: BACKEND_LITERAL): "folds": folds, } + def __repr__(self) -> str: + repr_str = super().__repr__() + if self.graph_pair != MULTI: + return repr_str.replace("triples_0", "triples_left").replace( + "triples_1", "triples_right" + ) + return repr_str + @property def _canonical_name(self) -> str: return f"{self.__class__.__name__}_{self.graph_pair}" diff --git a/tests/test_moviebenchmark.py b/tests/test_moviebenchmark.py index 6b7926f..d406677 100644 --- a/tests/test_moviebenchmark.py +++ b/tests/test_moviebenchmark.py @@ -131,3 +131,6 @@ def test_movie_benchmark_mock( assert fold.test is not None assert fold.val is not None assert ds._ds_prefixes is not None + if ds.graph_pair != MULTI: + assert "triples_left" in ds.__repr__() + assert ds.rel_triples_left is not None