From 0bf1f5a6067712d19563def0d87934335e169753 Mon Sep 17 00:00:00 2001 From: "Daniel M. Dunlavy" Date: Mon, 11 Jul 2022 14:17:58 -0600 Subject: [PATCH] Fix bug in tensor.mttkrp that only showed up when ndims > 3. (#36) --- pyttb/tensor.py | 2 +- tests/test_tensor.py | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/pyttb/tensor.py b/pyttb/tensor.py index d771fe34..4b4ce3d6 100644 --- a/pyttb/tensor.py +++ b/pyttb/tensor.py @@ -600,7 +600,7 @@ def mttkrp(self, U, n): return Y.T @ Ul else: Ul = ttb.khatrirao(U[n+1:], reverse=True) - Ur = np.reshape(ttb.khatrirao(U[0:self.ndims - 2], reverse=True), (szl, 1, R), order='F') + Ur = np.reshape(ttb.khatrirao(U[0:n], reverse=True), (szl, 1, R), order='F') Y = np.reshape(self.data, (-1, szr), order='F') Y = Y @ Ul Y = np.reshape(Y, (szl, szn, R), order='F') diff --git a/tests/test_tensor.py b/tests/test_tensor.py index 21ed1dc0..f3fce346 100644 --- a/tests/test_tensor.py +++ b/tests/test_tensor.py @@ -1175,6 +1175,7 @@ def test_tensor_mttkrp(sample_tensor_2way): (params, tensorInstance) = sample_tensor_2way tensorInstance = ttb.tensor.from_function(np.ones, (2, 3, 4)) + # 2-way sparse tensor weights = np.array([2., 2.]) fm0 = np.array([[1., 3.], [2., 4.]]) fm1 = np.array([[5., 8.], [6., 9.], [7., 10.]]) @@ -1195,6 +1196,43 @@ def test_tensor_mttkrp(sample_tensor_2way): assert np.allclose(tensorInstance.mttkrp(ktensorInstance, 1), m1) assert np.allclose(tensorInstance.mttkrp(ktensorInstance, 2), m2) + # 5-way dense tensor + shape = (2,3,4,5,6) + T = ttb.tensor.from_data(np.arange(1,np.prod(shape)+1), shape) + U = []; + for s in shape: + U.append(np.ones((s,2))) + + data0 = np.array([[129600, 129600], + [129960, 129960]]) + assert (T.mttkrp(U,0) == data0).all() + + data1 = np.array([[86040, 86040], + [86520, 86520], + [87000, 87000]]) + assert (T.mttkrp(U,1) == data1).all() + + data2 = np.array([[63270, 63270], + [64350, 64350], + [65430, 65430], + [66510, 66510]]) + assert (T.mttkrp(U,2) == data2).all() + + data3 = np.array([[45000, 45000], + [48456, 48456], + [51912, 51912], + [55368, 55368], + [58824, 58824]]) + assert (T.mttkrp(U,3) == data3).all() + + data4 = np.array([[ 7260, 7260], + [21660, 21660], + [36060, 36060], + [50460, 50460], + [64860, 64860], + [79260, 79260]]) + assert (T.mttkrp(U,4) == data4).all() + # tensor too small with pytest.raises(AssertionError) as excinfo: tensorInstance2 = ttb.tensor.from_data(np.array([1]))