Skip to content

Commit

Permalink
SPTENSOR: Infer shape from data (#243)
Browse files Browse the repository at this point in the history
Co-authored-by: Danny Dunlavy <[email protected]>
  • Loading branch information
ntjohnson1 and dmdunla authored Sep 16, 2023
1 parent 7d2e5a9 commit 8c9bdbb
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
5 changes: 4 additions & 1 deletion pyttb/sptensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,14 @@ def __init__(
raise ValueError(f"Invalid shape provided: {shape}")
self.shape = tuple(shape)
return
if subs is None or vals is None or shape is None:
if subs is None or vals is None:
raise ValueError(
"For non-empty sptensors subs, vals, and shape must be provided"
)

if shape is None:
shape = tuple(np.max(subs, axis=0) + 1)

if subs.size > 0:
assert subs.shape[1] == len(shape) and np.all(
(np.max(subs, axis=0) + 1) <= shape
Expand Down
5 changes: 3 additions & 2 deletions tests/test_sptensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ def test_sptensor_initialization_from_data(sample_sptensor):
assert np.array_equal(sptensorInstance.vals, data["vals"])
assert sptensorInstance.shape == data["shape"]

with pytest.raises(ValueError):
ttb.sptensor(data["subs"], data["vals"])
# Infer shape from data
another_sptensor = ttb.sptensor(data["subs"], data["vals"])
assert another_sptensor.isequal(sptensorInstance)

with pytest.raises(AssertionError):
shape = (3, 3, 1)
Expand Down

0 comments on commit 8c9bdbb

Please sign in to comment.