Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Latent SDE #104

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open

Adding Latent SDE #104

wants to merge 12 commits into from

Conversation

anh-tong
Copy link

Hi Patrick,

Following up on the last discussion, I create a pull request containing

  • Small change in diffrax.misc.sde_kl_divegence i.e., handling context and compute KL
  • Add a new notebook of Latent SDEs as a new file examples/neural_sde_vae.ipynb
  • Rename examples/neural_sde.ipynb to examples/neural_sde_gan.ipynb (fix link in the description as well)
  • Update mkdocs.yml

anh-tong added 3 commits May 11, 2022 17:33
+ add Latent SDE (notebook, mkdocs)
+ change neural_sde.ipynd to neural_sde_gan.ipynb
+ fix doc links according to the change
Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to bump the version number? If we update the docs we should do a new release to people can use the update sde_kl_divergence functionality.

inv_diffusion = jnp.linalg.pinv(diffusion)
scale = inv_diffusion @ (drift1 - drift2)
if diffusion.ndim == 1:
scale = (drift1 - drift2) / diffusion
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So my original code here in sde_kl_divergence was pretty hacky and not library-ready, and I think it'll still need some more work to get ready.

In particular I think it would make most sense to operate the level of terms. This would allow for abstracting over the kind of diffusion used -- e.g. ControlTerm versus WeaklyDiagonalControlTerm etc. -- rather than the current vector-field-based approach.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we may need to bump the version number. Here, I have changed sde_kl_divergence API from taking drift functions, a diffusion function ... into taking two MultiTerm. Although there is a duplication in control terms as they share the same, this sounds more natural as we compare two SDEs.

@@ -23,7 +26,7 @@ class _AugDrift(eqx.Module):
def __call__(self, t, y, args):
y, _ = y
context = self.context(t)
aug_y = jnp.concatenate([y, context], axis=-1)
aug_y = jnp.concatenate([y, context], axis=-1) if context is not None else y
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: flipping the if and else branches allows for switching if context is not None down toj ust if context is None.

@@ -725,7 +728,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
"version": "3.10.4"
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the spurious changes to this file be removed? You don't need to actually re-run a notebook when you just make changes to the documentation, as it's just a big JSON file you can edit.

"source": [
"# Neural SDE (VAE)\n",
"\n",
"This implementation is based on the Pytorch version of Latent SDE from [`torchsde`](https://github.com/google-research/torchsde/blob/master/examples/latent_sde.py) library. \n",
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: put this at the bottom of this introduction; the first line is the most important and this isn't the most important piece of information.

Honestly, the theory of latent SDEs is pretty nontrivial, and the remainder of this section pretty impenetrable, so I'd start off with a sentence here that just says something very simple to the effect of "this is a VAE".

If you want to give folks a readable reference for this topic then I'd recommend also adding a link to the appropriate section of On Neural Differential Equations. (I'm biased I suppose, but I definitely didn't find the original paper that clear on this front.)

"from diffrax import (MultiTerm, ODETerm, ControlTerm,\n",
" diffeqsolve, Euler,\n",
" SaveAt, VirtualBrownianTree)\n",
"from diffrax.misc import sde_kl_divergence\n",
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So anything not imported as diffrax.* is considered private API. If we're going to expose this publicly then sde_kl_divergence should be offered as diffrax.sde_kl_divergence.

"from diffrax.misc import sde_kl_divergence\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns"
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To minimise dependencies can we do without seaborn.

" maxval=1.6,\n",
" shape=(16,),\n",
" key=key))\n",
" ys = jnp.sin(ts * 2 * 3.14)[:, None] * 0.8\n",
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jnp.pi instead of 3.14.

@patrick-kidger
Copy link
Owner

Looks like the formatting is failing. Have a look at CONTRIBUTING.md.

I've not really gone through most of the example yet; I'll leave a proper review of that once everything so far has been organised.

I will say that I don't think I really believe what's happening here, mathematically. In the infinite-training limit you're just matching the SDE against a single trajectory, so it collapses to an ODE (zero noise). Have a look at the Lorenz example in torchsde for a more convincing (to me) example of training a latent SDE as a generative model, rather than this case which I think is pretty much just supervised learning.

(The real giveaway here is that you're using context=None in sde_kl_divergence. Whilst you don't have to have a context for the abstract notion of "KL divergence between two SDEs", you absolutely need one to have a meaningful latent SDE.)

@anh-tong
Copy link
Author

Thanks for the detailed review. I will get back on this after a few days :)

@anh-tong
Copy link
Author

anh-tong commented Jun 13, 2022

Sorry for taking so long.

