Skip to content

Commit

Permalink
[Relay][Prelude] Remove Peano nats from the prelude (apache#3045)
Browse files Browse the repository at this point in the history
  • Loading branch information
slyubomirsky authored and wweic committed Jun 27, 2019
1 parent 9b1d5ab commit 0bf54b8
Show file tree
Hide file tree
Showing 8 changed files with 326 additions and 145 deletions.
132 changes: 64 additions & 68 deletions python/tvm/relay/prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""Adds certain standard global functions and ADT definitions to the module."""
from .ty import GlobalTypeVar, TypeVar, FuncType, TupleType, scalar_type
from .expr import Var, Function, GlobalVar, Let, If, Tuple, TupleGetItem
from .expr import Var, Function, GlobalVar, Let, If, Tuple, TupleGetItem, const
from .op.tensor import add, subtract, equal
from .adt import Constructor, TypeData, Clause, Match
from .adt import PatternConstructor, PatternVar, PatternWildcard

Expand All @@ -34,6 +35,7 @@ def define_list_adt(self):
self.cons = Constructor("cons", [a, self.l(a)], self.l)
self.mod[self.l] = TypeData(self.l, [a], [self.nil, self.cons])


def define_list_hd(self):
"""Defines a function to get the head of a list. Assume the list has at least one
element.
Expand All @@ -48,6 +50,7 @@ def define_list_hd(self):
cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), y)
self.mod[self.hd] = Function([x], Match(x, [cons_case]), a, [a])


def define_list_tl(self):
"""Defines a function to get the tail of a list.
Expand All @@ -61,39 +64,44 @@ def define_list_tl(self):
cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), z)
self.mod[self.tl] = Function([x], Match(x, [cons_case]), self.l(a), [a])


def define_list_nth(self):
"""Defines a function to get the nth element of a list.
nth(l) : list[a] -> a
nth(l) : list[a] -> Tensor[(), int32] -> a
"""
self.nth = GlobalVar("nth")
a = TypeVar("a")
x = Var("x", self.l(a))
n = Var("n", self.nat())
n = Var("n", scalar_type('int32'))

body = If(equal(n, const(0)),
self.hd(x),
self.nth(self.tl(x), subtract(n, const(1))))

self.mod[self.nth] = Function([x, n], body, a, [a])

y = Var("y")
z_case = Clause(PatternConstructor(self.z), self.hd(x))
s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]), self.nth(self.tl(x), y))
self.mod[self.nth] = Function([x, n], Match(n, [z_case, s_case]), a, [a])

def define_list_update(self):
"""Defines a function to update the nth element of a list and return the updated list.
update(l, i, v) : list[a] -> nat -> a -> list[a]
update(l, i, v) : list[a] -> Tensor[(), int32] -> a -> list[a]
"""
self.update = GlobalVar("update")
a = TypeVar("a")
l = Var("l", self.l(a))
n = Var("n", self.nat())
n = Var("n", scalar_type('int32'))
v = Var("v", a)

y = Var("y")
body = If(equal(n, const(0)),
self.cons(v, self.tl(l)),
self.cons(self.hd(l),
self.update(self.tl(l),
subtract(n, const(1)),
v)))

z_case = Clause(PatternConstructor(self.z), self.cons(v, self.tl(l)))
s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]),
self.cons(self.hd(l), self.update(self.tl(l), y, v)))
self.mod[self.update] = Function([l, n, v], body, self.l(a), [a])

self.mod[self.update] = Function([l, n, v], Match(n, [z_case, s_case]), self.l(a), [a])

def define_list_map(self):
"""Defines a function for mapping a function over a list's
Expand All @@ -114,6 +122,7 @@ def define_list_map(self):
self.cons(f(y), self.map(f, z)))
self.mod[self.map] = Function([f, x], Match(x, [nil_case, cons_case]), self.l(b), [a, b])


def define_list_foldl(self):
"""Defines a left-way fold over a list.
Expand All @@ -136,6 +145,7 @@ def define_list_foldl(self):
self.mod[self.foldl] = Function([f, av, bv],
Match(bv, [nil_case, cons_case]), a, [a, b])


def define_list_foldr(self):
"""Defines a right-way fold over a list.
Expand All @@ -158,6 +168,7 @@ def define_list_foldr(self):
self.mod[self.foldr] = Function([f, bv, av],
Match(av, [nil_case, cons_case]), b, [a, b])


