Skip to content

Commit

Permalink
added tests for split
Browse files Browse the repository at this point in the history
  • Loading branch information
maxme1 committed Jun 27, 2023
1 parent 8a91010 commit 9713985
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 48 deletions.
1 change: 1 addition & 0 deletions connectome/interface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .edges import *
from .metaclasses import Mixin, Source, Transform
from .nodes import Input, InverseInput, InverseOutput, Output, Parameter
from .split import Split
55 changes: 55 additions & 0 deletions connectome/interface/split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import Union, Collection, Callable, Iterable

from ..containers import EdgesBag
from ..engine import IdentityEdge
from ..interface.factory import TransformFactory
from ..interface.metaclasses import APIMeta
from ..layers.split import SplitBase


class SplitFactory(TransformFactory):
_part_name = '__part__'
_split_name = '__split__'
layer_cls = SplitBase

def __init__(self, layer: str, scope):
self._split: Callable = None
super().__init__(layer, scope)

def _prepare_layer_arguments(self, container: EdgesBag, properties: Iterable[str]):
assert not properties, properties
return self._split, container

def _before_collect(self):
super()._before_collect()
self.edges.append(IdentityEdge().bind(self.inputs[self._part_name], self.parameters[self._part_name]))
self.magic_dispatch[self._split_name] = self._handle_split

def _handle_split(self, value):
assert self._split is None, self._split
self._split = value

def _after_collect(self):
super()._after_collect()
assert self._split is not None
assert not self.special_methods, self.special_methods


# TODO: Examples
class Split(SplitBase, metaclass=APIMeta, __factory=SplitFactory):
"""
Split a dataset entries into several parts.
This layer requires a `__split__` magic method, which takes an entry id, and returns a list of parts -
(part_id, part_context) pairs, "part_id" will become the part's id, and "part_context" is accessible in other
methods as part-specific useful info.
"""

__inherit__: Union[str, Collection[str], bool] = ()
__exclude__: Union[str, Collection[str]] = ()

def __init__(self, *args, **kwargs):
raise NotImplementedError

def __split__(*args, **kwargs):
raise NotImplementedError
57 changes: 9 additions & 48 deletions connectome/layers/split.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union, Collection, Callable, Iterable, Generator, Any
from typing import Callable, Generator, Any

from .base import Layer
from .chain import connect
Expand All @@ -9,8 +9,6 @@
StaticHash, IdentityEdge, FunctionEdge, HashBarrier, Node, CacheEdge, Details, TreeNode
)
from ..exceptions import DependencyError
from ..interface.factory import TransformFactory
from ..interface.metaclasses import APIMeta
from ..utils import extract_signature, AntiSet, node_to_dict


Expand Down Expand Up @@ -94,51 +92,6 @@ def _connect(self, previous: EdgesBag) -> EdgesBag:
)


class SplitFactory(TransformFactory):
_part_name = '__part__'
_split_name = '__split__'
layer_cls = SplitBase

def __init__(self, layer: str, scope):
self._split: Callable = None
super().__init__(layer, scope)

def _prepare_layer_arguments(self, container: EdgesBag, properties: Iterable[str]):
assert not properties, properties
return self._split, container

def _before_collect(self):
super()._before_collect()
self.edges.append(IdentityEdge().bind(self.inputs[self._part_name], self.parameters[self._part_name]))
self.magic_dispatch[self._split_name] = self._handle_split

def _handle_split(self, value):
assert self._split is None, self._split
self._split = value

def _after_collect(self):
super()._after_collect()
assert self._split is not None
assert not self.special_methods, self.special_methods


# TODO: Examples
class Split(SplitBase, metaclass=APIMeta, __factory=SplitFactory):
"""
Split a dataset entries into several parts.
This layer requires a `__split__` magic method, which takes an entry id, and returns a list of parts -
(part_id, part_context) pairs, "part_id" will become the part's id, and "part_context" is accessible in other
methods as part-specific useful info.
"""

__inherit__: Union[str, Collection[str], bool] = ()
__exclude__: Union[str, Collection[str]] = ()

def __split__(*args, **kwargs):
raise NotImplementedError


class SplitMapping(StaticGraph, StaticHash):
def __init__(self, graph: Graph):
super().__init__(arity=1)
Expand All @@ -157,3 +110,11 @@ def evaluate(self) -> Generator[Request, Response, Any]:
mapping[new] = key, part

return mapping


def chain_edges(inputs, output, *edges):
tmp = inputs
for idx, edge in enumerate(edges, 1):
inputs = tmp
tmp = output if idx == len(edges) else Node('$aux')
yield edge.bind(inputs, tmp)
56 changes: 56 additions & 0 deletions tests/test_layers/test_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import pytest

from connectome import Source, Split, meta


class A(Source):
@meta
def ids():
return tuple('012')

def x(i):
return [f'x-{i}-{j}' for j in range(int(i) + 1)]

def y(i):
return [f'y-{i}-{j}' for j in range(int(i) + 1)]


def test_split():
class SplitList(Split):
def __split__(id, x):
for idx, entry in enumerate(x):
yield f'{id}-{idx}', idx

def x(x, __part__):
return x[__part__]

def y(y, __part__):
return y[__part__]

a = A()
ds = a >> SplitList()
assert ds.ids == ('0-0', '1-0', '1-1', '2-0', '2-1', '2-2')
assert [ds.x('0-0')] == a.x('0')
assert [ds.x('1-0'), ds.x('1-1')] == a.x('1')


@pytest.mark.xfail
def test_split_with_args():
class SplitListWithArgs(Split):
_separator: str

def __split__(id, x, _separator):
for idx, entry in enumerate(x):
yield f'{id}{_separator}{idx}', idx

def x(x, __part__):
return x[__part__]

def y(y, __part__):
return y[__part__]

a = A()
ds = a >> SplitListWithArgs('@')
assert ds.ids == ('0@0', '1@0', '1@1', '2@0', '2@1', '2@2')
assert [ds.x('0@0')] == a.x('0')
assert [ds.x('1@0'), ds.x('1@1')] == a.x('1')

0 comments on commit 9713985

Please sign in to comment.