In the recent commits, I have changed diffrax.misc.sde_kl_divergence where the vector field is contructed based on the control term of the input SDE as MultiTerm. I also make a simple unit test for this.

I've implemented the notebook of Latent SDE for Lorenz data as you suggested to make it more like VAE than just supervised learning. This takes some time for me to make it run. (It seems KL annealing is the trick to train this model)

Your other comments into the recent changes are included as well.

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, so -- sorry for taking so long to do a review! Thanks again for this implementation, which I think is now nearly there.

As you can see I've got one big comment against the KL divergence implementation, but I've also provided a possible solution. So I think tweaking that file should be straightforward.

Apparently the VAE code is too large for me to leave a diff against it line-by-line, so some comments here instead:

In generate_lorenz:

  1. Use jnp.stack([foo, bar]) rather than jnp.concatenate([foo[None], bar[None]).
  2. typo: normialize -> normalize

In the modules:

  1. The super().__init__() is essentially unnecessarily. (eqx.Module.__init__ does nothing). Good practice in Python is either (a) not to include super().__init__, and treat the class as final (meaning "not subclassable"), or (b) to include super().__init__ but also accept **kwargs in the __init__ and then forward them on as super().__init__(**kwargs); this is known as co-operative multiple inheritance.

In the training:

  1. I think we could probably train for less time; the samples get good enough about halfway through I think.

I really like the visualisation throughout training; that looks really cool.

In passing, it's interesting to note just how small the diffusion is in the learnt model; much smaller than the dataset. This has always been a big weakness of latent SDEs. I feel like there's probably a way to tweak the loss function to try and fix that somehow. (I'm just musing about an open research question here though.)

Overall I like both the sde_kl_divergence implementation in terms of terms, and the new example showing it off.


class _AugControlTerm(ControlTerm):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should inherit from AbstractTerm rather than ControlTerm. At the moment you're using both inheritance (from ControlTerm) and composition (passing in a ControlTerm instance as an argument); almost always you only ever need one of these approaches.

In this case I think composition is most natural, since the "base" ControlTerm already exists.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah you're right. I'll update it.

sde1: MultiTerm
sde2: MultiTerm
context: callable
kl: callable
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should require an explicit drift1, drift2, and diffusion here. (Rather than wrapping them in MultiTerms.)

I'm also thinking we can probably just remove context altogether? This isn't always used -- if you just want to compute the KL divergence between fixed SDEs -- and in the latent SDE case then the context can be handled via the args that are passed through. So better to have a simpler API I think.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I just realize now that we can check the class of diffusion. I will go back to the API with drift1, drift2, and diffusion.

kl: callable

def __init__(self, sde1, sde2, context) -> None:
super().__init__(sde1.terms[0].vector_field)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here also -- pick only one of composition or inheritance.

diffusion = self.sde1.terms[1].vf(t, y, args)
kl_divergence = jax.tree_map(self.kl, drift1, drift2, diffusion)
kl_divergence = jax.tree_util.tree_reduce(operator.add, kl_divergence)
return drift1, kl_divergence
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So unfortunately this approach is going to break whenever a custom diffusion term is used.

At the moment, the current implementation assumes that the diffusion is either a ControlTerm -- which produces a diffusion matrix, and does a diffusion-control product as a matrix-vector multiply -- or a WeaklyDiagonalControlTerm -- which produces a diagonal diffusion matrix, and does a diffusion-control product as a (diagonal-matrix)-vector multiply.

But the API around terms is specifically chosen so that the output of .vf(...) can really be anything (dense matrix; diagonal matrix; ... someone may also wish to write something special for tridiagonal matrices, sparse matrices etc.) so that diffusion could really be an arbitrarily-structured PyTree, for which .prod(...) is the only thing that knows how to consume it.

I'd need to think a lot harder about what the general case really is here. I'd welcome any thoughts on how that might be done, but if seems more complicated than you really want to get in to right now, a simple-but-inefficient approach is to forcibly materialise the diffusion as a matrix, ignoring any custom (diagonal/sparse/whatever) structure. This won't be efficient with user-specified control terms, but at the very least won't break.

Untested, but I think something like the following would work as an implementation.

def materialise_vf(t, y, args, contr, vf_prod):
    # Only used for its shape/dtype/structure; value is irrelevant
    control = contr(t, t)

    y_size = sum(np.size(yi) for yi in jax.tree_leaves(y))
    control_size = sum(np.size(ci) for ci in jax.tree_leaves(control))
    if y_size > control_size:
        make_jac = jax.jacfwd
    else:
        make_jac = jax.jacrev

    # Find the tree structure of vf_prod by smuggling it out as an additional
    # result from the Jacobian calculation.
    sentinel = vf_prod_tree = object()
    control_tree = jax.tree_structure(control)

    def _fn(_control):
        _out = vf_prod(t, y, args, _control)
        nonlocal vf_prod_tree
        structure = jax.tree_structure(_out)
        if vf_prod_tree is sentinel:
            vf_prod_tree = structure
        else:
            assert vf_prod_tree == structure
        return _out

    jac = make_jac(_fn)(control)
    assert vf_prod_tree is not sentinel
    if jax.tree_structure(None) in (vf_prod_tree, control_tree):
        # An unusual/not-useful edge case to handle.
        raise NotImplementedError(
            "`materialise_vf` not implemented for `None` controls or states."
        )
    return jax.tree_transpose(vf_prod_tree, control_tree, jac)

def _assert_array(x):
    if not isinstance(x, jnp.ndarray):
        raise NotImplementedError("`sde_kl_divergence` can only handle array-valued drifts and diffusions")

class _AugDrift(AbstractTerm):
    drift1: ODETerm
    drift2: ODETerm
    diffusion: AbstractTerm

   def vf(self, t: Scalar, y: PyTree, args: PyTree) -> PyTree:
        y, _ = y
        drift1 = self.drift1.vf(t, y, args)
        drift2 = self.drift2.vf(t, y, args)
        _assert_array(drift1)
        _assert_array(drift2)
        # Ugly hack special-casing built-in control terms.
        if isinstance(self.diffusion, WeaklyDiagonalControlTerm):
            diffusion = self.diffusion.vf(t, y, args)
            _assert_array(diffusion)
            kl_divergence = _kl_diagonal(drift1, drift2, diffusion)
        elif isinstance(self.diffusion, ControlTerm):
            diffusion = self.diffusion.vf(t, y, args)
            _assert_array(diffusion)
            kl_divergence = _kl(drift1, drift2, diffusion)
        else:
            # TODO: think about how to handle arbitrary control terms here, without forcibly
            # materialising the whole diffusion matrix. It'll require analysing `self.diffusion.prod` or 
            # `self.diffusion.vf_prod` and looking at its structure, I think? Or possibly extending the
            # `AbstractTerm` api to require specifying how to invert things?
            warnings.warn("`sde_kl_divergence` may be slow when used with custom diffusion terms")
            diffusion = materialise_vf(t, y, args, self.diffusion.contr, self.diffusion.vf_prod)
            _assert_array(diffusion)
            kl_divergence = _kl_general(drift1, drift2, diffusion)
        kl_divergence = jax.tree_util.tree_reduce(operator.add, kl_divergence)
        return drift1, kl_divergence

    @staticmethod
    def contr(t0: Scalar, t1: Scalar) -> Scalar:
        return t1 - t0

    @staticmethod
    def prod(vf: PyTree, control: Scalar) -> PyTree:
        return jax.tree_map(lambda v: control * v, vf)

This approach works, but falls short in the general case in two main respects. As already discussed the first is handling general AbstractTerms for the diffusion.

The second is more subtle, and is the reason for the _assert_array statements. The current approach of tree-map'ing isn't actually mathematically correct. By way of example, suppose we chose to represent the state/drift as a list-of-scalars (rather than a one-dimensional array), and the diffusion as a list-of-list-scalars instead of as a matrix. Then the tree-map'ing would unpack the first list in the diffusion term, but leave the second list in place. In _kl we'd then try to compute jnp.pinverse(...list of scalars...). Obviously that isn't programatically defined, but more importantly: regardless of how we adjust our implementation of _kl we could never compute the thing desired, as we need the whole diffusion to do the inversion, and right now we only have a single column.

Once again this is something I'd need to think hard about how to handle efficiently in the general case. (As an inefficient general-case implementation you could use jax.flatten_util.ravel_pytree, though -- I'd be happy to have that in there with a warnings.warn if that branch is taken, just like the above case.)

Phew! As you can tell, all of this gets nontrivial fast.

By the way: this approach of materialising the diffusion matrix is something that's come up before in other contexts. I copied the code for doing that from AdjointTerm. If you decide to include materialise_vf in your implementation then do it factor out and use it in both places.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for such an insightful comment! I understand better how diffrax works now.

I like the idea of extending AbstractTerm to invert matricies (vectors). I will try to go for this direction with the your suggested code here.

context: callable,
y0: PyTree,
bm: AbstractBrownianPath,
*, sde1: MultiTerm, sde2: MultiTerm, context: callable, y0: PyTree
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've mentioned it earlier, but just to reiterate since this is the public API: I'd make this API accept drift1: ODETerm, drift2: ODETerm, diffusion: AbstractTerm, since that's what we actually need.

(e.g. a MultiTerm could include an arbitrary number of terms)

@harrisonzhu508
Copy link

harrisonzhu508 commented Sep 9, 2022

Hi all, thank you for the great work! I was just trying to run the code from the pull request and encountered this error

NotImplementedError                       Traceback (most recent call last)
Cell In [13], line 4
      1 while iter < train_iters:
      2     # optimizing
      3     _, training_key = jrandom.split(training_key)
----> 4     loss, grads = make_step(latent_sde)
      5     loss = loss.item()
      6     updates, opt_state = optim.update(grads, opt_state)

File ~/diffrax/.env/lib/python3.8/site-packages/equinox/jit.py:95, in _JitWrapper.__call__(_JitWrapper__self, *args, **kwargs)
     94 def __call__(__self, *args, **kwargs):
---> 95     return __self._fun_wrapper(False, args, kwargs)

File ~/diffrax/.env/lib/python3.8/site-packages/equinox/jit.py:91, in _JitWrapper._fun_wrapper(self, is_lower, args, kwargs)
     89     return self._cached.lower(dynamic, static)
     90 else:
---> 91     dynamic_out, static_out = self._cached(dynamic, static)
     92     return combine(dynamic_out, static_out.value)

    [... skipping hidden 11 frame]

File ~/diffrax/.env/lib/python3.8/site-packages/jax/experimental/host_callback.py:1806, in <lambda>(j)
   1803 id_p.def_abstract_eval(lambda *args: args)
   1804 xla.register_translation(id_p, lambda ctx, avals_in, avals_out, *args: args)
...
   1687           )))
   1688 else:
