Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhliu committed Mar 24, 2020
1 parent b7407e8 commit b0ee501
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 68 deletions.
5 changes: 1 addition & 4 deletions include/tvm/te/autodiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@
/*!
* \file tvm/te/autodiff.h
* \brief Automatic differentiation of tensor expressions.
* The algorithm was initially implemented by Sergei Grechanik (sgrechanik-h)
* in [Automatic differentiation for tensor expressions](#2498)
* and [Zero elimination](#2634)
*/

#ifndef TVM_TE_AUTODIFF_H_
Expand Down Expand Up @@ -71,7 +68,7 @@ Tensor Jacobian(const Tensor& output, const Tensor& input);
* \return The tensor of shape `prefix + input.shape`
* representing the partial adjoint of \p input wrt one of its consumers (output)
*/
Tensor PartialAdjoint(const Tensor& output, const Tensor& input, const Tensor& head);
Tensor VectorJacobianProduct(const Tensor &output, const Tensor &input, const Tensor &head);

/*!
* \brief Perform reverse mode automatic differentiation.
Expand Down
7 changes: 1 addition & 6 deletions python/tvm/te/autodiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,7 @@
# specific language governing permissions and limitations
# under the License.

"""
Automatic differentiation of tensor expressions.
The algorithm was initially implemented by Sergei Grechanik (sgrechanik-h)
in [Automatic differentiation for tensor expressions](#2498)
and [Zero elimination](#2634)
"""
"""Automatic differentiation of tensor expressions."""
from . import _ffi_api


Expand Down
33 changes: 10 additions & 23 deletions python/tvm/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-raise
""" TVM testing utilities """
import logging
import numpy as np
Expand All @@ -33,7 +32,7 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7):


def check_numerical_grads(function, input_values, grad_values, function_value=None,
delta=1e-3, atol=1e-2, rtol=0.1, acceptable_fail_percentage=None):
delta=1e-3, atol=1e-2, rtol=0.1):
"""A helper function that checks that numerical gradients of a function are
equal to gradients computed in some different way (analytical gradients).
Expand Down Expand Up @@ -69,10 +68,6 @@ def check_numerical_grads(function, input_values, grad_values, function_value=No
rtol : float, optional
Relative tolerance.
acceptable_fail_percentage : float, optional
If not None, raise an error only when the fraction of wrong elements for a gradient is
higher than this value.
"""
# If input_values is a list then function accepts positional arguments
# In this case transform it to a function taking kwargs of the form {"0": ..., "1": ...}
Expand Down Expand Up @@ -139,7 +134,7 @@ def compare_derivative(j, n_der, grad):

ngrad.reshape(-1)[j] = nder

wrong_percentage = len(wrong_positions)/np.prod(grad.shape)
wrong_percentage = int(100*len(wrong_positions)/np.prod(grad.shape))

dist = np.sqrt(np.sum((ngrad - grad)**2))
grad_norm = np.sqrt(np.sum(ngrad**2))
Expand All @@ -154,22 +149,14 @@ def compare_derivative(j, n_der, grad):
sqrt_n = np.sqrt(float(np.prod(grad.shape)))

if dist > atol*sqrt_n + rtol*grad_norm:
enough_failures = (acceptable_fail_percentage is None or
wrong_percentage > acceptable_fail_percentage)
if enough_failures:
raise AssertionError(
"Analytical and numerical grads wrt '{}' differ too much\n"
"analytical grad = {}\n numerical grad = {}\n"
"{}% of elements differ, first 10 of wrong positions: {}\n"
"distance > atol*sqrt(n) + rtol*grad_norm\n"
"distance {} > {}*{} + {}*{}"
.format(x_name, grad, ngrad, int(100*wrong_percentage),
wrong_positions[:10], dist, atol, sqrt_n, rtol, grad_norm))
else:
logging.warning("Analytical and numerical grads wrt '%s' differ, however "
"there were not enough wrong elements to raise an error "
"(only %d%%)",
x_name, int(100*wrong_percentage))
raise AssertionError(
"Analytical and numerical grads wrt '{}' differ too much\n"
"analytical grad = {}\n numerical grad = {}\n"
"{}% of elements differ, first 10 of wrong positions: {}\n"
"distance > atol*sqrt(n) + rtol*grad_norm\n"
"distance {} > {}*{} + {}*{}"
.format(x_name, grad, ngrad, wrong_percentage, wrong_positions[:10],
dist, atol, sqrt_n, rtol, grad_norm))

max_diff = np.max(np.abs(ngrad - grad))
avg_diff = np.mean(np.abs(ngrad - grad))
Expand Down
3 changes: 0 additions & 3 deletions src/te/autodiff/ad_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@
/*!
* \file ad_util.cc
* \brief Utility for tensor-level auto-differentiation.
* The algorithm was initially implemented by Sergei Grechanik (sgrechanik-h)
* in [Automatic differentiation for tensor expressions](#2498)
* and [Zero elimination](#2634)
*/
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
Expand Down
3 changes: 0 additions & 3 deletions src/te/autodiff/ad_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@
/*!
* \file ad_util.h
* \brief Helper utilities to implement auto-differentiation.
* The algorithm was initially implemented by Sergei Grechanik (sgrechanik-h)
* in [Automatic differentiation for tensor expressions](#2498)
* and [Zero elimination](#2634)
*/
#ifndef TVM_TE_AUTODIFF_AD_UTIL_H_
#define TVM_TE_AUTODIFF_AD_UTIL_H_
Expand Down
10 changes: 4 additions & 6 deletions src/te/autodiff/adjoint.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@
* (2) multiply the Jacobian (PartialAdjoint),
* (3) and sum them together to get the adjoint of the input itself.
* The three steps are computed recursively.
* The algorithm was initially implemented by Sergei Grechanik (sgrechanik-h)
* in [Automatic differentiation for tensor expressions](#2498)
* and [Zero elimination](#2634)
*/
#include <tvm/runtime/registry.h>
#include <tvm/te/autodiff.h>
Expand Down Expand Up @@ -62,7 +59,7 @@ Tensor Identity(const Tensor& output) {
return te::compute(shape, func, "identity");
}

Tensor PartialAdjoint(const Tensor& output, const Tensor& input, const Tensor& head) {
Tensor VectorJacobianProduct(const Tensor &output, const Tensor &input, const Tensor &head) {
Tensor jac = Jacobian(output, input);
Tensor result = topi::tensordot(head, jac, /*axes=*/output->shape.size(),
output->op->name + "." + input->op->name + ".grad");
Expand Down Expand Up @@ -118,10 +115,11 @@ Array<Tensor> Gradient(const Tensor& output,
} else {
// The new adjoint is computed as a sum of the reverse dependencies' adjoints multiplied
// by the corresponding "local" jacobians (dDep/dTensor). The computation of the jacobian
// and the multiplication is done in the function PartialAdjoint
// and the multiplication is done in the function VectorJacobianProduct
for (const Tensor& direct_consumer : direct_consumers) {
// part = (adjoint of direct_consumer) * Jacobian(direct_consumer, tensor)
Tensor part = PartialAdjoint(direct_consumer, tensor, compute_adjoint(direct_consumer));
Tensor part = VectorJacobianProduct(
direct_consumer, tensor, compute_adjoint(direct_consumer));
res_adjoint = res_adjoint.get() ? topi::add(res_adjoint, part) : part;
}
}
Expand Down
13 changes: 2 additions & 11 deletions src/te/autodiff/jacobian.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@
* \brief Calculate Jacobian of two tensors dY/dX.
* X must be direct input tensor of Y.
* The result Jacobian shape will be (Y.shape, X.shape)
* The algorithm was initially implemented by Sergei Grechanik (sgrechanik-h)
* in [Automatic differentiation for tensor expressions](#2498)
* and [Zero elimination](#2634)
*/
#include <tvm/te/autodiff.h>
#include <tvm/runtime/registry.h>
Expand Down Expand Up @@ -325,14 +322,8 @@ Tensor Jacobian(const Tensor& output, const Tensor& input) {
// We have to clone the iteration axes because otherwise the original expression
// cannot be used together with the derivative (it will lead to errors during lowering)
Array<IterVar> new_axis;
std::unordered_map<const VarNode*, PrimExpr> vmap;
for (IterVar iv : op->axis) {
IterVar new_v =
IterVarNode::make(iv->dom, iv->var.copy_with_suffix(""),
iv->iter_type, iv->thread_tag);
new_axis.push_back(new_v);
vmap[iv->var.get()] = new_v;
}
Map<Var, PrimExpr> vmap;
std::tie(new_axis, vmap) = te::CloneIterVars(op->axis);

Array<PrimExpr> input_indices;
size_t i = 0;
Expand Down
30 changes: 18 additions & 12 deletions tests/python/unittest/test_te_autodiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@

import tvm
from tvm import te
from tvm.testing import check_numerical_grads
from tvm.testing import check_numerical_grads, assert_allclose
import topi
from topi.util import get_const_tuple

import numpy as np


def check_grad(out, inputs, data_range=(-10, 10), acceptable_fail_percentage=None):
def check_grad(out, inputs, data_range=(-10, 10), desired_grads=None):
inputs = inputs if isinstance(inputs, list) else [inputs]

def check_device(device, host="llvm"):
Expand All @@ -50,19 +50,24 @@ def check_device(device, host="llvm"):
grads = te.gradient(out, inputs, head=ones)
grad_sched = te.create_schedule([grad.op for grad in grads])
mgrad = tvm.build(grad_sched, list(grads) + inputs)
# print(tvm.lower(grad_sched, list(grads) + inputs, simple_mode=True))

grad_data = [tvm.nd.empty(get_const_tuple(i.shape), g.dtype)
for i, g in zip(inputs, grads)]

mgrad(*grad_data, *input_data)
g_res = [g.asnumpy() for g in grad_data]

def forward(*in_data):
out_data = tvm.nd.empty(out_shape, out.dtype)
mout(out_data, *[tvm.nd.array(d) for d in list(in_data)])
return out_data.asnumpy().sum()
check_numerical_grads(forward, [d.asnumpy() for d in input_data], g_res,
acceptable_fail_percentage=acceptable_fail_percentage)
if desired_grads:
assert isinstance(desired_grads, list)
for actual, desired in zip(g_res, desired_grads):
assert_allclose(actual, desired, rtol=0.1, atol=1e-2)
else:
def forward(*in_data):
out_data = tvm.nd.empty(out_shape, out.dtype)
mout(out_data, *[tvm.nd.array(d) for d in list(in_data)])
return out_data.asnumpy().sum()
check_numerical_grads(forward, [d.asnumpy() for d in input_data], g_res)

check_device("cpu")

Expand All @@ -74,6 +79,7 @@ def test_basic_operation():
l = te.reduce_axis((0, 10), name="l")
A0 = te.placeholder(shape, name='A0')
A1 = te.placeholder(shape, name='A1')
zeros = np.zeros(shape)

B = te.compute(shape, lambda i, j: A0[i, j], name='B')
check_grad(B, [A0])
Expand All @@ -85,16 +91,16 @@ def test_basic_operation():
check_grad(B, A0)

B = te.compute(shape, lambda i, j: te.floor(A0[i, j]), name='B')
check_grad(B, A0, acceptable_fail_percentage=0.05)
check_grad(B, A0, desired_grads=[zeros])

B = te.compute(shape, lambda i, j: te.ceil(A0[i, j]), name='B')
check_grad(B, A0, acceptable_fail_percentage=0.05)
check_grad(B, A0, desired_grads=[zeros])

B = te.compute(shape, lambda i, j: te.trunc(A0[i, j]), name='B')
check_grad(B, A0, acceptable_fail_percentage=0.05)
check_grad(B, A0, desired_grads=[zeros])

B = te.compute(shape, lambda i, j: te.round(A0[i, j]), name='B')
check_grad(B, A0, acceptable_fail_percentage=0.05)
check_grad(B, A0, desired_grads=[zeros])

B = te.compute(shape, lambda i, j: A0[i, j] + te.exp(A0[j, i]), name='B')
check_grad(B, A0)
Expand Down

0 comments on commit b0ee501

Please sign in to comment.