Skip to content

Commit

Permalink
consistent default kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelgloeckler authored and michaeldeistler committed Aug 27, 2024
1 parent cc40464 commit fc52585
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 24 deletions.
27 changes: 8 additions & 19 deletions sbi/neural_nets/estimators/score_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,33 +190,22 @@ def loss(
# Compute MSE loss between network output and true score.
loss = torch.sum((score_pred - score_target) ** 2.0, dim=-1)

# For times -> 0 this loss has high variance, hece it can help a lot to add
# a control variate i.e. a term that has zero expectation but is strongly
# correlated with our objective. Todo so notice that the if we perform a
# 0 th order taylor expansion of the score network around the mean i.e.
# s(input_noised) = s(mean) + O(std), we get
# E_eps[||-eps/std - s(mean)||^2]
# = E_eps[||eps||^2/std^2 + 2<s(mean),eps/std> + ||s(mean)||^2]
# = E_eps[||eps||^]/std^2 + 2<s(mean), E_eps[eps]/std> + C
# Notice that we can calculate this expectation analytically which will evaluate
# to D/std^2.
# This allows us to add a control variate to the loss that is zero on
# expectation i.e. (eps**2/std**2 - 2s(mean)eps**T + s(mean)**2 - D/std**2).
# For small std the Taylor expansion will be good and the control variate will
# strongly correlate with the objective, hence reducing the variance of the
# estimator.
# For large std the Taylor expansion will be bad and the control variate will
# only loosely correlate with the objective, hence could potentially increase
# the variance of the estimator.
# Proposed in https://arxiv.org/pdf/2101.03288 (works very well for small std)
# For times -> 0 this loss has high variance a standard method to reduce the
# variance is to use a control variate i.e. a term that has zero expectation but
# is strongly correlated with our objective.
# Such a term can be derived by performing a 0 th order taylor expansion score
# network around the mean (https://arxiv.org/pdf/2101.03288 for details).
# NOTE: As it is a taylor expansion it will only work well for small std.

if control_variate:
D = input.shape[-1]
score_mean_pred = self.forward(mean, condition, times)
s = torch.squeeze(std, -1)

# Loss terms that depend on eps
term1 = 2 / s * torch.sum(eps * score_mean_pred, dim=-1)
term2 = torch.sum(eps**2, dim=-1) / s**2
# This term is the analytical expectation of the above term
term3 = D / s**2

control_variate = term3 - term1 - term2
Expand Down
4 changes: 2 additions & 2 deletions sbi/neural_nets/score_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ def build_score_estimator(
score_net: Optional[Union[str, nn.Module]] = "mlp",
z_score_x: Optional[str] = "independent",
z_score_y: Optional[str] = "independent",
t_embedding_dim: int = 32,
t_embedding_dim: int = 16,
num_layers: int = 3,
hidden_features: int = 100,
hidden_features: int = 50,
embedding_net_x: nn.Module = nn.Identity(),
embedding_net_y: nn.Module = nn.Identity(),
**kwargs,
Expand Down
6 changes: 4 additions & 2 deletions sbi/samplers/score/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ def __init__(
score_based_potential_gradient: A time-dependent score-based potential.
predictor: A predictor to propagate samples forward in time.
corrector (Ooptional): A corrector to refine the samples. Defaults to None.
predictor_params (optional): _description_. Defaults to None.
corrector_params (optional): _description_. Defaults to None.
predictor_params (optional): Parameters passed to the predictor, if given as
string. Defaults to None.
corrector_params (optional): Parameters passed to the corrector, if given as
string. Defaults to None.
"""
# Set predictor and corrector
self.set_predictor(predictor, score_based_potential, **(predictor_params or {}))
Expand Down
13 changes: 12 additions & 1 deletion tests/score_estimator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
@pytest.mark.parametrize("input_event_shape", ((1,), (4,)))
@pytest.mark.parametrize("condition_event_shape", ((1,), (7,)))
@pytest.mark.parametrize("batch_dim", (1, 10))
@pytest.mark.parametrize("score_net", ["mlp", "ada_mlp"])
def test_score_estimator_loss_shapes(
sde_type,
input_sample_dim,
input_event_shape,
condition_event_shape,
batch_dim,
score_net,
):
"""Test whether `loss` of DensityEstimators follow the shape convention."""
score_estimator, inputs, conditions = _build_score_estimator_and_tensors(
Expand All @@ -31,6 +33,7 @@ def test_score_estimator_loss_shapes(
condition_event_shape,
batch_dim,
input_sample_dim,
score_net=score_net,
)

losses = score_estimator.loss(inputs[0], condition=conditions)
Expand Down Expand Up @@ -65,8 +68,14 @@ def test_score_estimator_on_device(sde_type, device):
@pytest.mark.parametrize("input_event_shape", ((1,), (4,)))
@pytest.mark.parametrize("condition_event_shape", ((1,), (7,)))
@pytest.mark.parametrize("batch_dim", (1, 10))
@pytest.mark.parametrize("score_net", ["mlp", "ada_mlp"])
def test_score_estimator_forward_shapes(
sde_type, input_sample_dim, input_event_shape, condition_event_shape, batch_dim
sde_type,
input_sample_dim,
input_event_shape,
condition_event_shape,
batch_dim,
score_net,
):
"""Test whether `forward` of DensityEstimators follow the shape convention."""
score_estimator, inputs, conditions = _build_score_estimator_and_tensors(
Expand All @@ -75,6 +84,7 @@ def test_score_estimator_forward_shapes(
condition_event_shape,
batch_dim,
input_sample_dim,
score_net=score_net,
)
# Batched times
times = torch.rand((batch_dim,))
Expand Down Expand Up @@ -119,6 +129,7 @@ def _build_score_estimator_and_tensors(
sde_type=sde_type,
embedding_net_x=embedding_net_x,
embedding_net_y=embedding_net_y,
**kwargs,
)

inputs = building_thetas[:batch_dim]
Expand Down

0 comments on commit fc52585

Please sign in to comment.