From 8d220a2e43f62e2cc8c1af5e1f0d7411a733ce6b Mon Sep 17 00:00:00 2001 From: Ke Han <38852697+hanke580@users.noreply.github.com> Date: Thu, 4 Jun 2020 03:28:24 +0800 Subject: [PATCH] [Numpy]Fix einsum issue #18102 (#18419) * * Fix einsum Bug * * Fix sanity * * Fix one dim start bug * * Fix test case gt --- src/operator/numpy/np_einsum_op-inl.h | 4 +++- tests/python/unittest/test_numpy_op.py | 20 ++++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/src/operator/numpy/np_einsum_op-inl.h b/src/operator/numpy/np_einsum_op-inl.h index ca80c7bc20be..8dd679c3fd19 100644 --- a/src/operator/numpy/np_einsum_op-inl.h +++ b/src/operator/numpy/np_einsum_op-inl.h @@ -724,7 +724,9 @@ inline void NumpyEinsumProcess(const std::vector& inputs, int j = 0; for (idim = 0; idim < ndim_iter; ++idim) { if (op_axes_arrays[i][idim] == -1 || - opshape[i][op_axes_arrays[i][idim]] == 1) { + (iop != nop && opshape[i][op_axes_arrays[i][idim]] == 1 && + op_axes_arrays[iop][idim] != -1 && + opshape[iop][op_axes_arrays[iop][idim]] != 1)) { remainstride[iop][j++] = iterstride[iop][idim]; } else { opstride[iop][op_axes_arrays[i][idim]] = iterstride[iop][idim]; diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 441c7274cefa..45b6a9c7c217 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -7833,6 +7833,26 @@ def dbg(name, data): # broadcast bug ('ij, ij -> i', [(1, 4), (2, 4)], lambda *args: (_np.sum(args[1], axis=0)[None, :], _np.tile(args[0], [2, 1]))), + # one dimensim bug + ('...ij, ...jk -> ...ik', [(1, 4), (4, 2)], lambda *args: (args[1].sum(axis=1)[None, :], + _np.tile(args[0].sum(axis=0)[: ,None], [1, 2]))), + ('...ij, ...jk -> ...ik', [(2, 4), (4, 2)], lambda *args: (_np.tile(args[1].sum(axis=1)[None, :], [2, 1]), + _np.tile(args[0].sum(axis=0)[: ,None], [1, 2]))), + ('...ij, ...jk -> ...ik', [(3, 2, 1, 4), (3, 2, 4, 2)], lambda *args: ( + args[1].sum(axis=3)[:, :, None, :], + _np.tile(args[0].sum(axis=2)[:, :, :, None], [1, 1, 1, 2]))), + ('...ij, ...ik -> ...jk', [(1, 1, 1, 4), (1, 1, 1, 3)], lambda *args: ( + _np.tile(args[1].sum(axis=3)[:, :, :, None], [1, 1, 1, 4]), + _np.tile(args[0].sum(axis=3)[:, :, : ,None], [1, 1, 1, 3]))), + ('...ij, ...jc -> ...ic', [(1, 1, 5, 3), (1, 1, 3, 2)], lambda *args: ( + _np.tile(args[1].sum(axis=3)[:, :, None, :], [1, 1, 5, 1]), + _np.tile(args[0].sum(axis=2)[:, :, : ,None], [1, 1, 1, 2]))), + ('...ij, ...jc -> ...ic', [(1, 2, 5, 4), (1, 2, 4, 2)], lambda *args: ( + _np.tile(args[1].sum(axis=3)[:, :, None, :], [1, 1, 5, 1]), + _np.tile(args[0].sum(axis=2)[:, :, : ,None], [1, 1, 1, 2]))), + ('...ij, ...jc -> ...ic', [(2, 1, 5, 4), (2, 1, 4, 2)], lambda *args: ( + _np.tile(args[1].sum(axis=3)[:, :, None, :], [1, 1, 5, 1]), + _np.tile(args[0].sum(axis=2)[:, :, : ,None], [1, 1, 1, 2]))), # issue #16576 # commented due to long running time # ('abiz,abjz->abij', [(64, 8, 128, 512), (64, 8, 128, 512)], lambda *args: (_np.matmul(_np.ones((64, 8, 128, 128)), args[1]),