Skip to content

Commit

Permalink
Merge branch 'main' into sptensor_shape
Browse files Browse the repository at this point in the history
  • Loading branch information
dmdunla authored Sep 16, 2023
2 parents 4d694fe + 7d2e5a9 commit 92b4281
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 3 deletions.
13 changes: 10 additions & 3 deletions pyttb/cp_als.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,9 @@ def cp_als( # noqa: PLR0912,PLR0913,PLR0915
factor_matrices[1] =
[[0.1467... 0.0923...]
[0.1862... 0.3455...]]
>>> print(output)
{'params': (0.0001, 1000, 1, [0, 1]), 'iters': 1, 'normresidual': ..., 'fit': ...}
>>> print(output["params"]) # doctest: +NORMALIZE_WHITESPACE
{'stoptol': 0.0001, 'maxiters': 1000, 'dimorder': [0, 1], 'printitn': 1,\
'fixsigns': True}
Example using "nvecs" initialization:
Expand Down Expand Up @@ -265,7 +266,13 @@ def cp_als( # noqa: PLR0912,PLR0913,PLR0915
print(f" Final f = {fit:e}")

output = {
"params": (stoptol, maxiters, printitn, dimorder),
"params": {
"stoptol": stoptol,
"maxiters": maxiters,
"dimorder": dimorder,
"printitn": printitn,
"fixsigns": fixsigns,
},
"iters": iteration,
"normresidual": normresidual,
"fit": fit,
Expand Down
19 changes: 19 additions & 0 deletions tests/test_cp_als.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,25 @@ def test_cp_als_sptensor_zeros(capsys):
assert output3["normresidual"] == 0


def test_cp_als_tensor_pass_params(capsys, sample_tensor):
_, T = sample_tensor
KInit = ttb.ktensor.from_function(np.random.random_sample, T.shape, 2)

_, _, output = ttb.cp_als(T, 2, init=KInit, maxiters=2)
capsys.readouterr()

# passing the same parameters back to the method will yield the exact same results
_, _, output1 = ttb.cp_als(T, 2, init=KInit, **output["params"])
capsys.readouterr()

# changing the order should also work
_, _, output2 = ttb.cp_als(T, 2, **output["params"], init=KInit)
capsys.readouterr()

assert output["params"] == output1["params"]
assert output["params"] == output2["params"]


def test_cp_als_tensor_printitn(capsys, sample_tensor):
_, T = sample_tensor

Expand Down

0 comments on commit 92b4281

Please sign in to comment.