Skip to content

Commit

Permalink
Migrating unit tests to pytest (Part 5) (pybamm-team#4333)
Browse files Browse the repository at this point in the history
* Migrating unit tests to pytest (Part 5)

Signed-off-by: Pradyot Ranjan <[email protected]>

* style: pre-commit fixes

* using is instead of ==

Signed-off-by: Pradyot Ranjan <[email protected]>

* using is instead of ==

Signed-off-by: Pradyot Ranjan <[email protected]>

* Update tests/unit/test_discretisations/test_discretisation.py

Co-authored-by: Saransh Chopra <[email protected]>

* Update tests/unit/test_discretisations/test_discretisation.py

Co-authored-by: Saransh Chopra <[email protected]>

* Update tests/unit/test_discretisations/test_discretisation.py

Co-authored-by: Saransh Chopra <[email protected]>

* Update tests/unit/test_discretisations/test_discretisation.py

Co-authored-by: Saransh Chopra <[email protected]>

* Update tests/unit/test_batch_study.py

Co-authored-by: Saransh Chopra <[email protected]>

* Update tests/unit/test_expression_tree/test_operations/test_evaluate_python.py

Co-authored-by: Saransh Chopra <[email protected]>

* Adding suggestions

Signed-off-by: Pradyot Ranjan <[email protected]>

---------

Signed-off-by: Pradyot Ranjan <[email protected]>
Co-authored-by: Pradyot Ranjan <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Saransh Chopra <[email protected]>
  • Loading branch information
4 people authored Aug 12, 2024
1 parent b1fc595 commit 83115e8
Show file tree
Hide file tree
Showing 20 changed files with 1,406 additions and 1,705 deletions.
40 changes: 15 additions & 25 deletions tests/unit/test_batch_study.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
Tests for the batch_study.py
"""

import pytest
import os
import pybamm
import unittest
from tempfile import TemporaryDirectory


class TestBatchStudy(unittest.TestCase):
class TestBatchStudy:
def test_solve(self):
spm = pybamm.lithium_ion.SPM()
spm_uniform = pybamm.lithium_ion.SPM({"particle": "uniform profile"})
Expand Down Expand Up @@ -41,62 +41,62 @@ def test_solve(self):

# Tests for exceptions
for name in pybamm.BatchStudy.INPUT_LIST:
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
pybamm.BatchStudy(
models={"SPM": spm, "SPM uniform": spm_uniform}, **{name: {None}}
)

# Tests for None when only models are given with permutations=False
bs_false_only_models.solve(t_eval=[0, 3600])
self.assertEqual(2, len(bs_false_only_models.sims))
assert len(bs_false_only_models.sims) == 2

# Tests for None when only models are given with permutations=True
bs_true_only_models.solve(t_eval=[0, 3600])
self.assertEqual(2, len(bs_true_only_models.sims))
assert 2 == len(bs_true_only_models.sims)

# Tests for BatchStudy when permutations=False
bs_false.solve()
bs_false.plot(show_plot=False)
self.assertEqual(2, len(bs_false.sims))
assert len(bs_false.sims) == 2
for num in range(len(bs_false.sims)):
output_model = bs_false.sims[num].model.name
models_list = [model.name for model in bs_false.models.values()]
self.assertIn(output_model, models_list)
assert output_model in models_list

output_solver = bs_false.sims[num].solver.name
solvers_list = [solver.name for solver in bs_false.solvers.values()]
self.assertIn(output_solver, solvers_list)
assert output_solver in solvers_list

output_experiment = bs_false.sims[num].experiment.steps
experiments_list = [
experiment.steps for experiment in bs_false.experiments.values()
]
self.assertIn(output_experiment, experiments_list)
assert output_experiment in experiments_list

# Tests for BatchStudy when permutations=True
bs_true.solve()
bs_true.plot(show_plot=False)
self.assertEqual(4, len(bs_true.sims))
assert len(bs_true.sims) == 4
for num in range(len(bs_true.sims)):
output_model = bs_true.sims[num].model.name
models_list = [model.name for model in bs_true.models.values()]
self.assertIn(output_model, models_list)
assert output_model in models_list

output_solver = bs_true.sims[num].solver.name
solvers_list = [solver.name for solver in bs_true.solvers.values()]
self.assertIn(output_solver, solvers_list)
assert output_solver in solvers_list

output_experiment = bs_true.sims[num].experiment.steps
experiments_list = [
experiment.steps for experiment in bs_true.experiments.values()
]
self.assertIn(output_experiment, experiments_list)
assert output_experiment in experiments_list

def test_create_gif(self):
with TemporaryDirectory() as dir_name:
bs = pybamm.BatchStudy({"spm": pybamm.lithium_ion.SPM()})
with self.assertRaisesRegex(
ValueError, "The simulations have not been solved yet."
with pytest.raises(
ValueError, match="The simulations have not been solved yet."
):
pybamm.BatchStudy(
models={
Expand All @@ -117,13 +117,3 @@ def test_create_gif(self):
# create a GIF after calling the plot method
bs.plot(show_plot=False)
bs.create_gif(number_of_images=3, duration=1, output_filename=test_file)


if __name__ == "__main__":
print("Add -v for more debug output")
import sys

if "-v" in sys.argv:
debug = True
pybamm.settings.debug_mode = True
unittest.main()
57 changes: 24 additions & 33 deletions tests/unit/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#
# Tests the citations class.
#
import pytest

import pybamm
import unittest
import os
from pybamm import callbacks

Expand All @@ -18,7 +18,8 @@ def on_experiment_end(self, logs):
print(self.name, file=f)


class TestCallbacks(unittest.TestCase):
class TestCallbacks:
@pytest.fixture(autouse=True)
def tearDown(self):
# Remove any test log files that were created, even if the test fails
for logfile in ["test_callback.log", "test_callback_2.log"]:
Expand All @@ -32,22 +33,22 @@ def tearDown(self):
def test_setup_callbacks(self):
# No callbacks, LoggingCallback should be added
callbacks = pybamm.callbacks.setup_callbacks(None)
self.assertIsInstance(callbacks, pybamm.callbacks.CallbackList)
self.assertEqual(len(callbacks), 1)
self.assertIsInstance(callbacks[0], pybamm.callbacks.LoggingCallback)
assert isinstance(callbacks, pybamm.callbacks.CallbackList)
assert len(callbacks) == 1
assert isinstance(callbacks[0], pybamm.callbacks.LoggingCallback)

# Single object, transformed to list
callbacks = pybamm.callbacks.setup_callbacks(1)
self.assertIsInstance(callbacks, pybamm.callbacks.CallbackList)
self.assertEqual(len(callbacks), 2)
self.assertEqual(callbacks.callbacks[0], 1)
self.assertIsInstance(callbacks[-1], pybamm.callbacks.LoggingCallback)
assert isinstance(callbacks, pybamm.callbacks.CallbackList)
assert len(callbacks) == 2
assert callbacks.callbacks[0] == 1
assert isinstance(callbacks[-1], pybamm.callbacks.LoggingCallback)

# List
callbacks = pybamm.callbacks.setup_callbacks([1, 2, 3])
self.assertIsInstance(callbacks, pybamm.callbacks.CallbackList)
self.assertEqual(callbacks.callbacks[:3], [1, 2, 3])
self.assertIsInstance(callbacks[-1], pybamm.callbacks.LoggingCallback)
assert isinstance(callbacks, pybamm.callbacks.CallbackList)
assert callbacks.callbacks[:3] == [1, 2, 3]
assert isinstance(callbacks[-1], pybamm.callbacks.LoggingCallback)

def test_callback_list(self):
"Tests multiple callbacks in a list"
Expand All @@ -64,18 +65,18 @@ def test_callback_list(self):
)
callback.on_experiment_end(None)
with open("test_callback.log") as f:
self.assertEqual(f.read(), "first\n")
assert f.read() == "first\n"
with open("test_callback_2.log") as f:
self.assertEqual(f.read(), "second\n")
assert f.read() == "second\n"

def test_logging_callback(self):
# No argument, should use pybamm's logger
callback = pybamm.callbacks.LoggingCallback()
self.assertEqual(callback.logger, pybamm.logger)
assert callback.logger == pybamm.logger

pybamm.set_logging_level("NOTICE")
callback = pybamm.callbacks.LoggingCallback("test_callback.log")
self.assertEqual(callback.logfile, "test_callback.log")
assert callback.logfile == "test_callback.log"

logs = {
"cycle number": (5, 12),
Expand All @@ -87,39 +88,29 @@ def test_logging_callback(self):
}
callback.on_experiment_start(logs)
with open("test_callback.log") as f:
self.assertEqual(f.read(), "")
assert f.read() == ""

callback.on_cycle_start(logs)
with open("test_callback.log") as f:
self.assertIn("Cycle 5/12", f.read())
assert "Cycle 5/12" in f.read()

callback.on_step_start(logs)
with open("test_callback.log") as f:
self.assertIn("Cycle 5/12, step 1/4", f.read())
assert "Cycle 5/12, step 1/4" in f.read()

callback.on_experiment_infeasible_event(logs)
with open("test_callback.log") as f:
self.assertIn("Experiment is infeasible: 'event'", f.read())
assert "Experiment is infeasible: 'event'" in f.read()

callback.on_experiment_infeasible_time(logs)
with open("test_callback.log") as f:
self.assertIn("Experiment is infeasible: default duration", f.read())
assert "Experiment is infeasible: default duration" in f.read()

callback.on_experiment_end(logs)
with open("test_callback.log") as f:
self.assertIn("took 0.45", f.read())
assert "took 0.45" in f.read()

# Calling start again should clear the log
callback.on_experiment_start(logs)
with open("test_callback.log") as f:
self.assertEqual(f.read(), "")


if __name__ == "__main__":
print("Add -v for more debug output")
import sys

if "-v" in sys.argv:
debug = True
pybamm.settings.debug_mode = True
unittest.main()
assert f.read() == ""
Loading

0 comments on commit 83115e8

Please sign in to comment.