Skip to content

Commit

Permalink
[TE] reverse-mode autodiff without any optimization
Browse files Browse the repository at this point in the history
Co-authored-by: Sergei Grechanik <[email protected]>
  • Loading branch information
yzhliu and sgrechanik-h committed Mar 24, 2020
1 parent 686911e commit b7407e8
Show file tree
Hide file tree
Showing 10 changed files with 1,020 additions and 11 deletions.
100 changes: 100 additions & 0 deletions include/tvm/te/autodiff.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/te/autodiff.h
* \brief Automatic differentiation of tensor expressions.
* The algorithm was initially implemented by Sergei Grechanik (sgrechanik-h)
* in [Automatic differentiation for tensor expressions](#2498)
* and [Zero elimination](#2634)
*/

#ifndef TVM_TE_AUTODIFF_H_
#define TVM_TE_AUTODIFF_H_

#include <tvm/runtime/object.h>
#include <tvm/tir/expr.h>
#include "tensor.h"

namespace tvm {
/*! \brief Tensor expression language DSL. */
namespace te {

/*!
* \brief Take the derivative of the expression with respect to the given variable.
* \param expr The expression to differentiate.
* \param var The variable to differentiate with respect to.
* \return The expression for the derivative.
*/
PrimExpr Derivative(const PrimExpr& expr, const Var& var);

/*!
* \brief Get the tensor representing the Jacobian of the output with respect to the input.
*
* Note that if \p output depends on \p input indirectly (by using some other tensor
* depending on \p input), this dependency won't contribute to the resulting Jacobian.
* For such cases use the function ::Gradient.
*
* \param output The tensor to differentiate.
* \param input The input tensor, which \p output should directly use.
* \return The tensor representing the Jacobian of shape `output.shape + input.shape`.
*/
Tensor Jacobian(const Tensor& output, const Tensor& input);

/*!
* \brief The building block for reverse-mode AD.
*
* Differentiate \p output wrt \p input and multiply the result by \p head on the left using tensor
* dot product. \p input must be an immediate dependency of \p output (must be called from within
* the body of \p output). That is, the function will compute one summand of the adjoint for \p input
* given the adjoint for \p output (which is called \p head here).
*
* \param output The tensor to differentiate.
* \param input The input tensor, which \p output should directly use.
* \param head The adjoint of \p output. Must be of shape `prefix + output.shape`
* \return The tensor of shape `prefix + input.shape`
* representing the partial adjoint of \p input wrt one of its consumers (output)
*/
Tensor PartialAdjoint(const Tensor& output, const Tensor& input, const Tensor& head);

/*!
* \brief Perform reverse mode automatic differentiation.
*
* Each item of the `result` field of the result is an adjoint for the corresponding item of
* \p inputs, i.e. \p head multiplied by the Jacobian of \p output with respect to the
* corresponding item of \p inputs.
*
* \param output The tensor to differentiate.
* \param inputs The array of input tensors. When the array is empty, will perform differentiation
* wrt all tensors the output depends on.
* \param head The adjoint of the output, in other words, some tensor, by which the Jacobians
* will be multiplied (using tensordot axes=`output.shape`).
* Its shape must be of the form `prefix + output.shape`. If the null pointer is provided,
* the identity tensor of shape `output.shape + output.shape` will be used.
* \return An array of adjoints corresponding to \p inputs.
*/
TVM_DLL Array<Tensor> Gradient(
const Tensor& output,
const Array<Tensor>& inputs,
const Tensor& head = Tensor());

} // namespace te
} // namespace tvm

#endif // TVM_TE_AUTODIFF_H_
1 change: 1 addition & 0 deletions python/tvm/te/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@
from .operation import thread_axis, reduce_axis

from .tensor import PlaceholderOp, ComputeOp, TensorComputeOp, ScanOp, ExternOp, HybridOp
from .autodiff import gradient
72 changes: 72 additions & 0 deletions python/tvm/te/autodiff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
Automatic differentiation of tensor expressions.
The algorithm was initially implemented by Sergei Grechanik (sgrechanik-h)
in [Automatic differentiation for tensor expressions](#2498)
and [Zero elimination](#2634)
"""
from . import _ffi_api


def gradient(output, inputs, head=None):
"""Perform reverse-mode automatic differentiation.
Parameters
----------
output : Tensor
The tensor to differentiate.
inputs : List[Tensor]
The list of input tensors to be differentiated wrt.
head : Tensor
The adjoint of the output, in other words, some tensor, by which the Jacobians
will be multiplied. Its shape must be of the form `prefix + output.shape`.
If `None` is passed, the identity tensor of shape `output.shape + output.shape`
will be used.
Returns
-------
tensors: List[Tensor]
The result gradient, in the same order as the inputs
Example
-------
.. code-block:: python
x = tvm.placeholder((32, 3, 28, 28), name='x')
w1 = tvm.placeholder((10, 3, 3, 3), name='w1')
w2 = tvm.placeholder((10, 10, 3, 3), name='w2')
z1 = topi.nn.conv2d(x, w1, 1, 1, 1)
z2 = topi.nn.conv2d(z1, w2, 1, 1, 1)
y = topi.sum(z2)
# produce gradients
[dw1, dw2] = tvm.gradient(y, [w1, w2])
# produce Jacobians
[jw1, jw2] = tvm.gradient(z2, [w1, w2])
# produce gradients, the head adjoint for z2 is provided manually
[dw1, dw2] = tvm.gradient(z2, [w1, w2], topi.full_like(z2, 1.0))
"""
if not isinstance(inputs, list):
inputs = [inputs]
return _ffi_api.Gradient(output, inputs, head)
33 changes: 23 additions & 10 deletions python/tvm/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-raise
""" TVM testing utilities """
import logging
import numpy as np
Expand All @@ -32,7 +33,7 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7):


def check_numerical_grads(function, input_values, grad_values, function_value=None,
delta=1e-3, atol=1e-2, rtol=0.1):
delta=1e-3, atol=1e-2, rtol=0.1, acceptable_fail_percentage=None):
"""A helper function that checks that numerical gradients of a function are
equal to gradients computed in some different way (analytical gradients).
Expand Down Expand Up @@ -68,6 +69,10 @@ def check_numerical_grads(function, input_values, grad_values, function_value=No
rtol : float, optional
Relative tolerance.
acceptable_fail_percentage : float, optional
If not None, raise an error only when the fraction of wrong elements for a gradient is
higher than this value.
"""
# If input_values is a list then function accepts positional arguments
# In this case transform it to a function taking kwargs of the form {"0": ..., "1": ...}
Expand Down Expand Up @@ -134,7 +139,7 @@ def compare_derivative(j, n_der, grad):

ngrad.reshape(-1)[j] = nder

wrong_percentage = int(100*len(wrong_positions)/np.prod(grad.shape))
wrong_percentage = len(wrong_positions)/np.prod(grad.shape)

dist = np.sqrt(np.sum((ngrad - grad)**2))
grad_norm = np.sqrt(np.sum(ngrad**2))
Expand All @@ -149,14 +154,22 @@ def compare_derivative(j, n_der, grad):
sqrt_n = np.sqrt(float(np.prod(grad.shape)))

if dist > atol*sqrt_n + rtol*grad_norm:
raise AssertionError(
"Analytical and numerical grads wrt '{}' differ too much\n"
"analytical grad = {}\n numerical grad = {}\n"
"{}% of elements differ, first 10 of wrong positions: {}\n"
"distance > atol*sqrt(n) + rtol*grad_norm\n"
"distance {} > {}*{} + {}*{}"
.format(x_name, grad, ngrad, wrong_percentage, wrong_positions[:10],
dist, atol, sqrt_n, rtol, grad_norm))
enough_failures = (acceptable_fail_percentage is None or
wrong_percentage > acceptable_fail_percentage)
if enough_failures:
raise AssertionError(
"Analytical and numerical grads wrt '{}' differ too much\n"
"analytical grad = {}\n numerical grad = {}\n"
"{}% of elements differ, first 10 of wrong positions: {}\n"
"distance > atol*sqrt(n) + rtol*grad_norm\n"
"distance {} > {}*{} + {}*{}"
.format(x_name, grad, ngrad, int(100*wrong_percentage),
wrong_positions[:10], dist, atol, sqrt_n, rtol, grad_norm))
else:
logging.warning("Analytical and numerical grads wrt '%s' differ, however "
"there were not enough wrong elements to raise an error "
"(only %d%%)",
x_name, int(100*wrong_percentage))

max_diff = np.max(np.abs(ngrad - grad))
avg_diff = np.mean(np.abs(ngrad - grad))
Expand Down
67 changes: 67 additions & 0 deletions src/te/autodiff/ad_util.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file ad_util.cc
* \brief Utility for tensor-level auto-differentiation.
* The algorithm was initially implemented by Sergei Grechanik (sgrechanik-h)
* in [Automatic differentiation for tensor expressions](#2498)
* and [Zero elimination](#2634)
*/
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
#include <string>
#include "ad_util.h"

namespace tvm {
namespace te {

std::pair<Array<IterVar>, Map<Var, PrimExpr>> CloneIterVars(const Array<IterVar>& vars) {
Array<IterVar> new_vars;
Map<Var, PrimExpr> vmap;
for (const IterVar& iv : vars) {
IterVar new_v =
IterVarNode::make(iv->dom, iv->var.copy_with_suffix(""),
iv->iter_type, iv->thread_tag);
new_vars.push_back(new_v);
vmap.Set(iv->var, new_v->var);
}
return std::make_pair(std::move(new_vars), std::move(vmap));
}

PrimExpr CloneReduction(const PrimExpr& expr) {
if (const ReduceNode* red = expr.as<ReduceNode>()) {
Array<IterVar> new_axis;
Map<Var, PrimExpr> vmap;
std::tie(new_axis, vmap) = CloneIterVars(red->axis);

Array<PrimExpr> src_with_newaxis;
for (const auto& src : red->source) {
src_with_newaxis.push_back(tir::Substitute(src, vmap));
}

return ReduceNode::make(red->combiner, src_with_newaxis,
new_axis, tir::Substitute(red->condition, vmap), red->value_index);
} else {
return expr;
}
}

} // namespace te
} // namespace tvm
55 changes: 55 additions & 0 deletions src/te/autodiff/ad_util.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file ad_util.h
* \brief Helper utilities to implement auto-differentiation.
* The algorithm was initially implemented by Sergei Grechanik (sgrechanik-h)
* in [Automatic differentiation for tensor expressions](#2498)
* and [Zero elimination](#2634)
*/
#ifndef TVM_TE_AUTODIFF_AD_UTIL_H_
#define TVM_TE_AUTODIFF_AD_UTIL_H_

#include <tvm/tir/expr.h>
#include <tvm/te/operation.h>
#include <vector>
#include <unordered_map>
#include <utility>

namespace tvm {
namespace te {

/*!
* \brief Clone iter vars and return both the new vars and the substitution from old to new.
*
* \param vars The original iter vars.
* \return A pair containing the array of new iter vars and the map from old vars to new ones.
*/
std::pair<Array<IterVar>, Map<Var, PrimExpr>> CloneIterVars(const Array<IterVar>& vars);

/*!
* \brief Clone reduction by cloning the axis variables.
* \param expr A reduction expr to clone. Non-reduction expressions are left intact.
*/
PrimExpr CloneReduction(const PrimExpr& expr);

} // namespace te
} // namespace tvm
#endif // TVM_TE_AUTODIFF_AD_UTIL_H_
Loading

0 comments on commit b7407e8

Please sign in to comment.