-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Add Large Tensor Test for linalg_syrk #18782
Changes from 3 commits
6c1a904
0cefe25
ae4318a
8d9b188
b49f999
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,6 +40,7 @@ | |
LARGE_SIZE = LARGE_X * SMALL_Y | ||
LARGE_TENSOR_SHAPE = 2**32 | ||
RNN_LARGE_TENSOR = 2**28 | ||
LARGE_SQ_X = 70000 | ||
|
||
|
||
def test_nn(): | ||
|
@@ -1791,6 +1792,29 @@ def test_sparse_dot(): | |
assert out.shape == (2, 2) | ||
|
||
|
||
def test_linalg_operators(): | ||
def check_syrk_batch(): | ||
# test both forward and backward | ||
# batch syrk will be applied to the last two dimensions | ||
A = nd.zeros((2, LARGE_SQ_X, LARGE_SQ_X)) | ||
for i in range(LARGE_SQ_X): | ||
A[0,i,i] = 1 | ||
A[1,i,i] = 0.1 | ||
A.attach_grad() | ||
with mx.autograd.record(): | ||
out = nd.linalg.syrk(A, alpha=2, transpose=False) | ||
for i in range(LARGE_SQ_X): | ||
assert out[0,i,i] == 2 | ||
assert_almost_equal(out[1,i,i], nd.array([0.02]), rtol=1e-3, atol=1e-5) | ||
out.backward() | ||
for i in range(LARGE_SQ_X): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same |
||
# check the first row | ||
assert A.grad[0,0,i] == 4 | ||
assert_almost_equal(A.grad[1,0,i], nd.array([0.4]), rtol=1e-3, atol=1e-5) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Question: Why did this become 0.4 and not 0.04 ? OR just let me know if this output is consistent with smaller inputs like 2x2 or 3x3. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the correct result I believe. I verified with hand-written calculation. Yeah it also struck as counter-intuitive to me.. I am going to dive deep in matrix grad when I find time There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you just check with smaller input run and let me know the results. That should be good enough |
||
|
||
check_syrk_batch() | ||
|
||
|
||
if __name__ == '__main__': | ||
import nose | ||
nose.runmodule() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can check in 2 places in (0,y,y) and (1,y,y). No need to check in 70000 locations