Skip to content

Commit

Permalink
TENSOR: Fix slices ref shen return value isn't scalar or vector. sand…
Browse files Browse the repository at this point in the history
  • Loading branch information
ntjohnson1 committed Feb 21, 2023
1 parent bc83b26 commit 2b259e6
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 11 deletions.
13 changes: 2 additions & 11 deletions pyttb/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1271,7 +1271,6 @@ def __getitem__(self, item):
kpdims = [] # dimensions to keep
rmdims = [] # dimensions to remove

# Determine the new size and what dimensions to keep
# Determine the new size and what dimensions to keep
for i in range(0, len(region)):
if isinstance(region[i], slice):
Expand All @@ -1289,19 +1288,11 @@ def __getitem__(self, item):

# If the size is zero, then the result is returned as a scalar
# otherwise, we convert the result to a tensor

if newsiz.size == 0:
a = newdata
else:
if rmdims.size == 0:
a = ttb.tensor.from_data(newdata)
else:
# If extracted data is a vector then no need to tranpose it
if len(newdata.shape) == 1:
a = ttb.tensor.from_data(newdata)
else:
a = ttb.tensor.from_data(np.transpose(newdata, np.concatenate((kpdims, rmdims))))
return ttb.tt_subsubsref(a, item)
a = ttb.tensor.from_data(newdata)
return a

# *** CASE 2a: Subscript indexing ***
if len(item) > 1 and isinstance(item[-1], str) and item[-1] == 'extract':
Expand Down
3 changes: 3 additions & 0 deletions tests/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,9 @@ def test_tensor__getitem__(sample_tensor_2way):
assert tensorInstance[0, 0] == params['data'][0, 0]
# Case 1 Subtensor
assert (tensorInstance[:, :] == tensorInstance).data.all()
three_way_data = np.random.random((2, 3, 4))
two_slices = (slice(None,None,None), 0, slice(None,None,None))
assert (ttb.tensor.from_data(three_way_data)[two_slices].double() == three_way_data[two_slices]).all()
# Case 1 Subtensor
assert (tensorInstance[np.array([0, 1]), :].data == tensorInstance.data[[0, 1], :]).all()
# Case 1 Subtensor
Expand Down

0 comments on commit 2b259e6

Please sign in to comment.