Skip to content

Commit

Permalink
Amendments for gradients (apache#5941)
Browse files Browse the repository at this point in the history
* Amendments for gradients

- We fix the dtype handling of consts in generated gradients.
- We add a collapse_sum_to instruction mirroring the collapse_sum_like.
  While for general definitions (potentially dynamic shapes),
  collapse_sum_like is the first choice, when moving to static,
  using collapse_sum_to will greatly simplify the graph.
  (This simplification is not part of the PR.)

* Fix Broadcast rel description in comment

Thank you, @MarisaKirisame
  • Loading branch information
t-vi authored and zhiics committed Jul 2, 2020
1 parent a997995 commit af1f80a
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 35 deletions.
24 changes: 16 additions & 8 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def log2_grad(orig, grad):
"""Returns [grad * 1 / (log(2) * x)]"""
x = orig.args[0]
ones = ones_like(x)
two = const(2.0)
two = const(2.0, dtype=x.checked_type.dtype)
return [grad * ones / (log(two) * x)]


Expand All @@ -78,7 +78,7 @@ def log10_grad(orig, grad):
"""Returns [grad * 1 / (log(10) * x)]"""
x = orig.args[0]
ones = ones_like(x)
ten = const(10.0)
ten = const(10.0, dtype=x.checked_type.dtype)
return [grad * ones / (log(ten) * x)]


Expand Down Expand Up @@ -175,8 +175,9 @@ def exp_grad(orig, grad):
@register_gradient("sqrt")
def sqrt_grad(orig, grad):
"""Returns [grad * 0.5 * (x ^ -0.5)]"""
a = const(0.5) # (TODO) type?
return [grad * a * power(orig.args[0], negative(a))]
x = orig.args[0]
a = const(0.5, dtype=x.checked_type.dtype)
return [grad * a * power(x, negative(a))]


@register_gradient("sigmoid")
Expand Down Expand Up @@ -261,6 +262,13 @@ def collapse_sum_like_grad(orig, grad):
return [broadcast_to_like(grad, x), zeros_like(y)]


@register_gradient("collapse_sum_to")
def collapse_sum_to_grad(orig, grad):
"""Returns [broadcast_to_like(grad, x), 0]"""
x, y = orig.args
return [broadcast_to_like(grad, x), zeros_like(y)]


@register_gradient("abs")
def abs_grad(orig, grad):
"""Returns grad * (select(x < 0, -1, 1))."""
Expand All @@ -284,8 +292,8 @@ def clip_grad(orig, grad):
x = orig.args[0]
a_min = orig.attrs.get_int("a_min")
a_max = orig.attrs.get_int("a_max")
a_mins = broadcast_to_like(const(a_min), x)
a_maxs = broadcast_to_like(const(a_max), x)
a_mins = broadcast_to_like(const(a_min, dtype=x.checked_type.dtype), x)
a_maxs = broadcast_to_like(const(a_max, dtype=x.checked_type.dtype), x)
zeros = zeros_like(x)
ones = ones_like(x)
return [where(less(x, a_mins), zeros, where(less(a_maxs, x), zeros, ones * grad))]
Expand Down Expand Up @@ -591,7 +599,7 @@ def cross_entropy_grad(orig, grad):
x, y = orig.args
shape = shape_of(x)
batch_size = take(shape, const(0, dtype='int32'), axis=0)
grad = grad / batch_size.astype('float32')
grad = grad / batch_size.astype(x.checked_type.dtype)
return [-grad * y / x, -grad * log(x)]


Expand All @@ -600,5 +608,5 @@ def cross_entropy_with_logits_grad(orig, grad):
x, y = orig.args
shape = shape_of(x)
batch_size = take(shape, const(0, dtype='int32'), axis=0)
grad = grad / batch_size.astype('float32')
grad = grad / batch_size.astype(x.checked_type.dtype)
return [-grad * y, -grad * x]
1 change: 1 addition & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
_reg.register_injective_schedule("sequence_mask")
_reg.register_injective_schedule("one_hot")
_reg.register_reduce_schedule("collapse_sum_like")
_reg.register_reduce_schedule("collapse_sum_to")
_reg.register_injective_schedule("unravel_index")
_reg.register_injective_schedule("sparse_to_dense")

Expand Down
21 changes: 21 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,27 @@ def collapse_sum_like(data, collapse_type):
return _make.collapse_sum_like(data, collapse_type)


def collapse_sum_to(data, shape):
"""Return a summation of data to the specified shape.
Parameters
----------
data : relay.Expr
The input tensor.
shape : relay.Expr
Shape to collapse to.
Returns
-------
result : relay.Expr
The resulting tensor.
"""
if isinstance(shape, (list, tuple)):
shape = const(list(shape), "int32")
return _make.collapse_sum_to(data, shape)


def split(data, indices_or_sections, axis=0):
"""Split input tensor along axis by sections or indices.
Expand Down
48 changes: 48 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1713,6 +1713,54 @@ RELAY_REGISTER_OP("collapse_sum_like")
.set_attr<FTVMCompute>("FTVMCompute", CollapseSumLikeCompute)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);

// CollapseSumTo: <A, B> -> B where Broadcast(A, B) = A
bool CollapseSumToRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const InitOpAttrs* param = attrs.as<InitOpAttrs>();
const auto* target_shape = types[1].as<TensorTypeNode>();
DataType out_dtype = types[0].as<TensorTypeNode>()->dtype;

const IntImmNode* shape_shape = target_shape->shape[0].as<IntImmNode>();
CHECK(shape_shape) << "Parameter shape must have static shape";

std::vector<IndexExpr> oshape;
if (param->shape) {
const Array<Integer>& cshape_array = param->shape.value();
for (size_t i = 0; i < cshape_array.size(); ++i) {
oshape.push_back(cshape_array[i]);
}
} else {
for (int i = 0; i < shape_shape->value; ++i) {
oshape.push_back(Any());
}
}
reporter->Assign(types[2], TensorType(oshape, out_dtype));
return BroadcastRel({types[0], types[2], types[0]}, 2, Attrs(), reporter);
}

Expr MakeCollapseSumTo(Expr data, Expr shape) {
static const Op& op = Op::Get("collapse_sum_to");
auto attrs = make_object<InitOpAttrs>();
if (const auto* cshape = shape.as<ConstantNode>()) {
attrs->shape = ToVector(cshape->data);
}
return Call(op, {data, shape}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op._make.collapse_sum_to").set_body_typed(MakeCollapseSumTo);

RELAY_REGISTER_OP("collapse_sum_to")
.describe(R"code(Broadcast the first input to match the shape argument.
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("shape", "Tensor", "Target shape.")
.set_support_level(4)
.add_type_rel("CollapseSumTo", CollapseSumToRel)
.set_attr<FTVMCompute>("FTVMCompute", CollapseSumLikeCompute)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);

// BroadCastTo: <A, B> -> B where BroadCast(A, B) = B
bool BroadCastToRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
Expand Down
13 changes: 7 additions & 6 deletions tests/python/relay/test_op_grad_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,8 @@ def relu(x):


def test_unary_op():
def check_single_op(opfunc, ref):
def check_single_op(opfunc, ref, dtype):
shape = (10, 4)
dtype = 'float32'
tp = relay.TensorType(shape, dtype)
x = relay.var("x", tp)
y = opfunc(x)
Expand Down Expand Up @@ -76,16 +75,17 @@ def check_single_op(opfunc, ref):
(tvm.relay.acosh, lambda x: 1./ (x**2 - 1.)**(1./2.)),
(tvm.relay.asinh, lambda x: 1./ (x**2 + 1.)**(1./2.)),
(tvm.relay.atanh, lambda x: -1./ (x**2 - 1.))]:
check_single_op(opfunc, ref)
for dtype in ('float32', 'float64'):
check_single_op(opfunc, ref, dtype)


def test_binary_op():
def inst(vars, sh):
return [vars.get(s, s) for s in sh]

def check_binary_op(opfunc, ref):
def check_binary_op(opfunc, ref, dtype):
s = (5, 10, 5)
t = relay.TensorType((5, 10, 5))
t = relay.TensorType((5, 10, 5), dtype=dtype)
x = relay.var("x", t)
y = relay.var("y", t)
z = opfunc(x, y)
Expand All @@ -107,7 +107,8 @@ def check_binary_op(opfunc, ref):
(relay.subtract, lambda x, y: [np.ones_like(x), -np.ones_like(y)]),
(relay.multiply, lambda x, y: [y, x]),
(relay.divide, lambda x, y: [1 / y, - x / (y**2)])]:
check_binary_op(opfunc, ref)
for dtype in ('float32', 'float64'):
check_binary_op(opfunc, ref, dtype)


def test_softmax_grad():
Expand Down
14 changes: 8 additions & 6 deletions tests/python/relay/test_op_grad_level10.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,17 @@


def test_cross_entropy_grad():
x = relay.var("x", shape=(2, 5))
y = relay.var("y", shape=(2, 5))
check_grad(relay.Function([x, y], relay.op.nn.cross_entropy(x, y)), eps=0.01, scale=0.1, mean=1)
for dtype in ('float32', 'float64'):
x = relay.var("x", shape=(2, 5), dtype=dtype)
y = relay.var("y", shape=(2, 5), dtype=dtype)
check_grad(relay.Function([x, y], relay.op.nn.cross_entropy(x, y)), eps=0.01, scale=0.1, mean=1)


def test_cross_entropy_with_logits_grad():
x = relay.var("x", shape=(2, 5))
y = relay.var("y", shape=(2, 5))
check_grad(relay.Function([x, y], relay.op.nn.cross_entropy_with_logits(x, y)), eps=0.01, scale=0.1, mean=1)
for dtype in ('float32', 'float64'):
x = relay.var("x", shape=(2, 5), dtype=dtype)
y = relay.var("y", shape=(2, 5), dtype=dtype)
check_grad(relay.Function([x, y], relay.op.nn.cross_entropy_with_logits(x, y)), eps=0.01, scale=0.1, mean=1)

def test_checkpoint():
inputs = [relay.var("x{}".format(i), shape=(1,)) for i in range(4)]
Expand Down
31 changes: 16 additions & 15 deletions tests/python/relay/test_op_grad_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,22 @@


def test_clip():
ref = (lambda x: np.where(x > 10.0, np.zeros_like(x),
np.where(x < 1.0, np.zeros_like(x), np.ones_like(x))))
x = relay.var("x", relay.TensorType((10, 4), "float32"))
y = tvm.relay.clip(x, 1.0, 10.0)

data = np.random.rand(10, 4).astype("float32") * 11.0
ref_grad = ref(data)
fwd_func = relay.Function([x], y)
fwd_func = run_infer_type(fwd_func)
bwd_func = run_infer_type(gradient(fwd_func))

for target, ctx in ctx_list():
intrp = relay.create_executor(ctx=ctx, target=target)
op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data)
np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)
for dtype in ('float32', 'float64'):
ref = (lambda x: np.where(x > 10.0, np.zeros_like(x),
np.where(x < 1.0, np.zeros_like(x), np.ones_like(x))))
x = relay.var("x", relay.TensorType((10, 4), dtype))
y = tvm.relay.clip(x, 1.0, 10.0)

data = np.random.rand(10, 4).astype(dtype) * 11.0
ref_grad = ref(data)
fwd_func = relay.Function([x], y)
fwd_func = run_infer_type(fwd_func)
bwd_func = run_infer_type(gradient(fwd_func))

for target, ctx in ctx_list():
intrp = relay.create_executor(ctx=ctx, target=target)
op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data)
np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)


def verify_transpose_grad(d_shape, axes=None):
Expand Down
20 changes: 20 additions & 0 deletions tests/python/relay/test_op_level10.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,26 @@ def test_collapse_sum_like():
op_res = intrp.evaluate(func)(x, y)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)


def test_collapse_sum_to():
shape = (3, 4, 5, 6)
shape_to = (4, 5, 6)
dtype = "float32"
x = relay.Var("x", relay.ty.TensorType(shape , dtype))
z = relay.collapse_sum_to(x, shape_to)
zz = run_infer_type(z)
assert zz.checked_type == relay.ty.TensorType(shape_to, dtype)

func = relay.Function([x], z)
x = np.random.uniform(size=shape).astype(dtype)
ref_res = np.sum(x, 0)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)


def test_broadcast_to():
shape = (4, 1, 6)
shape_like = (3, 4, 5, 6)
Expand Down

0 comments on commit af1f80a

Please sign in to comment.