Skip to content

Commit

Permalink
Fix ADT tests to account for using scalars instead of nats
Browse files Browse the repository at this point in the history
  • Loading branch information
slyubomirsky committed May 15, 2019
1 parent 0c74795 commit 5aafd8b
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions tests/python/relay/test_adt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np
import tvm
from tvm import relay
from tvm.relay.ir_pass import infer_type
Expand Down Expand Up @@ -153,6 +154,12 @@ def tree_to_dict(t):
ret['children'].append(l)
return ret


# turns a scalar-valued relay tensor value into a python number
def get_scalar(tv):
return tv.asnumpy().item()


def test_nat_value():
assert count(make_nat(10)) == 10
assert count(intrp.evaluate(s(s(z())))) == 2
Expand Down Expand Up @@ -198,11 +205,10 @@ def test_nth():
for i in reversed(expected):
l = cons(relay.const(i), l)

got = []
for i in range(len(expected)):
got.append(intrp.evaluate(nth(l, relay.const(i))))
item = intrp.evaluate(nth(l, relay.const(i)))
assert get_scalar(item) == i

assert got == expected

def test_update():
expected = list(range(10))
Expand All @@ -225,7 +231,7 @@ def test_length():
a = relay.TypeVar("a")
assert mod[length].checked_type == relay.FuncType([l(a)], relay.scalar_type('int32'), [a])
res = intrp.evaluate(length(cons(z(), cons(z(), cons(z(), nil())))))
assert res == 3
assert get_scalar(res) == 3


def test_map():
Expand Down Expand Up @@ -303,7 +309,7 @@ def test_foldr1():
def test_sum():
assert mod[sum].checked_type == relay.FuncType([l(relay.scalar_type('int32'))], relay.scalar_type('int32'))
res = intrp.evaluate(sum(cons(relay.const(1), cons(relay.const(2), nil()))))
assert count(res) == 3
assert get_scalar(res) == 3


def test_concat():
Expand Down Expand Up @@ -578,7 +584,7 @@ def test_size():
nil())))
t = rose(z(), cons(root, cons(root, cons(root, nil()))))
res = intrp.evaluate(size(t))
assert res == 10
assert get_scalar(res) == 10


def test_wildcard_match_solo():
Expand Down

0 comments on commit 5aafd8b

Please sign in to comment.