Skip to content

Commit

Permalink
Fix bug in tensor.mttkrp that only showed up when ndims > 3. (#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
dmdunla authored Jul 11, 2022
1 parent 3f597b5 commit 0bf1f5a
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pyttb/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,7 @@ def mttkrp(self, U, n):
return Y.T @ Ul
else:
Ul = ttb.khatrirao(U[n+1:], reverse=True)
Ur = np.reshape(ttb.khatrirao(U[0:self.ndims - 2], reverse=True), (szl, 1, R), order='F')
Ur = np.reshape(ttb.khatrirao(U[0:n], reverse=True), (szl, 1, R), order='F')
Y = np.reshape(self.data, (-1, szr), order='F')
Y = Y @ Ul
Y = np.reshape(Y, (szl, szn, R), order='F')
Expand Down
38 changes: 38 additions & 0 deletions tests/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1175,6 +1175,7 @@ def test_tensor_mttkrp(sample_tensor_2way):
(params, tensorInstance) = sample_tensor_2way
tensorInstance = ttb.tensor.from_function(np.ones, (2, 3, 4))

# 2-way sparse tensor
weights = np.array([2., 2.])
fm0 = np.array([[1., 3.], [2., 4.]])
fm1 = np.array([[5., 8.], [6., 9.], [7., 10.]])
Expand All @@ -1195,6 +1196,43 @@ def test_tensor_mttkrp(sample_tensor_2way):
assert np.allclose(tensorInstance.mttkrp(ktensorInstance, 1), m1)
assert np.allclose(tensorInstance.mttkrp(ktensorInstance, 2), m2)

# 5-way dense tensor
shape = (2,3,4,5,6)
T = ttb.tensor.from_data(np.arange(1,np.prod(shape)+1), shape)
U = [];
for s in shape:
U.append(np.ones((s,2)))

data0 = np.array([[129600, 129600],
[129960, 129960]])
assert (T.mttkrp(U,0) == data0).all()

data1 = np.array([[86040, 86040],
[86520, 86520],
[87000, 87000]])
assert (T.mttkrp(U,1) == data1).all()

data2 = np.array([[63270, 63270],
[64350, 64350],
[65430, 65430],
[66510, 66510]])
assert (T.mttkrp(U,2) == data2).all()

data3 = np.array([[45000, 45000],
[48456, 48456],
[51912, 51912],
[55368, 55368],
[58824, 58824]])
assert (T.mttkrp(U,3) == data3).all()

data4 = np.array([[ 7260, 7260],
[21660, 21660],
[36060, 36060],
[50460, 50460],
[64860, 64860],
[79260, 79260]])
assert (T.mttkrp(U,4) == data4).all()

# tensor too small
with pytest.raises(AssertionError) as excinfo:
tensorInstance2 = ttb.tensor.from_data(np.array([1]))
Expand Down

0 comments on commit 0bf1f5a

Please sign in to comment.