Skip to content

Commit

Permalink
refactor: methods for SimulateEIS, updt. fitting_example
Browse files Browse the repository at this point in the history
  • Loading branch information
BradyPlanden committed Jul 11, 2024
1 parent e754d7d commit f1bd83d
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 56 deletions.
33 changes: 5 additions & 28 deletions examples/scripts/eis_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

# Define model
parameter_set = pybop.ParameterSet.pybamm("Chen2020")
model = pybop.lithium_ion.DFN(
model = pybop.lithium_ion.SPM(
parameter_set=parameter_set, options={"surface form": "differential"}
)

Expand Down Expand Up @@ -39,33 +39,10 @@
signal = ["Impedance"]
# Generate problem, cost function, and optimisation class
problem = pybop.EISProblem(model, parameters, dataset, signal=signal)
prediction_1 = problem.evaluate(np.array([1.0, 60e-6]))
prediction_2 = problem.evaluate(np.array([10.0, 40e-6]))
prediction_1 = problem.evaluate(np.array([0.1, 50e-6]))
prediction_2 = problem.evaluate(np.array([10, 70e-6]))

# Plot
fig = px.scatter(x=prediction_1["Impedance"].real, y=-prediction_1["Impedance"].imag)
fig.add_scatter(x=prediction_2["Impedance"].real, y=-prediction_2["Impedance"].imag)
fig.show()
# cost = pybop.SumSquaredError(problem)
# optim = pybop.CMAES(cost, max_iterations=100)

# # Run the optimisation
# x, final_cost = optim.run()
# print("True parameters:", parameters.true_value())
# print("Estimated parameters:", x)

# # Plot the time series
# pybop.plot_dataset(dataset)

# # Plot the timeseries output
# pybop.quick_plot(problem, problem_inputs=x, title="Optimised Comparison")

# # Plot convergence
# pybop.plot_convergence(optim)

# # Plot the parameter traces
# pybop.plot_parameters(optim)

# # Plot the cost landscape
# pybop.plot2d(cost, steps=15)

# # Plot the cost landscape with optimisation path
# pybop.plot2d(optim, steps=15)
48 changes: 20 additions & 28 deletions pybop/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ def build(
else:
if not self.pybamm_model._built:
self.pybamm_model.build_model()
self.set_params(eis=self.eis)

self.set_params(eis=self.eis)
self._mesh = pybamm.Mesh(self.geometry, self.submesh_types, self.var_pts)
self._disc = pybamm.Discretisation(
mesh=self.mesh,
Expand Down Expand Up @@ -460,8 +460,8 @@ def simulateEIS(self, inputs: Inputs, f_eval: list) -> dict[str, np.ndarray]:
inputs : dict or array-like
The input parameters for the simulation. If array-like, it will be
converted to a dictionary using the model's fit keys.
t_eval : array-like
An array of time points at which to evaluate the solution.
f_eval : array-like
An array of frequency points at which to evaluate the solution.
Returns
-------
Expand Down Expand Up @@ -492,31 +492,25 @@ def simulateEIS(self, inputs: Inputs, f_eval: list) -> dict[str, np.ndarray]:
zs = [self.calculate_impedance(frequency) for frequency in f_eval]
return {"Impedance": np.asarray(zs) * self.z_scale}

def initialise_eis_simulation(self, inputs: Inputs = None):
# Get the mass matrix
def initialise_eis_simulation(self, inputs: Optional[Inputs] = None):
# Set mass matrix, and solver
self.M = self._built_model.mass_matrix.entries
self._solver.set_up(self._built_model, inputs=inputs)

if inputs is not None:
casadi_inputs = (
casadi.vertcat(*[x for x in inputs.values()])
if self._built_model.convert_to_format == "casadi"
else inputs
)

# Set up the solver for new inputs
self._solver.set_up(self._built_model, inputs=inputs)
# Convert inputs to casadi format if needed
casadi_inputs = (
casadi.vertcat(*inputs.values())
if inputs is not None and self._built_model.convert_to_format == "casadi"
else inputs or []
)

# Extract necessary attributes from the model
self.y0 = self._built_model.concatenated_initial_conditions.evaluate(
0, inputs=inputs
)
self.J = self._built_model.jac_rhs_algebraic_eval(
0, self.y0, casadi_inputs
).sparse()
else:
# Extract necessary attributes from the model
self.y0 = self._built_model.concatenated_initial_conditions.entries
self.J = self._built_model.jac_rhs_algebraic_eval(0, self.y0, []).sparse()
# Extract necessary attributes from the model
self.y0 = self._built_model.concatenated_initial_conditions.evaluate(
0, inputs=inputs
)
self.J = self._built_model.jac_rhs_algebraic_eval(
0, self.y0, casadi_inputs
).sparse()

# Convert to Compressed Sparse Column format
self.M = csc_matrix(self.M)
Expand Down Expand Up @@ -679,9 +673,7 @@ def predict(
parameter_values=parameter_set,
).solve(initial_soc=init_soc)
else:
raise ValueError(
"This sim method currently only supports PyBaMM models"
)
raise ValueError("This method currently only supports PyBaMM models")

else:
return [np.inf]
Expand Down

0 comments on commit f1bd83d

Please sign in to comment.