def define_list_foldr1(self):
"""Defines a right-way fold over a nonempty list.
Expand Down Expand Up @@ -196,6 +207,7 @@ def define_list_concat(self):
self.foldr(updater, l2, l1),
self.l(a), [a])


def define_list_filter(self):
"""Defines a function that filters a list.
Expand All @@ -214,6 +226,7 @@ def define_list_filter(self):
If(f(h), self.cons(h, self.filter(f, t)), self.filter(f, t)))
self.mod[self.filter] = Function([f, l], Match(l, [nil_case, cons_case]), self.l(a), [a])


def define_list_zip(self):
"""Defines a function that combines two lists into a list of tuples of their elements.
Expand All @@ -238,6 +251,7 @@ def define_list_zip(self):
self.mod[self.zip] = Function([l1, l2], Match(l1, [nil_case, outer_cons_case]),
self.l(TupleType([a, b])), [a, b])


def define_list_rev(self):
"""Defines a function that reverses a list.
Expand All @@ -253,6 +267,7 @@ def define_list_rev(self):
self.foldl(updater, self.nil(), l),
self.l(a), [a])


def define_list_map_accumr(self):
"""Defines an accumulative map, which is a fold that simulataneously updates
an accumulator value and a list of results.
Expand Down Expand Up @@ -282,6 +297,7 @@ def define_list_map_accumr(self):
TupleType([a, self.l(c)]),
[a, b, c])


def define_list_map_accuml(self):
"""Defines an accumulative map, which is a fold that simulataneously updates
an accumulator value and a list of results.
Expand Down Expand Up @@ -321,6 +337,7 @@ def define_optional_adt(self):
self.none = Constructor("none", [], self.optional)
self.mod[self.optional] = TypeData(self.optional, [a], [self.some, self.none])


def define_list_unfoldr(self):
"""Defines a function that builds up a list starting from a seed value.
Expand All @@ -343,6 +360,7 @@ def define_list_unfoldr(self):
self.mod[self.unfoldr] = Function([f, s], Match(f(s), [none_case, some_case]),
self.l(b), [a, b])


def define_list_unfoldl(self):
"""Defines a function that builds up a list starting from a seed value.
Expand All @@ -362,52 +380,29 @@ def define_list_unfoldl(self):
self.rev(self.unfoldr(f, s)),
self.l(b), [a, b])

def define_nat_adt(self):
"""Defines a Peano (unary) natural number ADT.
Zero is represented by z(). s(n) adds 1 to a nat n."""
self.nat = GlobalTypeVar("nat")
self.z = Constructor("z", [], self.nat)
self.s = Constructor("s", [self.nat()], self.nat)
self.mod[self.nat] = TypeData(self.nat, [], [self.z, self.s])

def define_nat_double(self):
"""Defines a function that doubles a nat."""
self.double = GlobalVar("double")
x = Var("x", self.nat())
y = Var("y")
z_case = Clause(PatternConstructor(self.z), self.z())
s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]),
self.s(self.s(self.double(y))))
self.mod[self.double] = Function([x], Match(x, [z_case, s_case]))

def define_nat_add(self):
"""Defines a function that adds two nats."""
self.add = GlobalVar("add")
x = Var("x", self.nat())
y = Var("y", self.nat())
a = Var("a")
z_case = Clause(PatternConstructor(self.z), y)
s_case = Clause(PatternConstructor(self.s, [PatternVar(a)]),
self.s(self.add(a, y)))
self.mod[self.add] = Function([x, y], Match(x, [z_case, s_case]))

def define_list_sum(self):
"""Defines a function that computes the sum of a list of nats."""
"""Defines a function that computes the sum of a list of integer scalars."""
self.sum = GlobalVar("sum")
a = Var("a", self.l(self.nat()))
self.mod[self.sum] = Function([a], self.foldl(self.add, self.z(), a))
a = Var("a", self.l(scalar_type('int32')))
x = Var('x')
y = Var('y')
addf = Function([x, y], add(x, y))
self.mod[self.sum] = Function([a], self.foldl(addf, const(0), a))


