Skip to content

Commit

Permalink
Remove unnecessary argument in Transcode functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
jgosmann committed Feb 21, 2018
1 parent a2d0cf3 commit 365be7f
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
12 changes: 6 additions & 6 deletions nengo_spa/modules/tests/test_transcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def stimulus(t, x):


def test_transcode(Simulator, seed):
def transcode_fn(t, sp, vocab):
assert t < 0.15 or vocab.parse('A').dot(sp) > 0.8
def transcode_fn(t, sp):
assert t < 0.15 or sp.vocab.parse('A').dot(sp) > 0.8
return 'B'

with spa.Network(seed=seed) as model:
Expand Down Expand Up @@ -123,10 +123,10 @@ class OutputFn(object):
def __init__(self):
self.called = False

def __call__(self, t, v, vocab):
def __call__(self, t, v):
if t > 0.001:
self.called = True
assert_almost_equal(vocab.parse('A').v, v.v)
assert_almost_equal(v.vocab.parse('A').v, v.v)

output_fn = OutputFn()

Expand All @@ -142,7 +142,7 @@ def __call__(self, t, v, vocab):


def test_decode_with_output(Simulator, seed):
def decode_fn(t, v, vocab):
def decode_fn(t, v):
return [t]

with spa.Network(seed=seed) as model:
Expand All @@ -156,7 +156,7 @@ def decode_fn(t, v, vocab):


def test_decode_size_out(Simulator, seed):
def decode_fn(t, v, vocab):
def decode_fn(t, v):
return [t]

with spa.Network(seed=seed) as model:
Expand Down
6 changes: 3 additions & 3 deletions nengo_spa/modules/transcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __call__(self, value):

def make_sp_func(fn, vocab):
def sp_func(t, v):
return fn(t, SemanticPointer(v), vocab)
return fn(t, SemanticPointer(v, vocab=vocab))
return sp_func


Expand Down Expand Up @@ -57,8 +57,8 @@ def coerce(self, obj, fn):
def coerce_callable(self, obj, fn):
t = 0.
if obj.input_vocab is not None:
args = (t, SemanticPointer(obj.input_vocab.dimensions),
obj.input_vocab)
args = (t, SemanticPointer(
obj.input_vocab.dimensions, vocab=obj.input_vocab))
elif obj.size_in is not None:
args = (t, np.zeros(obj.size_in))
else:
Expand Down

0 comments on commit 365be7f

Please sign in to comment.