-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-978] Second order gradient support for some unary operators #14613
Changes from all commits
45e1502
904adb4
d5dc994
0e69075
0c7cf98
492e4cd
45b334e
3bbfbac
4dc0907
c4034b2
3fe54e6
76aa6ad
8458717
f66610b
30ff1e9
8ecffcc
d9ba3da
1c93c7d
de721bc
0ac0942
f8e624e
3315124
8538980
1ee38b5
c18f317
689cfee
d56e132
2207815
0b6c2ef
31f671f
62fcca3
a0a0e75
451c4bd
b9b0c93
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,13 +15,55 @@ | |
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
import math | ||
|
||
import math | ||
from mxnet import nd, autograd | ||
from mxnet.test_utils import assert_almost_equal, random_arrays | ||
from mxnet.test_utils import assert_almost_equal, random_arrays, rand_shape_nd | ||
from common import with_seed | ||
|
||
|
||
@with_seed() | ||
def test_sin(): | ||
def sin(x): | ||
return nd.sin(x) | ||
|
||
def grad_grad_op(x): | ||
return -nd.sin(x) | ||
|
||
for dim in range(1, 5): | ||
shape = rand_shape_nd(dim) | ||
array = random_arrays(shape) | ||
check_second_order_unary(array, sin, grad_grad_op) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's only used in this test. If we add a different test file then it makes sense as you suggested. |
||
|
||
|
||
@with_seed() | ||
def test_cos(): | ||
def cos(x): | ||
return nd.cos(x) | ||
|
||
def grad_grad_op(x): | ||
return -nd.cos(x) | ||
|
||
for dim in range(1, 5): | ||
shape = rand_shape_nd(dim) | ||
array = random_arrays(shape) | ||
check_second_order_unary(array, cos, grad_grad_op) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can these There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This PR is only to verify second order gradient. Can we add test for Nth order gradient in a separate PR? |
||
|
||
|
||
@with_seed() | ||
def test_relu(): | ||
def relu(x): | ||
return nd.relu(x) | ||
|
||
def grad_grad_op(x): | ||
return nd.zeros_like(x) | ||
|
||
for dim in range(1, 5): | ||
shape = rand_shape_nd(dim) | ||
array = random_arrays(shape) | ||
check_second_order_unary(array, relu, grad_grad_op) | ||
|
||
|
||
@with_seed() | ||
def test_log(): | ||
def log(x): | ||
|
@@ -30,9 +72,9 @@ def log(x): | |
def grad_grad_op(x): | ||
return -1/(x**2) | ||
|
||
arrays = random_arrays((2, 2), (2, 3), (4, 5, 2), (3, 1, 4, 5)) | ||
|
||
for array in arrays: | ||
for dim in range(1, 5): | ||
shape = rand_shape_nd(dim) | ||
array = random_arrays(shape) | ||
check_second_order_unary(array, log, grad_grad_op) | ||
|
||
|
||
|
@@ -44,9 +86,9 @@ def log2(x): | |
def grad_grad_op(x): | ||
return -1/((x**2) * math.log(2)) | ||
|
||
arrays = random_arrays((2, 2), (2, 3), (4, 5, 2), (3, 1, 4, 5)) | ||
|
||
for array in arrays: | ||
for dim in range(1, 5): | ||
shape = rand_shape_nd(dim) | ||
array = random_arrays(shape) | ||
check_second_order_unary(array, log2, grad_grad_op) | ||
|
||
|
||
|
@@ -58,9 +100,9 @@ def log10(x): | |
def grad_grad_op(x): | ||
return -1/((x**2) * math.log(10)) | ||
|
||
arrays = random_arrays((2, 2), (2, 3), (4, 5, 2), (3, 1, 4, 5)) | ||
|
||
for array in arrays: | ||
for dim in range(1, 5): | ||
shape = rand_shape_nd(dim) | ||
array = random_arrays(shape) | ||
check_second_order_unary(array, log10, grad_grad_op) | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a hidden operator so user do not see this.