Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add color customization to plot_credible_intervals #414

Merged
merged 27 commits into from
May 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
48de030
Add color customization to plot_credible_intervals
FelipeR888 Feb 15, 2021
569b25b
Merge branch 'develop' into add_vis_colorcoding
yannikschaelte Feb 15, 2021
8cedb47
Add different variables for different colors
FelipeR888 Feb 16, 2021
9889068
Merge branch 'add_vis_colorcoding' of https://github.com/ICB-DCM/pyAB…
FelipeR888 Feb 16, 2021
aaccc7b
Merge branch 'develop' into add_vis_colorcoding
yannikschaelte Feb 16, 2021
1692e63
Add separate color variables
FelipeR888 Feb 16, 2021
8b09ffa
Merge branch 'add_vis_colorcoding' of https://github.com/ICB-DCM/pyAB…
FelipeR888 Feb 16, 2021
d3e5802
Merge branch 'develop' into add_vis_colorcoding
FelipeR888 Feb 22, 2021
68eea61
Improved readability
FelipeR888 Feb 22, 2021
907b34f
Merge branch 'add_vis_colorcoding' of https://github.com/ICB-DCM/pyAB…
FelipeR888 Feb 22, 2021
c23058f
tox flake fixes
FelipeR888 Feb 22, 2021
a46fed4
add functions to reweight with opt ESS
FelipeR888 Feb 24, 2021
6679f89
Add functions for ESS maximizing reweighting
FelipeR888 Feb 24, 2021
0ee9f77
Merge branch 'develop' into add_vis_colorcoding
yannikschaelte Mar 9, 2021
804989e
Update pyabc/sampler/redis_eps/sampler.py
FelipeR888 Mar 9, 2021
9906ddc
Update pyabc/visualization/credible.py
FelipeR888 Mar 9, 2021
1f75654
Update pyabc/visualization/credible.py
FelipeR888 Mar 9, 2021
f638509
Update pyabc/visualization/credible.py
FelipeR888 Mar 9, 2021
9a947a4
Fix rescaling & add analytical solution
FelipeR888 Mar 9, 2021
b19caeb
Updated structure and added test
FelipeR888 Mar 23, 2021
729ddb6
Add graph to test
FelipeR888 Mar 23, 2021
92818df
Fix analytical solution
FelipeR888 Mar 30, 2021
db800e4
Merge branch 'develop' into add_vis_colorcoding
yannikschaelte May 10, 2021
ef3e968
tidy up
yannikschaelte May 10, 2021
cc34e3c
tidy up
yannikschaelte May 10, 2021
d829d80
tidy up
yannikschaelte May 10, 2021
7c4a288
tidy up
yannikschaelte May 10, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 50 additions & 29 deletions pyabc/sampler/redis_eps/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,13 @@ class RedisSamplerBase(Sampler):
logging module.
"""

def __init__(self,
host: str = "localhost",
port: int = 6379,
password: str = None,
log_file: str = None):
def __init__(
self,
host: str = "localhost",
port: int = 6379,
password: str = None,
log_file: str = None,
):
super().__init__()
logger.debug(
f"Redis sampler: host={host} port={port}")
Expand Down Expand Up @@ -175,16 +177,18 @@ class RedisEvalParallelSampler(RedisSamplerBase):
logging module.
"""

def __init__(self,
host: str = "localhost",
port: int = 6379,
password: str = None,
batch_size: int = 1,
look_ahead: bool = False,
look_ahead_delay_evaluation: bool = True,
max_n_eval_look_ahead_factor: float = 10.,
wait_for_all_samples: bool = False,
log_file: str = None):
def __init__(
self,
host: str = "localhost",
port: int = 6379,
password: str = None,
batch_size: int = 1,
look_ahead: bool = False,
look_ahead_delay_evaluation: bool = True,
max_n_eval_look_ahead_factor: float = 10.,
wait_for_all_samples: bool = False,
log_file: str = None,
):
super().__init__(
host=host, port=port, password=password, log_file=log_file)
self.batch_size: int = batch_size
Expand Down Expand Up @@ -315,8 +319,13 @@ def get_int(var: str):
return sample

