Skip to content

Commit

Permalink
Add solver argument to translate in action rules.
Browse files Browse the repository at this point in the history
Addresses #57.
  • Loading branch information
jgosmann committed Aug 20, 2017
1 parent 27e2118 commit 60f0eba
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 4 deletions.
5 changes: 5 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ Release History
Nengo networks.
(`#65 <https://github.com/nengo/nengo_spa/pull/65>`_,
`#26 <https://github.com/nengo/nengo_spa/issues/26>`_)
- Add a ``solver`` argument to the action rule's ``translate`` to use a solver
instead of an outer product to obtain the transformation matrix which can
give slightly better results.
(`#56 <https://github.com/nengo/nengo_spa/pull/56>`_,
`#57 <https://github.com/nengo/nengo_spa/issues/57>`_)

**Changed**

Expand Down
8 changes: 5 additions & 3 deletions nengo_spa/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@
Negative: '-' Source
DotProduct: 'dot(' Source ',' Source ')'
Reinterpret: 'reinterpret(' Source (',' VocabArg)? ')'
Translate: 'translate(' Source (',' VocabArg)? ')'
Translate: 'translate(' Source (',' TranslateArg)* ')'
TranslateArg : <valid Python identifier> '=' <valid Python argument> | VocabArg
VocabArg: 'vocab='? (Module | <valid Python identifier>)
Sink: <valid Python identifier> | <valid Python identifier> '.' Sink
Effect: Sink '=' Source
Expand Down Expand Up @@ -899,12 +900,13 @@ def __str__(self):


class Translate(Source):
def __init__(self, source, vocab=None, populate=None):
def __init__(self, source, vocab=None, populate=None, solver=None):
source = ensure_node(source)
super(Translate, self).__init__(staticity=source.staticity)
self.source = source
self.vocab = vocab
self.populate = populate
self.solver = solver

def infer_types(self, root_network, context_type):
if self.vocab is None:
Expand All @@ -926,7 +928,7 @@ def infer_types(self, root_network, context_type):

def construct(self, context):
tr = self.source.type.vocab.transform_to(
self.type.vocab, populate=self.populate)
self.type.vocab, populate=self.populate, solver=self.solver)
artifacts = self.source.construct(context)
return [a.add_transform(tr) for a in artifacts]

Expand Down
3 changes: 2 additions & 1 deletion nengo_spa/tests/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ def test_no_magic_vocab_transform():
(16, 16, 'reinterpret(a, b)', 'v1'),
(16, 32, 'translate(a, v2)', 'v2'),
(16, 32, 'translate(a)', 'v2'),
(16, 32, 'translate(a, b)', 'v2')])
(16, 32, 'translate(a, b)', 'v2'),
(16, 32, 'translate(a, solver=nengo.solvers.Lstsq())', 'v2')])
def test_casting_vocabs(d1, d2, method, lookup, Simulator, plt, rng):
v1 = spa.Vocabulary(d1, rng=rng)
v1.populate('A')
Expand Down

0 comments on commit 60f0eba

Please sign in to comment.