Skip to content

Commit

Permalink
docs: add fmpe to tutorials, fix docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Aug 27, 2024
1 parent d18798e commit bc00bd5
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 50 deletions.
9 changes: 4 additions & 5 deletions sbi/neural_nets/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,8 @@ def flowmatching_nn(
be used for Flow Matching. The returned function is to be passed to the
Args:
model: The type of density estimator that will be created. One of [`mdn`,
`made`, `maf`, `maf_rqs`, `nsf`].
model: the type of regression network to learn the vector field. One of ['mlp',
'resnet'].
z_score_theta: Whether to z-score parameters $\theta$ before passing them into
the network, can take one of the following:
- `none`, or None: do not z-score.
Expand All @@ -239,9 +239,8 @@ def flowmatching_nn(
density estimator is a normalizing flow (i.e. currently either a `maf` or a
`nsf`). Ignored if density estimator is a `mdn` or `made`.
num_blocks: Number of blocks if a ResNet is used.
embedding_net: Optional embedding network for x.
num_components: Number of mixture components for a mixture of Gaussians.
Ignored if density estimator is not an mdn.
num_frequencies: Number of frequencies for the time embedding.
embedding_net: Optional embedding network for the condition.
kwargs: additional custom arguments passed to downstream build functions.
"""
implemented_models = ["mlp", "resnet"]
Expand Down
Loading

0 comments on commit bc00bd5

Please sign in to comment.