-> 1689   raise NotImplementedError(f"outfeed rewrite {eqn.primitive}")

NotImplementedError: outfeed rewrite closed_call

Was not sure if there was something I missed? I didn't change anything in the code. This occurs when I call the make_step() function. The library versions are

equinox             0.7.1
jax                 0.3.17
jaxlib              0.3.15+cuda11.cudnn82
optax               0.1.3

Thanks in advance!

@patrick-kidger
Copy link
Owner

patrick-kidger commented Sep 9, 2022

Hmm. This looks like a bug in core JAX -- closed_call looks like a new JAX primitive, which jax.experimental.host_callback.call hasn't been updated to be able to handle.

I'm not sure when closed_call is used, but you should be able to construct a MWE by using jax.experimental.host_callback.call inside a function called by whatever-it-is that generates closed_call. (In particular this should be doable without Diffrax.)

@harrisonzhu508
Copy link

Hmm. This looks like a bug in core JAX -- closed_call looks like a new JAX primitive, which jax.experimental.host_callback.call hasn't been updated to be able to handle.

I'm not sure when closed_call is used, but you should be able to construct a MWE by using jax.experimental.host_callback.call inside a function called by whatever-it-is that generates closed_call. (In particular this should be doable without Diffrax.)

Thanks for the quick reply Patrick! I see, this makes sense. It's strange that the other examples in Diffrax do still seem to work, I'll investigate this a bit more.