def define_list_length(self):
"""Defines a function that returns the length of a list as a nat"""
"""Defines a function that returns the length of a list"""
self.length = GlobalVar("length")
a = TypeVar("a")
x = Var("x", self.l(a))
y = Var("y")
nil_case = Clause(PatternConstructor(self.nil), self.z())
nil_case = Clause(PatternConstructor(self.nil), const(0))
cons_case = Clause(PatternConstructor(self.cons, [PatternWildcard(), PatternVar(y)]),
self.s(self.length(y)))
add(const(1), self.length(y)))
self.mod[self.length] = Function([x],
Match(x, [nil_case, cons_case]), None, [a])
Match(x, [nil_case, cons_case]), scalar_type('int32'), [a])


def define_tree_adt(self):
"""Defines a tree ADT. A tree can contain any type.
Expand All @@ -420,6 +415,7 @@ def define_tree_adt(self):
self.rose = Constructor("rose", [a, self.l(self.tree(a))], self.tree)
self.mod[self.tree] = TypeData(self.tree, [a], [self.rose])


def define_tree_map(self):
"""Defines a function that maps over a tree. The function
is applied to each subtree's contents.
Expand All @@ -439,23 +435,24 @@ def define_tree_map(self):
self.mod[self.tmap] = Function([f, t],
Match(t, [rose_case]), self.tree(b), [a, b])


def define_tree_size(self):
"""Defines a function that computes the size of a tree as a nat.
"""Defines a function that computes the size of a tree.
Signature: fn<a>(t : tree[a]) -> nat
Signature: fn<a>(t : tree[a]) -> Tensor[(), int32]
"""
self.size = GlobalVar("size")
a = TypeVar("a")
t = Var("t", self.tree(a))
x = Var("x", self.tree(a))
z = Var("z")
rose_case = Clause(PatternConstructor(self.rose, [PatternWildcard(), PatternVar(z)]),
self.s(self.sum(self.map(Function([x], self.size(x)), z))))
add(const(1), self.sum(self.map(self.size, z))))
self.mod[self.size] = Function([t],
Match(t, [rose_case]), self.nat(), [a])
Match(t, [rose_case]), scalar_type('int32'), [a])


def define_id(self):
"""Defines a function that return it's argument.
"""Defines a function that return its argument.
Signature: fn<a>(x : a) -> a
"""
Expand All @@ -466,7 +463,7 @@ def define_id(self):


def define_compose(self):
"""Defines a function that compose two function.
"""Defines a function that composes two function.
Signature: fn<a, b, c>(f : fn(b) -> c, g : fn(a) -> b) -> fn(a) -> c
"""
Expand All @@ -484,24 +481,26 @@ def define_compose(self):


def define_iterate(self):
"""Define a function that take a number n, a function f,
and return a closure that apply f n time on it's argument.
"""Defines a function that take a number n and a function f;
returns a closure that takes an argument and applies f
n times to its argument.
Signature: fn<a>(n : nat, f : fn(a) -> a) -> fn(a) -> a
Signature: fn<a>(f : fn(a) -> a, n : Tensor[(), int32]) -> fn(a) -> a
"""
self.iterate = GlobalVar("iterate")
a = TypeVar("a")
f = Var("f", FuncType([a], a))
x = Var("x", self.nat())
y = Var("y", self.nat())
z_case = Clause(PatternConstructor(self.z), self.id)
s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]),
self.compose(f, self.iterate(f, y)))
x = Var("x", scalar_type('int32'))
body = If(equal(x, const(0)),
self.id,
self.compose(f,
self.iterate(f, subtract(x, const(1)))))
self.mod[self.iterate] = Function([f, x],
Match(x, [z_case, s_case]),
body,
FuncType([a], a),
[a])


def __init__(self, mod):
self.mod = mod
self.define_list_adt()
Expand All @@ -522,9 +521,6 @@ def __init__(self, mod):
self.define_list_unfoldr()
self.define_list_unfoldl()

self.define_nat_adt()
self.define_nat_double()
self.define_nat_add()
self.define_list_length()
self.define_list_nth()
self.define_list_update()
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@

from .config import ctx_list
from .init import create_workload
from .nat import add_nat_definitions, count, make_nat_value, make_nat_expr
Loading

0 comments on commit 0bf54b8

Please sign in to comment.