Skip to content

Commit

Permalink
Build expression trees in PointerSymbol
Browse files Browse the repository at this point in the history
This enables correct handling of precedence, e.g.

s = sym.A + sym.B
s * sym.C

will be correctly handled as (sym.A + sym.B) * sym.C instead of
sym.A + (sym.B * sym.C).

Fixes #267.
  • Loading branch information
jgosmann committed Sep 30, 2020
1 parent 2ea7a72 commit 86a98d8
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 14 deletions.
29 changes: 18 additions & 11 deletions nengo_spa/ast/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import nengo
import numpy as np

from nengo_spa.ast import expr_tree
from nengo_spa.ast.base import infer_types, Fixed, TypeCheckedBinaryOp
from nengo_spa.exceptions import SpaTypeError
from nengo_spa.semantic_pointer import SemanticPointer
Expand Down Expand Up @@ -51,14 +52,20 @@ def evaluate(self):
def expr(self):
return repr(self.value)

@property
def _expr_tree(self):
return expr_tree.Leaf(self.expr)

def __neg__(self):
return FixedScalar(-self.value)


class PointerSymbol(Symbol):
def __init__(self, expr, type_=TAnyVocab):
super(PointerSymbol, self).__init__(type_=type_)
self._expr = expr
self._expr_tree = (
expr if isinstance(expr, expr_tree.Node) else expr_tree.Leaf(expr)
)

def connect_to(self, sink, **kwargs):
return nengo.Connection(self.construct(), sink, **kwargs)
Expand All @@ -76,20 +83,20 @@ def evaluate(self):

@property
def expr(self):
return self._expr
return str(self._expr_tree)

def __invert__(self):
return PointerSymbol("~" + self.expr, self.type)
return PointerSymbol(~self._expr_tree, self.type)

def __neg__(self):
return PointerSymbol("-" + self.expr, self.type)
return PointerSymbol(-self._expr_tree, self.type)

def __add__(self, other):
other = as_symbolic_node(other)
if not isinstance(other, PointerSymbol):
return NotImplemented
type_ = infer_types(self, other)
return PointerSymbol(self.expr + "+" + other.expr, type_)
return PointerSymbol(self._expr_tree + other._expr_tree, type_)

def __radd__(self, other):
return self + other
Expand All @@ -99,25 +106,25 @@ def __sub__(self, other):
if not isinstance(other, PointerSymbol):
return NotImplemented
type_ = infer_types(self, other)
return PointerSymbol(self.expr + "-" + other.expr, type_)
return PointerSymbol(self._expr_tree - other._expr_tree, type_)

def __rsub__(self, other):
return (-self) + other

@symbolic_op
def __mul__(self, other):
type_ = infer_types(self, other)
return PointerSymbol(self.expr + "*" + other.expr, type_)
return PointerSymbol(self._expr_tree * other._expr_tree, type_)

@symbolic_op
def __rmul__(self, other):
type_ = infer_types(self, other)
return PointerSymbol(other.expr + "*" + self.expr, type_)
return PointerSymbol(other._expr_tree * self._expr_tree, type_)

@symbolic_op
def __truediv__(self, other):
type_ = infer_types(self, other)
return PointerSymbol(self.expr + "/" + other.expr, type_)
return PointerSymbol(self._expr_tree / other._expr_tree, type_)

def dot(self, other):
other = as_symbolic_node(other)
Expand Down Expand Up @@ -145,7 +152,7 @@ def translate(self, vocab, populate=None, keys=None, solver=None):
return SemanticPointer(np.dot(tr, self.evaluate().v), vocab=vocab)

def __repr__(self):
return "PointerSymbol({!r}, {!r})".format(self.expr, self.type)
return "PointerSymbol({!r}, {!r})".format(self._expr_tree, self.type)


class PointerSymbolFactory:
Expand All @@ -165,7 +172,7 @@ def __getattribute__(self, key):
return PointerSymbol(key)

def __call__(self, expr):
return PointerSymbol(re.sub(r"\s+", "", expr))
return PointerSymbol("({})".format(re.sub(r"\s+", "", expr)))


sym = PointerSymbolFactory()
15 changes: 12 additions & 3 deletions nengo_spa/ast/tests/test_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import nengo_spa as spa
from nengo_spa import sym
from nengo_spa.ast import expr_tree
from nengo_spa.ast.symbolic import FixedScalar, PointerSymbol
from nengo_spa.types import TVocabulary

Expand Down Expand Up @@ -153,7 +154,15 @@ def test_pointer_symbol_factory():
assert ps.expr == "A"


def test_pointer_symbol_factory_expressions():
ps = sym("A + B * C")
@pytest.mark.parametrize(
"ps,expected",
[
(sym("A + B * C"), "(A+B*C)"),
(sym.A + sym.B * sym.C, "A + B * C"),
(sym("(A + B) * C"), "((A+B)*C)"),
((sym.A + sym.B) * sym.C, "(A + B) * C"),
],
)
def test_pointer_symbol_factory_expressions(ps, expected):
assert isinstance(ps, PointerSymbol)
assert ps.expr == (sym.A + sym.B * sym.C).expr
assert ps.expr == expected

0 comments on commit 86a98d8

Please sign in to comment.