-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add StandaloneProblem * Make separate test_standalone * Update problem evaluate and evaluateS1 * Update CHANGELOG.md * Increase test coverage * Update checks for multiple signals
- Loading branch information
1 parent
a7515a7
commit 2e35380
Showing
7 changed files
with
204 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import numpy as np | ||
from pybop._problem import BaseProblem | ||
|
||
|
||
class StandaloneProblem(BaseProblem): | ||
""" | ||
Defines an example standalone problem without a Model. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
parameters, | ||
dataset, | ||
model=None, | ||
check_model=True, | ||
signal=None, | ||
init_soc=None, | ||
x0=None, | ||
): | ||
super().__init__(parameters, model, check_model, signal, init_soc, x0) | ||
self._dataset = dataset.data | ||
|
||
# Check that the dataset contains time and current | ||
for name in ["Time [s]"] + self.signal: | ||
if name not in self._dataset: | ||
raise ValueError(f"expected {name} in list of dataset") | ||
|
||
self._time_data = self._dataset["Time [s]"] | ||
self.n_time_data = len(self._time_data) | ||
if np.any(self._time_data < 0): | ||
raise ValueError("Times can not be negative.") | ||
if np.any(self._time_data[:-1] >= self._time_data[1:]): | ||
raise ValueError("Times must be increasing.") | ||
|
||
for signal in self.signal: | ||
if len(self._dataset[signal]) != self.n_time_data: | ||
raise ValueError( | ||
f"Time data and {signal} data must be the same length." | ||
) | ||
target = [self._dataset[signal] for signal in self.signal] | ||
self._target = np.vstack(target).T | ||
|
||
def evaluate(self, x): | ||
""" | ||
Evaluate the model with the given parameters and return the signal. | ||
Parameters | ||
---------- | ||
x : np.ndarray | ||
Parameter values to evaluate the model at. | ||
Returns | ||
------- | ||
y : np.ndarray | ||
The model output y(t) simulated with inputs x. | ||
""" | ||
|
||
return x[0] * self._time_data + x[1] | ||
|
||
def evaluateS1(self, x): | ||
""" | ||
Evaluate the model with the given parameters and return the signal and its derivatives. | ||
Parameters | ||
---------- | ||
x : np.ndarray | ||
Parameter values to evaluate the model at. | ||
Returns | ||
------- | ||
tuple | ||
A tuple containing the simulation result y(t) and the sensitivities dy/dx(t) evaluated | ||
with given inputs x. | ||
""" | ||
|
||
y = x[0] * self._time_data + x[1] | ||
|
||
dy = np.dstack([self._time_data, np.zeros(self._time_data.shape)]) | ||
|
||
return (np.asarray(y), np.asarray(dy)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import pytest | ||
import pybop | ||
import numpy as np | ||
from examples.standalone.cost import StandaloneCost | ||
from examples.standalone.problem import StandaloneProblem | ||
|
||
|
||
class TestStandalone: | ||
""" | ||
Class for testing stanadalone components. | ||
""" | ||
|
||
@pytest.mark.unit | ||
def test_standalone(self): | ||
# Build an Optimisation problem with a StandaloneCost | ||
cost = StandaloneCost() | ||
opt = pybop.Optimisation(cost=cost, optimiser=pybop.NLoptOptimize) | ||
x, final_cost = opt.run() | ||
|
||
assert len(opt.x0) == opt.n_parameters | ||
np.testing.assert_allclose(x, 0, atol=1e-2) | ||
np.testing.assert_allclose(final_cost, 42, atol=1e-2) | ||
|
||
@pytest.mark.unit | ||
def test_standalone_problem(self): | ||
# Define parameters to estimate | ||
parameters = [ | ||
pybop.Parameter( | ||
"Gradient", | ||
prior=pybop.Gaussian(4.2, 0.02), | ||
bounds=[-1, 10], | ||
), | ||
pybop.Parameter( | ||
"Intercept", | ||
prior=pybop.Gaussian(3.3, 0.02), | ||
bounds=[-1, 10], | ||
), | ||
] | ||
|
||
# Define target data | ||
t_eval = np.linspace(0, 1, 100) | ||
x0 = np.array([3, 4]) | ||
dataset = pybop.Dataset( | ||
{ | ||
"Time [s]": t_eval, | ||
"Output": x0[0] * t_eval + x0[1], | ||
} | ||
) | ||
signal = "Output" | ||
|
||
# Define a Problem without a Model | ||
problem = StandaloneProblem(parameters, dataset, signal=signal) | ||
|
||
# Test the Problem with a Cost | ||
rmse_cost = pybop.RootMeanSquaredError(problem) | ||
x = rmse_cost([1, 2]) | ||
|
||
np.testing.assert_allclose(x, 3.138, atol=1e-2) | ||
|
||
# Test the sensitivities | ||
sums_cost = pybop.SumSquaredError(problem) | ||
sums_cost.evaluateS1([1, 2]) | ||
|
||
# Test incorrect number of initial parameter values | ||
with pytest.raises(ValueError): | ||
StandaloneProblem(parameters, dataset, signal=signal, x0=np.array([])) | ||
|
||
# Test problem construction errors | ||
for bad_dataset in [ | ||
pybop.Dataset({"Time [s]": np.array([0])}), | ||
pybop.Dataset( | ||
{ | ||
"Time [s]": np.array([-1]), | ||
"Output": np.array([0]), | ||
} | ||
), | ||
pybop.Dataset( | ||
{ | ||
"Time [s]": np.array([1, 0]), | ||
"Output": np.array([0, 0]), | ||
} | ||
), | ||
pybop.Dataset( | ||
{ | ||
"Time [s]": np.array([0]), | ||
"Output": np.array([0, 0]), | ||
} | ||
), | ||
pybop.Dataset( | ||
{ | ||
"Time [s]": np.array([[0], [0]]), | ||
"Output": np.array([0, 0]), | ||
} | ||
), | ||
]: | ||
with pytest.raises(ValueError): | ||
StandaloneProblem(parameters, bad_dataset, signal=signal) |