@anh-tong
Copy link
Author

@patrick-kidger Sorry for taking so long. I will try to get back to this pull request in a couple of days. What I can think of now is to handle the case that diffusion matrices are diagonal.

@harrisonzhu508 I will take a look at the bug. If other examples do not have the problem, it may be because of the current implementation of latent SDE.

@anh-tong
Copy link
Author

Hi @harrisonzhu508, you can run the current code with jax=0.3.15 (while waiting for further update). The latest version of equinox=0.7.1 works fine.

As Patrick mentioned, it must be something to do with JAX core. I also find that the bug occurs when we use eqx.filter_value_and_grad (eqx.filter_jit still works). This is related to this pull request (jax-ml/jax#10711) and host_callback does not handle this yet.

@anh-tong
Copy link
Author

anh-tong commented Sep 21, 2022

Hi @patrick-kidger ,

I tried to make sde_kl_divergence can handle more general cases but I do not have a complete solution.

If I understand correctly, the goal of materialise_vf to find the PyTree structure of the output of vf_prod. However, such an output should agree with PyTree structure of the ouput of drift (as ODETerm) which we always can have access. Therefore, we may not need to materialise vf_prod. Also we may not need to convert everything into arrays.

If we can restrict our case where both drift and diffusion has the same PyTree structure (and leaf nodes are jnp.ndarray), we can simple handle block-diagonal diffusions using tree_map and checking the shape of array in leaf nodes.

The current implementation can handle block diagonal difussion matrices having PyTree as

drift = {
        "block1": jnp.zeros((2,)),
        "block2": jnp.zeros((2,)),
        "block3": jnp.zeros((3,)),
    }
diffusion = {
        "block1": jnp.ones((2,)),
        "block2": jnp.ones((2, 3)),
        "block3": jnp.ones((3, 4)),
    }

The first block corresponds to WeaklyDiagonalControlTerm. The remaining ones correspond to the general ControlTerm. I did not make any experiments on this part but a unit test to test this.

I also pass context using args in vf(t, y, args) but this may break the API that args should be PyTree while context is a function.

The difficulty I encounter when handling the more general case can be described in this code.

import jax.tree_util as jtu
import jax.numpy as jnp

vf_prod = {'block1': jnp.ones((2,)), 'block2': jnp.ones((1))}
diffusion = {'block1': jnp.ones((2,)), "block2": [[1., 1., 1.]]}

# vf_prod_tree obtained either from `materialise_vf` or input `drift`
vf_prod_tree = jtu.tree_structure(vf_prod) # PyTreeDef({'block1': *, 'block2': *})
diffusion_tree = jtu.tree_structure(diffusion) # PyTreeDef({'block1': *, 'block2': [[*, *, *]]})

transposed = jtu.tree_map(lambda *xs: list(xs), *[vf_prod, diffusion])
# PyTreeDef({'block1': [*, *], 'block2': [*, [[*, *, *]]]})

# next step is to convert the diffusion part of `block2` to array. But we don't know how
# maybe can use `is_leaf` in `jtu.tree_map`. But what is the condition to decide a leaf?

@harrisonzhu508
Copy link

Hi @harrisonzhu508, you can run the current code with jax=0.3.15 (while waiting for further update). The latest version of equinox=0.7.1 works fine.

As Patrick mentioned, it must be something to do with JAX core. I also find that the bug occurs when we use eqx.filter_value_and_grad (eqx.filter_jit still works). This is related to this pull request (google/jax#10711) and host_callback does not handle this yet.

Thanks a lot!

@harrisonzhu508
Copy link

Hi @anh-tong, thanks a lot for the very clean implementation again! I was trying to reproduce an example that is very similar to https://github.com/google-research/torchsde/blob/master/examples/latent_sde.py. Running the latter script yields the attached. But using your notebook implementation, I've noticed that the posterior sample paths seem to collapse to a deterministic function (even in intervals where there's no data), I was wondering if you noticed something similar too? Thanks a lot!

global_step_950

@anh-tong
Copy link
Author

anh-tong commented Sep 27, 2022

Hi @anh-tong, thanks a lot for the very clean implementation again! I was trying to reproduce an example that is very similar to https://github.com/google-research/torchsde/blob/master/examples/latent_sde.py. Running the latter script yields the attached. But using your notebook implementation, I've noticed that the posterior sample paths seem to collapse to a deterministic function (even in intervals where there's no data), I was wondering if you noticed something similar too? Thanks a lot!

Hi, I guess this happens because the current parameter setting with kl_anneal_iters = 1000 may not be suitable for the data in the plot. Please try kl_anneal_iters = 100 instead (like in torchsde).

kl_anneal_iters actually helps the training to figure out a good set of parameters at the early stage by prioritizing optimizing the likelihood over KL divergence.
Fitting the data in the figure is relatively simple so that it may not take so long to reach a part with a reasonable likelihood. The collapsing is explained in this paper(see Section 5) as the model only learns via likelihoods.

@harrisonzhu508
Copy link

Hi @anh-tong, thanks a lot for the very clean implementation again! I was trying to reproduce an example that is very similar to https://github.com/google-research/torchsde/blob/master/examples/latent_sde.py. Running the latter script yields the attached. But using your notebook implementation, I've noticed that the posterior sample paths seem to collapse to a deterministic function (even in intervals where there's no data), I was wondering if you noticed something similar too? Thanks a lot!

Hi, I guess this happens because the current parameter setting with kl_anneal_iters = 1000 may not be suitable for the data in the plot. Please try kl_anneal_iters = 100 instead (like in torchsde).

kl_anneal_iters actually helps the training to figure out a good set of parameters at the early stage by prioritizing optimizing the likelihood over KL divergence. Fitting the data in the figure is relatively simple so that it may not take so long to reach a part with a reasonable likelihood. The collapsing is explained in this paper(see Section 5) as the model only learns via likelihoods.

That makes sense, thanks for the explanation! I haven't got it working (I'm training on samples from a stochastic process) but I'll try and play around with the KL annealing!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants