diff --git a/grainlearning/bayesian_calibration.py b/grainlearning/bayesian_calibration.py index d35e7fc..1aad184 100644 --- a/grainlearning/bayesian_calibration.py +++ b/grainlearning/bayesian_calibration.py @@ -1,12 +1,18 @@ """ This module contains the Bayesian calibration class. """ -from typing import Type, Dict import os -from numpy import argmax +from typing import Dict, Type + from grainlearning.dynamic_systems import DynamicSystem, IODynamicSystem from grainlearning.iterative_bayesian_filter import IterativeBayesianFilter -from grainlearning.tools import plot_param_stats, plot_posterior, plot_param_data, plot_obs_and_sim +from grainlearning.tools import ( + plot_obs_and_sim, + plot_param_data, + plot_param_stats, + plot_posterior, +) +from numpy import argmax class BayesianCalibration: @@ -66,30 +72,31 @@ class BayesianCalibration: :param curr_iter: Current iteration step :param save_fig: Flag for skipping (-1), showing (0), or saving (1) the figures """ - #: Dynamic system whose parameters or hidden states are being inferred - system: Type["DynamicSystem"] - - #: Calibration method (e.g, Iterative Bayesian Filter) - calibration: Type["IterativeBayesianFilter"] - - #: Number of iterations - num_iter: int - - #: Current calibration step - curr_iter: int = 0 - - #: Flag to save figures - save_fig: int = -1 def __init__( self, system: Type["DynamicSystem"], calibration: Type["IterativeBayesianFilter"], - num_iter: int, - curr_iter: int, - save_fig: int + num_iter: int = 10, + curr_iter: int = 0, + save_fig: int = -1 ): - """Initialize the Bayesian calibration class""" + """Initialize a Bayesian calibration object + + + Parameters + ---------- + system : Type["DynamicSystem"] + Dynamic system whose parameters or hidden states are being inferred + calibration : Type["IterativeBayesianFilter"] + Calibration method (e.g, Iterative Bayesian Filter) + num_iter : int + Number of iterations + curr_iter : int + Current calibration step + save_fig : int + Flag to save figures + """ self.system = system self.calibration = calibration @@ -100,6 +107,8 @@ def __init__( self.save_fig = save_fig + + def run(self): """ This is the main calibration loop which does the following steps 1. First iteration of Bayesian calibration starts with a Halton sequence @@ -290,4 +299,4 @@ def from_dict( num_iter=obj["num_iter"], curr_iter=obj.get("curr_iter", 0), save_fig=obj.get("save_fig", -1) - ) + ) \ No newline at end of file