Skip to content

Commit

Permalink
[Relay] GradientCell Relay Pass (#5039)
Browse files Browse the repository at this point in the history
* save

* gradient.rly

* fix

* NOT WORKING: gradient cell pass

* test gradient pass

* fixed basic call ops

* more tests

* fix bug

* transform calls to one ones_like zero zero_like

* maintenance stuff

* fix linting

* linting

* linting

* throw default

* remove unrelated changes

* import gradent.rly in pass

* comment

* linting

* remove changes to test files

* move gradient_cell.cc to transforms

* revert change

* update files with new commits

* type

* wrapper function to main outermost function type

* fix linting

* fix unsigned and signed int comparison

* review

* GetConstructor definition in module and change op comparison

* update node instantiations

* increase code readability

Co-authored-by: Marisa Kirisame <[email protected]>
  • Loading branch information
hypercubestart and MarisaKirisame authored Mar 24, 2020
1 parent a6de507 commit e6dd8e1
Show file tree
Hide file tree
Showing 8 changed files with 818 additions and 1 deletion.
8 changes: 8 additions & 0 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,14 @@ class IRModuleNode : public Object {
*/
TVM_DLL Array<GlobalTypeVar> GetGlobalTypeVars() const;

/*!
* \brief Find constructor of ADT using name
* \param adt name of the ADT the constructor belongs to
* \param cons name of the constructor
* \returns Constructor of ADT, error if not found
*/
TVM_DLL Constructor GetConstructor(const std::string& adt, const std::string& cons) const;

/*!
* \brief Look up a global function by its variable.
* \param var The global var to lookup.
Expand Down
14 changes: 14 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,20 @@ TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc<
*/
TVM_DLL Pass DeadCodeElimination(bool inline_once = false);

/*!
* \brief Convert all expressions of TensorType into GradCell,
* an algebraic data type defined in gradient.rly.
*
* This will delay or decrease memory usage. All calls to
* ones, ones_like, zeros, zeros_like will not immediately instantiate a tensor in memory,
* rather only instantiate if needed. It also defines + and * operation
* between GradCell types which can increase performance when using
* zero-filled or one-filled tensors, which is the case in reverse mode ad.
*
* \return the pass
*/
TVM_DLL Pass LazyGradientInit();

/*!
* \brief Fold constant expressions.
*
Expand Down
55 changes: 55 additions & 0 deletions python/tvm/relay/std/gradient.rly
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
v0.0.4

/*
* Store the Gradient Value of a Tensor of type T.
* Note that Gradient of T is stored inside a Ref(GradCell[T]) instead of GradCell[T].
*/
type GradCell[T] {
Raw(T),
One(fn() -> T),
Zero(fn() -> T)
}

def @FromGradCell[T](%g: GradCell[T]) -> T {
match (%g) {
Raw(%x) => %x,
One(%x) => %x(),
Zero(%x) => %x()
}
}

def @MultiplyGradCell[T](%multiply: fn(T, T) -> T, %l: GradCell[T], %r: GradCell[T]) -> GradCell[T] {
match((%l, %r)) {
(Zero(_), _) => %l,
(_, Zero(_)) => %r,
(One(_), _) => %r,
(_, One(_)) => %l,
_ => Raw(%multiply(@FromGradCell(%l), @FromGradCell(%r)))
}
}

def @AddGradCell[T](%add: fn(T, T) -> T, %l: GradCell[T], %r: GradCell[T]) -> GradCell[T] {
match ((%l, %r)) {
(Zero(_), _) => %r,
(_, Zero(_)) => %l,
_ => Raw(%add(@FromGradCell(%l), @FromGradCell(%r)))
}
}
13 changes: 13 additions & 0 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,19 @@ def DeadCodeElimination(inline_once=False):
"""
return _ffi_api.DeadCodeElimination(inline_once)

def LazyGradientInit():
"""Reduces memory usage of gradient tensors
Parameters
----------
Returns
-------
ret: tvm.relay.Pass
A pass which delays and/or reduces memory allocation,
by lazily allocating 0 or one filled tensors.
"""
return _ffi_api.LazyGradientInit()

def FoldConstant():
"""Fold the constant expressions in a Relay program.
Expand Down
12 changes: 12 additions & 0 deletions src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,18 @@ GlobalTypeVar IRModuleNode::GetGlobalTypeVar(const std::string& name) const {
return (*it).second;
}

Constructor IRModuleNode::GetConstructor(const std::string& adt, const std::string& cons) const {
TypeData typeDef = this->LookupTypeDef(adt);
for (Constructor c : typeDef->constructors) {
if (cons.compare(c->name_hint) == 0) {
return c;
}
}

LOG(FATAL) << adt << " does not contain constructor " << cons;
throw std::runtime_error("Constructor Not Found.");
}

tvm::Array<GlobalTypeVar> IRModuleNode::GetGlobalTypeVars() const {
std::vector<GlobalTypeVar> global_type_vars;
for (const auto& pair : global_type_var_map_) {
Expand Down
Loading

0 comments on commit e6dd8e1

Please sign in to comment.