Skip to content

Commit

Permalink
Merge pull request #96 from neuro-ml/dev
Browse files Browse the repository at this point in the history
caching the mapping in groupby; add CheckIds layer
  • Loading branch information
STNLd2 authored Jul 26, 2023
2 parents 4084679 + f57bdf4 commit bc8cb67
Show file tree
Hide file tree
Showing 13 changed files with 124 additions and 175 deletions.
2 changes: 1 addition & 1 deletion connectome/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.7.1'
__version__ = '0.8.0'
3 changes: 1 addition & 2 deletions connectome/containers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def __init__(self, inputs: Nodes, outputs: Nodes, edges: BoundEdges, context: Op
self.persistent: NameSet = persistent
self.optional: NodeSet = optional
self.context = context
self.backend = None

def freeze(self, parent: Union[Details, None] = None) -> 'EdgesBag':
"""
Expand Down Expand Up @@ -93,7 +92,7 @@ def freeze(self, parent: Union[Details, None] = None) -> 'EdgesBag':

def compile(self) -> GraphCompiler:
return GraphCompiler(
self.inputs, self.outputs, self.edges, self.virtual, self.optional, self.backend
self.inputs, self.outputs, self.edges, self.virtual, self.optional
)

def loopback(self, func: Callable, inputs: StringsLike, output: StringsLike) -> 'EdgesBag':
Expand Down
1 change: 0 additions & 1 deletion connectome/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .base import *
from .compiler import GraphCompiler
from .edges import *
from .executor import DefaultExecutor, SyncExecutor
from .graph import Graph
from .node_hash import ApplyHash, CustomHash, FilterHash, GraphHash, LeafHash, NodeHash, NodeHashes, TupleHash
2 changes: 1 addition & 1 deletion connectome/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

class Command(Enum):
ParentHash, CurrentHash, ParentValue, Payload, Await, Call = range(6)
Send, Store, Item, ComputeHash, Evaluate, AwaitFuture, Tuple, Return = range(-8, 0)
Send, Store, Item, ComputeHash, Evaluate, Tuple, Return = range(-7, 0)


HashOutput = Tuple[NodeHash, Any]
Expand Down
9 changes: 3 additions & 6 deletions connectome/engine/compiler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from concurrent.futures import Executor
from typing import Tuple, Union

from ..exceptions import DependencyError, FieldError
Expand All @@ -9,8 +8,7 @@


class GraphCompiler:
def __init__(self, inputs: Nodes, outputs: Nodes, edges: BoundEdges, virtuals: NameSet, optionals: Nodes,
executor: Executor):
def __init__(self, inputs: Nodes, outputs: Nodes, edges: BoundEdges, virtuals: NameSet, optionals: Nodes):
check_for_duplicates(inputs)
check_for_duplicates(outputs)
self._mapping = TreeNode.from_edges(edges)
Expand All @@ -24,7 +22,6 @@ def __init__(self, inputs: Nodes, outputs: Nodes, edges: BoundEdges, virtuals: N
# some optional nodes might be unreachable
self._optionals = {self._mapping[x] for x in optionals if x in self._mapping}
self._virtuals = virtuals
self._executor = executor

self._cache = {}
self._dependencies = self._outputs = None
Expand Down Expand Up @@ -71,7 +68,7 @@ def get_node(out):
# TODO: signature
return identity

return Graph(self._inputs, node, self._executor)
return Graph(self._inputs, node)

inputs, outputs = [], []
for name in item:
Expand All @@ -83,7 +80,7 @@ def get_node(out):
outputs.append(node)

product = TreeNode(f'({", ".join(item)})', (ProductEdge(len(item)), outputs), None)
return Graph(self._inputs | set(inputs), product, self._executor)
return Graph(self._inputs | set(inputs), product)

def __getitem__(self, item):
# TODO: deprecate
Expand Down
81 changes: 0 additions & 81 deletions connectome/engine/executor.py

This file was deleted.

19 changes: 8 additions & 11 deletions connectome/engine/graph.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import inspect
from collections import defaultdict
from concurrent.futures import Executor
from typing import Any, Sequence

from .base import Command, TreeNode, TreeNodes
from .executor import DefaultExecutor
from .node_hash import GraphHash, LeafHash, NodeHash
from .utils import EvictionCache
from .vm import execute


class Graph:
def __init__(self, inputs: TreeNodes, output: TreeNode, executor: Executor = None):
def __init__(self, inputs: TreeNodes, output: TreeNode):
validate_graph(inputs, output)
# TODO: need a cumulative eviction policy
counts = count_entries(inputs, output, multiplier=2)
Expand All @@ -23,7 +21,6 @@ def __init__(self, inputs: TreeNodes, output: TreeNode, executor: Executor = Non
self.inputs = inputs
self.output = output
self.counts = counts
self.executor = DefaultExecutor if executor is None else executor
self.__signature__ = signature
# TODO: deprecate
self.call = self.__call__
Expand All @@ -32,7 +29,7 @@ def __call__(*args, **kwargs):
self, *args = args
scope = self.__signature__.bind(*args, **kwargs)
hashes, cache = self._prepare_cache(scope.arguments)
return evaluate(self.output, hashes, cache, self.executor)
return evaluate(self.output, hashes, cache)

def __str__(self):
inputs = ', '.join(x.name for x in self.inputs)
Expand All @@ -54,22 +51,22 @@ def get_hash(self, *inputs: Any):
assert all(not isinstance(v, NodeHash) for v in inputs)

hashes, cache = self._prepare_cache({n.name: v for n, v in zip(self.inputs, inputs)})
result, _ = compute_hash(self.output, hashes, cache, self.executor)
result, _ = compute_hash(self.output, hashes, cache)
return result, (hashes, cache)

def get_value(self, hashes, cache) -> Any:
return evaluate(self.output, hashes, cache, self.executor)
return evaluate(self.output, hashes, cache)

def hash(self) -> GraphHash:
return GraphHash(hash_graph(self.inputs, self.output))


def evaluate(node: TreeNode, hashes: EvictionCache, cache: EvictionCache, executor: Executor):
return execute(Command.Evaluate, node, hashes, cache, executor)
def evaluate(node: TreeNode, hashes: EvictionCache, cache: EvictionCache):
return execute(Command.Evaluate, node, hashes, cache)


def compute_hash(node: TreeNode, hashes: EvictionCache, cache: EvictionCache, executor: Executor):
return execute(Command.ComputeHash, node, hashes, cache, executor)
def compute_hash(node: TreeNode, hashes: EvictionCache, cache: EvictionCache):
return execute(Command.ComputeHash, node, hashes, cache)


def validate_graph(inputs: TreeNodes, output: TreeNode):
Expand Down
84 changes: 19 additions & 65 deletions connectome/engine/vm.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,21 @@
from concurrent.futures import Executor
from collections import deque

from .base import Command
from .executor import AsyncLoop, Frame


# TODO: replace cache by a thunk tree
def execute(cmd, node, hashes, cache, executor: Executor):
root = Frame([node], [(Command.Return,), (cmd,)])
loop = AsyncLoop(root)
push, pop, peek = loop.push, loop.pop, loop.peek
push_command, pop_command = loop.push_command, loop.pop_command
next_frame, enqueue_frame = loop.next_frame, loop.enqueue_frame
dispose_frame = loop.dispose_frame
def execute(cmd, node, hashes, cache):
stack, commands = deque([node]), deque([(Command.Return,), (cmd,)])
push, pop, peek = stack.append, stack.pop, lambda: stack[-1]
push_command, pop_command = commands.append, commands.pop

while True:
assert not loop.frame.ready
cmd, *args = pop_command()

# return
if cmd == Command.Return:
assert not args
assert len(loop.frame.stack) == 1, len(loop.frame.stack)
value = pop()
if loop.frame is root:
loop.clear()
return value

loop.frame.value = value
loop.frame.ready = True
dispose_frame()
assert len(stack) == 1, len(stack)
return pop()

# communicate with edges
elif cmd == Command.Send:
Expand Down Expand Up @@ -56,18 +43,8 @@ def execute(cmd, node, hashes, cache, executor: Executor):
assert not args
node = pop()
if node in hashes:
value = hashes[node]
if value is _CACHE_SENTINEL:
# restore state
push_command((cmd, *args))
push(node)
# switch context
next_frame()

else:
push(value)
push(hashes[node])
else:
hashes[node] = _CACHE_SENTINEL
push_command((Command.Store, hashes, node))
push_command((Command.Send, node, node.edge.compute_hash()))
push(None)
Expand All @@ -76,20 +53,9 @@ def execute(cmd, node, hashes, cache, executor: Executor):
elif cmd == Command.Evaluate:
assert not args
node = pop()

if node in cache:
value = cache[node]
if value is _CACHE_SENTINEL:
# restore state
push_command((cmd, *args))
push(node)
# switch context
next_frame()

else:
push(value)
push(cache[node])
else:
cache[node] = _CACHE_SENTINEL
push_command((Command.Store, cache, node))
push_command((Command.Send, node, node.edge.evaluate()))
push(None)
Expand Down Expand Up @@ -122,44 +88,32 @@ def execute(cmd, node, hashes, cache, executor: Executor):

elif cmd == Command.Await:
node = pop()
push_command((Command.Tuple, len(args)))
for arg in args:
local = Frame([node], [(Command.Return,), arg])
push_command((Command.AwaitFuture, local))
enqueue_frame(local)
push_command((Command.Tuple, node, len(args), list(args)))

elif cmd == Command.Call:
pop() # pop the node
func, pos, kw = args
push_command((Command.AwaitFuture, executor.submit(func, *pos, **kw)))
next_frame()
push(func(*pos, **kw))

# utils
elif cmd == Command.Store:
storage, key = args
sentinel = storage[key]
assert sentinel is _CACHE_SENTINEL, sentinel

assert key not in storage
storage[key] = peek()

elif cmd == Command.Item:
key, = args
push(pop()[key])

elif cmd == Command.Tuple:
n, = args
push(tuple(pop() for _ in range(n)))

elif cmd == Command.AwaitFuture:
child, = args
if child.done():
push(child.result())
node, n, requests = args
if not requests:
push(tuple(pop() for _ in range(n)))
else:
push_command((cmd, *args))
next_frame()
request = requests.pop()
push_command((Command.Tuple, node, n, requests))
push_command(request)
push(node)

else:
raise RuntimeError('Unknown command', cmd) # pragma: no cover


_CACHE_SENTINEL = object()
3 changes: 2 additions & 1 deletion connectome/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .apply import Apply
from .base import CallableLayer, Chain, Layer, LazyChain, chained
from .cache import CacheLayer, CacheToDisk, CacheToRam
from .check_ids import CheckIds
from .columns import CacheColumns
from .debug import HashDigest
from .filter import Filter
from .group import GroupBy
from .join import Join, JoinMode
from .merge import Merge
from .group import GroupBy
Loading

0 comments on commit bc8cb67

Please sign in to comment.