def start_generation_t(
self, n: int, t: int, simulate_one: Callable, all_accepted: bool,
is_look_ahead: bool, max_n_eval_look_ahead: float = np.inf,
self,
n: int,
t: int,
simulate_one: Callable,
all_accepted: bool,
is_look_ahead: bool,
max_n_eval_look_ahead: float = np.inf,
) -> None:
"""Start generation `t`."""
ana_id = self.analysis_id
Expand Down Expand Up @@ -385,8 +394,13 @@ def clear_generation_t(self, t: int) -> None:
.execute())

def maybe_start_next_generation(
self, t: int, n: int, id_results: List, all_accepted: bool,
ana_vars: AnalysisVars) -> None:
self,
t: int,
n: int,
id_results: List,
all_accepted: bool,
ana_vars: AnalysisVars,
) -> None:
"""Start the next generation already, if that looks reasonable.

Parameters
Expand Down Expand Up @@ -445,7 +459,7 @@ def maybe_start_next_generation(

# create a preliminary simulate_one function
simulate_one_prel = create_preliminary_simulate_one(
t=t+1, population=population,
t=t + 1, population=population,
delay_evaluation=self.look_ahead_delay_evaluation,
ana_vars=ana_vars)

Expand All @@ -463,7 +477,7 @@ def maybe_start_next_generation(
# head-start the next generation
# all_accepted is most certainly False for t>0
self.start_generation_t(
n=n, t=t+1, simulate_one=simulate_one_prel,
n=n, t=t + 1, simulate_one=simulate_one_prel,
all_accepted=False, is_look_ahead=True,
max_n_eval_look_ahead=max_n_eval_look_ahead)

Expand All @@ -488,10 +502,12 @@ def create_sample(self, id_results: List[Tuple], n: int) -> Sample:
return sample

def check_analysis_variables(
self,
distance_function: Distance,
eps: Epsilon,
acceptor: Acceptor) -> None:
self,
distance_function: Distance,
eps: Epsilon,
acceptor: Acceptor,
) -> None:
""""Check analysis variables appropriateness for sampling."""
if self.look_ahead_delay_evaluation:
# nothing to be done
return
Expand All @@ -512,7 +528,7 @@ def check_bad(var):


def create_preliminary_simulate_one(
t, population, delay_evaluation: bool, ana_vars: AnalysisVars,
t, population, delay_evaluation: bool, ana_vars: AnalysisVars,
) -> Callable:
"""Create a preliminary simulate_one function for generation `t`.

Expand Down Expand Up @@ -555,8 +571,10 @@ def create_preliminary_simulate_one(
)


def post_check_acceptance(sample_with_id, ana_id, t, redis, ana_vars,
logger: RedisSamplerLogger) -> Tuple:
def post_check_acceptance(
sample_with_id, ana_id, t, redis, ana_vars,
logger: RedisSamplerLogger,
) -> Tuple:
"""Check whether the sample is really acceptable.

This is where evaluation of preliminary samples happens, using the analysis
Expand Down Expand Up @@ -667,7 +685,10 @@ def self_normalize_within_subpopulations(sample: Sample, n: int) -> Sample:


