diff --git a/include/tvm/te/autodiff.h b/include/tvm/te/autodiff.h index a3e5d68c0f204..180ec0bf676cc 100644 --- a/include/tvm/te/autodiff.h +++ b/include/tvm/te/autodiff.h @@ -20,9 +20,6 @@ /*! * \file tvm/te/autodiff.h * \brief Automatic differentiation of tensor expressions. - * The algorithm was initially implemented by Sergei Grechanik (sgrechanik-h) - * in [Automatic differentiation for tensor expressions](#2498) - * and [Zero elimination](#2634) */ #ifndef TVM_TE_AUTODIFF_H_ @@ -71,7 +68,7 @@ Tensor Jacobian(const Tensor& output, const Tensor& input); * \return The tensor of shape `prefix + input.shape` * representing the partial adjoint of \p input wrt one of its consumers (output) */ -Tensor PartialAdjoint(const Tensor& output, const Tensor& input, const Tensor& head); +Tensor VectorJacobianProduct(const Tensor &output, const Tensor &input, const Tensor &head); /*! * \brief Perform reverse mode automatic differentiation. diff --git a/python/tvm/te/autodiff.py b/python/tvm/te/autodiff.py index e6cad14955dd9..f8650839948d1 100644 --- a/python/tvm/te/autodiff.py +++ b/python/tvm/te/autodiff.py @@ -15,12 +15,7 @@ # specific language governing permissions and limitations # under the License. -""" -Automatic differentiation of tensor expressions. -The algorithm was initially implemented by Sergei Grechanik (sgrechanik-h) -in [Automatic differentiation for tensor expressions](#2498) -and [Zero elimination](#2634) -""" +"""Automatic differentiation of tensor expressions.""" from . import _ffi_api diff --git a/python/tvm/testing.py b/python/tvm/testing.py index e6df7ddf1d32c..077ac35f69a08 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=no-else-raise """ TVM testing utilities """ import logging import numpy as np @@ -33,7 +32,7 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7): def check_numerical_grads(function, input_values, grad_values, function_value=None, - delta=1e-3, atol=1e-2, rtol=0.1, acceptable_fail_percentage=None): + delta=1e-3, atol=1e-2, rtol=0.1): """A helper function that checks that numerical gradients of a function are equal to gradients computed in some different way (analytical gradients). @@ -69,10 +68,6 @@ def check_numerical_grads(function, input_values, grad_values, function_value=No rtol : float, optional Relative tolerance. - - acceptable_fail_percentage : float, optional - If not None, raise an error only when the fraction of wrong elements for a gradient is - higher than this value. """ # If input_values is a list then function accepts positional arguments # In this case transform it to a function taking kwargs of the form {"0": ..., "1": ...} @@ -139,7 +134,7 @@ def compare_derivative(j, n_der, grad): ngrad.reshape(-1)[j] = nder - wrong_percentage = len(wrong_positions)/np.prod(grad.shape) + wrong_percentage = int(100*len(wrong_positions)/np.prod(grad.shape)) dist = np.sqrt(np.sum((ngrad - grad)**2)) grad_norm = np.sqrt(np.sum(ngrad**2)) @@ -154,22 +149,14 @@ def compare_derivative(j, n_der, grad): sqrt_n = np.sqrt(float(np.prod(grad.shape))) if dist > atol*sqrt_n + rtol*grad_norm: - enough_failures = (acceptable_fail_percentage is None or - wrong_percentage > acceptable_fail_percentage) - if enough_failures: - raise AssertionError( - "Analytical and numerical grads wrt '{}' differ too much\n" - "analytical grad = {}\n numerical grad = {}\n" - "{}% of elements differ, first 10 of wrong positions: {}\n" - "distance > atol*sqrt(n) + rtol*grad_norm\n" - "distance {} > {}*{} + {}*{}" - .format(x_name, grad, ngrad, int(100*wrong_percentage), - wrong_positions[:10], dist, atol, sqrt_n, rtol, grad_norm)) - else: - logging.warning("Analytical and numerical grads wrt '%s' differ, however " - "there were not enough wrong elements to raise an error " - "(only %d%%)", - x_name, int(100*wrong_percentage)) + raise AssertionError( + "Analytical and numerical grads wrt '{}' differ too much\n" + "analytical grad = {}\n numerical grad = {}\n" + "{}% of elements differ, first 10 of wrong positions: {}\n" + "distance > atol*sqrt(n) + rtol*grad_norm\n" + "distance {} > {}*{} + {}*{}" + .format(x_name, grad, ngrad, wrong_percentage, wrong_positions[:10], + dist, atol, sqrt_n, rtol, grad_norm)) max_diff = np.max(np.abs(ngrad - grad)) avg_diff = np.mean(np.abs(ngrad - grad)) diff --git a/src/te/autodiff/ad_util.cc b/src/te/autodiff/ad_util.cc index 0f8f31e316685..3a90beff48220 100644 --- a/src/te/autodiff/ad_util.cc +++ b/src/te/autodiff/ad_util.cc @@ -20,9 +20,6 @@ /*! * \file ad_util.cc * \brief Utility for tensor-level auto-differentiation. - * The algorithm was initially implemented by Sergei Grechanik (sgrechanik-h) - * in [Automatic differentiation for tensor expressions](#2498) - * and [Zero elimination](#2634) */ #include #include diff --git a/src/te/autodiff/ad_util.h b/src/te/autodiff/ad_util.h index d791e384ebc8d..7e511b1c5a22e 100644 --- a/src/te/autodiff/ad_util.h +++ b/src/te/autodiff/ad_util.h @@ -20,9 +20,6 @@ /*! * \file ad_util.h * \brief Helper utilities to implement auto-differentiation. - * The algorithm was initially implemented by Sergei Grechanik (sgrechanik-h) - * in [Automatic differentiation for tensor expressions](#2498) - * and [Zero elimination](#2634) */ #ifndef TVM_TE_AUTODIFF_AD_UTIL_H_ #define TVM_TE_AUTODIFF_AD_UTIL_H_ diff --git a/src/te/autodiff/adjoint.cc b/src/te/autodiff/adjoint.cc index 40b06011ed0ca..0c54764e601ad 100644 --- a/src/te/autodiff/adjoint.cc +++ b/src/te/autodiff/adjoint.cc @@ -29,9 +29,6 @@ * (2) multiply the Jacobian (PartialAdjoint), * (3) and sum them together to get the adjoint of the input itself. * The three steps are computed recursively. - * The algorithm was initially implemented by Sergei Grechanik (sgrechanik-h) - * in [Automatic differentiation for tensor expressions](#2498) - * and [Zero elimination](#2634) */ #include #include @@ -62,7 +59,7 @@ Tensor Identity(const Tensor& output) { return te::compute(shape, func, "identity"); } -Tensor PartialAdjoint(const Tensor& output, const Tensor& input, const Tensor& head) { +Tensor VectorJacobianProduct(const Tensor &output, const Tensor &input, const Tensor &head) { Tensor jac = Jacobian(output, input); Tensor result = topi::tensordot(head, jac, /*axes=*/output->shape.size(), output->op->name + "." + input->op->name + ".grad"); @@ -118,10 +115,11 @@ Array Gradient(const Tensor& output, } else { // The new adjoint is computed as a sum of the reverse dependencies' adjoints multiplied // by the corresponding "local" jacobians (dDep/dTensor). The computation of the jacobian - // and the multiplication is done in the function PartialAdjoint + // and the multiplication is done in the function VectorJacobianProduct for (const Tensor& direct_consumer : direct_consumers) { // part = (adjoint of direct_consumer) * Jacobian(direct_consumer, tensor) - Tensor part = PartialAdjoint(direct_consumer, tensor, compute_adjoint(direct_consumer)); + Tensor part = VectorJacobianProduct( + direct_consumer, tensor, compute_adjoint(direct_consumer)); res_adjoint = res_adjoint.get() ? topi::add(res_adjoint, part) : part; } } diff --git a/src/te/autodiff/jacobian.cc b/src/te/autodiff/jacobian.cc index 08baa052e1370..607336f7e9286 100644 --- a/src/te/autodiff/jacobian.cc +++ b/src/te/autodiff/jacobian.cc @@ -22,9 +22,6 @@ * \brief Calculate Jacobian of two tensors dY/dX. * X must be direct input tensor of Y. * The result Jacobian shape will be (Y.shape, X.shape) - * The algorithm was initially implemented by Sergei Grechanik (sgrechanik-h) - * in [Automatic differentiation for tensor expressions](#2498) - * and [Zero elimination](#2634) */ #include #include @@ -325,14 +322,8 @@ Tensor Jacobian(const Tensor& output, const Tensor& input) { // We have to clone the iteration axes because otherwise the original expression // cannot be used together with the derivative (it will lead to errors during lowering) Array new_axis; - std::unordered_map vmap; - for (IterVar iv : op->axis) { - IterVar new_v = - IterVarNode::make(iv->dom, iv->var.copy_with_suffix(""), - iv->iter_type, iv->thread_tag); - new_axis.push_back(new_v); - vmap[iv->var.get()] = new_v; - } + Map vmap; + std::tie(new_axis, vmap) = te::CloneIterVars(op->axis); Array input_indices; size_t i = 0; diff --git a/tests/python/unittest/test_te_autodiff.py b/tests/python/unittest/test_te_autodiff.py index 0511bff52526d..c5b8d1349dd20 100644 --- a/tests/python/unittest/test_te_autodiff.py +++ b/tests/python/unittest/test_te_autodiff.py @@ -17,14 +17,14 @@ import tvm from tvm import te -from tvm.testing import check_numerical_grads +from tvm.testing import check_numerical_grads, assert_allclose import topi from topi.util import get_const_tuple import numpy as np -def check_grad(out, inputs, data_range=(-10, 10), acceptable_fail_percentage=None): +def check_grad(out, inputs, data_range=(-10, 10), desired_grads=None): inputs = inputs if isinstance(inputs, list) else [inputs] def check_device(device, host="llvm"): @@ -50,6 +50,7 @@ def check_device(device, host="llvm"): grads = te.gradient(out, inputs, head=ones) grad_sched = te.create_schedule([grad.op for grad in grads]) mgrad = tvm.build(grad_sched, list(grads) + inputs) + # print(tvm.lower(grad_sched, list(grads) + inputs, simple_mode=True)) grad_data = [tvm.nd.empty(get_const_tuple(i.shape), g.dtype) for i, g in zip(inputs, grads)] @@ -57,12 +58,16 @@ def check_device(device, host="llvm"): mgrad(*grad_data, *input_data) g_res = [g.asnumpy() for g in grad_data] - def forward(*in_data): - out_data = tvm.nd.empty(out_shape, out.dtype) - mout(out_data, *[tvm.nd.array(d) for d in list(in_data)]) - return out_data.asnumpy().sum() - check_numerical_grads(forward, [d.asnumpy() for d in input_data], g_res, - acceptable_fail_percentage=acceptable_fail_percentage) + if desired_grads: + assert isinstance(desired_grads, list) + for actual, desired in zip(g_res, desired_grads): + assert_allclose(actual, desired, rtol=0.1, atol=1e-2) + else: + def forward(*in_data): + out_data = tvm.nd.empty(out_shape, out.dtype) + mout(out_data, *[tvm.nd.array(d) for d in list(in_data)]) + return out_data.asnumpy().sum() + check_numerical_grads(forward, [d.asnumpy() for d in input_data], g_res) check_device("cpu") @@ -74,6 +79,7 @@ def test_basic_operation(): l = te.reduce_axis((0, 10), name="l") A0 = te.placeholder(shape, name='A0') A1 = te.placeholder(shape, name='A1') + zeros = np.zeros(shape) B = te.compute(shape, lambda i, j: A0[i, j], name='B') check_grad(B, [A0]) @@ -85,16 +91,16 @@ def test_basic_operation(): check_grad(B, A0) B = te.compute(shape, lambda i, j: te.floor(A0[i, j]), name='B') - check_grad(B, A0, acceptable_fail_percentage=0.05) + check_grad(B, A0, desired_grads=[zeros]) B = te.compute(shape, lambda i, j: te.ceil(A0[i, j]), name='B') - check_grad(B, A0, acceptable_fail_percentage=0.05) + check_grad(B, A0, desired_grads=[zeros]) B = te.compute(shape, lambda i, j: te.trunc(A0[i, j]), name='B') - check_grad(B, A0, acceptable_fail_percentage=0.05) + check_grad(B, A0, desired_grads=[zeros]) B = te.compute(shape, lambda i, j: te.round(A0[i, j]), name='B') - check_grad(B, A0, acceptable_fail_percentage=0.05) + check_grad(B, A0, desired_grads=[zeros]) B = te.compute(shape, lambda i, j: A0[i, j] + te.exp(A0[j, i]), name='B') check_grad(B, A0)