Skip to content

Commit

Permalink
Default behavior of divmod() is different from MATLAB mod(), on which… (
Browse files Browse the repository at this point in the history
#237)

* Default behavior of divmod() is different from MATLAB mod(), on which the original logic was based. The prior logic threw a ZeroDivisionError if printitn == 0. Instead, this fix avoids this error by testing that printitn > 0. Tests added for various values of printitn.

* Update tests/test_cp_als.py

Co-authored-by: Nick <[email protected]>

* Update tests/test_cp_als.py

Co-authored-by: Nick <[email protected]>

* Applying Nick's suggesetions to remove mark, change maxiters to save CPU cycles. Also removing output since nothing is actually checked.

* Removes mark

* Closes #235
---------

Co-authored-by: Jeremy Myers <[email protected]>
Co-authored-by: Nick <[email protected]>
  • Loading branch information
3 people authored Sep 16, 2023
1 parent f995ca9 commit efaf16c
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pyttb/cp_als.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def cp_als( # noqa: PLR0912,PLR0913,PLR0915
else:
flag = 1

if (divmod(iteration, printitn)[1] == 0) or (printitn > 0 and flag == 0):
if (printitn > 0) and ((divmod(iteration, printitn)[1] == 0) or (flag == 0)):
print(f" Iter {iteration}: f = {fit:e} f-delta = {fitchange:7.1e}")

# Check for convergence
Expand Down
20 changes: 20 additions & 0 deletions tests/test_cp_als.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,23 @@ def test_cp_als_sptensor_zeros(capsys):
capsys.readouterr()
assert pytest.approx(output3["fit"], 1) == 0
assert output3["normresidual"] == 0


def test_cp_als_tensor_printitn(capsys, sample_tensor):
_, T = sample_tensor

# default printitn
ttb.cp_als(T, 2, printitn=1, maxiters=2)
capsys.readouterr()

# zero printitn
ttb.cp_als(T, 2, printitn=0, maxiters=2)
capsys.readouterr()

# negative printitn
ttb.cp_als(T, 2, printitn=-1, maxiters=2)
capsys.readouterr()

# float printitn
ttb.cp_als(T, 2, printitn=1.5, maxiters=2)
capsys.readouterr()

0 comments on commit efaf16c

Please sign in to comment.