def _log_active_set(
redis: StrictRedis, ana_id: str, t: int, id_results: List[Tuple],
redis: StrictRedis,
ana_id: str,
t: int,
id_results: List[Tuple],
batch_size: int,
) -> None:
"""Log the status of active simulations after the first n acceptances."""
Expand Down
67 changes: 46 additions & 21 deletions pyabc/visualization/credible.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,14 @@ def plot_credible_intervals(
ts: Union[List[int], int] = None,
par_names: List = None,
levels: List = None,
colors: List = None,
color_median: str = None,
show_mean: bool = False,
color_mean: str = None,
show_kde_max: bool = False,
color_kde_max: str = None,
show_kde_max_1d: bool = False,
color_kde_max_1d: str = None,
size: tuple = None,
refval: dict = None,
refval_color: str = 'C1',
Expand All @@ -31,38 +36,50 @@ def plot_credible_intervals(

Parameters
----------
history: History
history:
The history to extract data from.
m: int, optional (default = 0)
m:
The id of the model to plot for.
ts: Union[List[int], int], optional (default = all)
ts:
The time points to plot for.
par_names: List[str], optional
par_names:
The parameter to plot for. If None, then all parameters are used.
levels: List[float], optional (default = [0.95])
Confidence intervals to compute.
show_mean: bool, optional (default = False)
levels:
Confidence intervals to compute. Default is [0.95].
colors:
Colors to use for the errorbars.
color_median:
Color to use for the median line.
show_mean:
Whether to show the mean apart from the median as well.
show_kde_max: bool, optional (default = False)
color_mean:
Color to use for the mean.
show_kde_max:
Whether to show the one of the sampled points that gives the highest
KDE value for the specified KDE.
Note: It is not attemtped to find the overall hightest KDE value, but
rather the sampled point with the highest value is taken as an
approximation (of the MAP-value).
show_kde_max_1d: bool, optional (default = False)
Same as `show_kde_max`, but here the KDE is applied componentwise.
size: tuple of float
color_kde_max:
Color to use for KDE max value.
show_kde_max_1d:
Same as `show_kde_max`, but here the KDE is applied component-wise.
color_kde_max_1d:
Color to use for the KDE max value.
size:
Size of the plot.
refval: dict, optional (default = None)
refval:
A dictionary of reference parameter values to plot for each of
`par_names`.
refval_color: str, optional
refval_color:
Color to use for the reference value.
kde: Transition, optional (default = MultivariateNormalTransition)
kde:
The KDE to use for `show_kde_max`.
kde_1d: Transition, optional (default = MultivariateNormalTransition)
Defaults to :class:`pyabc.MultivariateNormalTransition`.
kde_1d:
The KDE to use for `show_kde_max_1d`.
arr_ax: List, optional
Defaults to :class:`pyabc.MultivariateNormalTransition`.
arr_ax:
Array of axes to use. Assumed to be a 1-dimensional list.

Returns
Expand All @@ -72,6 +89,10 @@ def plot_credible_intervals(
if levels is None:
levels = [0.95]
levels = sorted(levels)
if colors is None:
colors = [None for _ in range(len(levels))]
if color_median is None:
color_median = colors[0]
if par_names is None:
# extract all parameter names
df, _ = history.get_distribution(m=m)
Expand Down Expand Up @@ -144,22 +165,26 @@ def plot_credible_intervals(
y=median[i_par].flatten(),
yerr=[median[i_par] - cis[i_par, :, i_c],
cis[i_par, :, -1 - i_c] - median[i_par]],
color=color_median,
ecolor=colors[i_c],
capsize=(5.0 / n_confidence) * (i_c + 1),
label="{:.2f}".format(confidence))
ax.set_title(f"Parameter {par}")
# mean
if show_mean:
ax.plot(range(n_pop), mean[i_par], 'x-', label="Mean")
ax.plot(range(n_pop), mean[i_par], 'x-',
label="Mean", color=color_mean)
# kde max
if show_kde_max:
ax.plot(range(n_pop), kde_max[i_par], 'x-', label="Max KDE")
ax.plot(range(n_pop), kde_max[i_par], 'x-',
label="Max KDE", color=color_kde_max)
if show_kde_max_1d:
ax.plot(range(n_pop), kde_max_1d[i_par], 'x-',
label="Max KDE 1d")
label="Max KDE 1d", color=color_kde_max_1d)
# reference value
if refval is not None:
ax.hlines(refval[par], xmin=0, xmax=n_pop - 1, color=refval_color,
label="Reference value")
ax.hlines(refval[par], xmin=0, xmax=n_pop - 1,
color=refval_color, label="Reference value")
ax.set_xticks(range(n_pop))
ax.set_xticklabels(ts)
ax.set_ylabel(par)
Expand Down