Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xarray serialization error when using observed custom distribution #256

Open
dgegen opened this issue Apr 12, 2022 · 11 comments
Open

xarray serialization error when using observed custom distribution #256

dgegen opened this issue Apr 12, 2022 · 11 comments
Labels
bug Something isn't working

Comments

@dgegen
Copy link

dgegen commented Apr 12, 2022

Describe the bug
Unable to save traces to file, which is essential for running code on a cluster.

To Reproduce
Follow the case study Fitting TESS data (https://gallery.exoplanet.codes/tutorials/tess/) --
except using
lc = lc_file.remove_nans().remove_outliers().normalize() instead of
lc = lc_file.remove_nans().normalize().remove_outliers(), as the first order of transformations raised an unrelated error in my case.
After sampling, try to save the trace as one commonly saves arviz.InferenceData objects: trace.to_netcdf('results').
This will raise the following error

ValueError                                Traceback (most recent call last)
<ipython-input-23-c0e5828e59ee> in <module>
----> 1 trace.to_netcdf('results')

/cluster/apps/nss/gcc-8.2.0/python/3.9.9/x86_64/lib64/python3.9/site-packages/arviz/data/inference_data.py in to_netcdf(self, filename, compress, groups)
    390                 if compress:
    391                     kwargs["encoding"] = {var_name: {"zlib": True} for var_name in data.variables}
--> 392                 data.to_netcdf(filename, mode=mode, group=group, **kwargs)
    393                 data.close()
    394                 mode = "a"

/cluster/apps/nss/gcc-8.2.0/python/3.9.9/x86_64/lib64/python3.9/site-packages/xarray/core/dataset.py in to_netcdf(self, path, mode, format, group, engine, encoding, unlimited_dims, compute, invalid_netcdf)
   1900         from ..backends.api import to_netcdf
   1901
-> 1902         return to_netcdf(
   1903             self,
   1904             path,

/cluster/apps/nss/gcc-8.2.0/python/3.9.9/x86_64/lib64/python3.9/site-packages/xarray/backends/api.py in to_netcdf(dataset, path_or_file, mode, format, group, engine, encoding, unlimited_dims, compute, multifile, invalid_netcdf)
   1070         # TODO: allow this work (setting up the file for writing array data)
   1071         # to be parallelized with dask
-> 1072         dump_to_store(
   1073             dataset, store, writer, encoding=encoding, unlimited_dims=unlimited_dims
   1074         )

/cluster/apps/nss/gcc-8.2.0/python/3.9.9/x86_64/lib64/python3.9/site-packages/xarray/backends/api.py in dump_to_store(dataset, store, writer, encoder, encoding, unlimited_dims)
   1117         variables, attrs = encoder(variables, attrs)
   1118
-> 1119     store.store(variables, attrs, check_encoding, writer, unlimited_dims=unlimited_dims)
   1120
   1121

/cluster/apps/nss/gcc-8.2.0/python/3.9.9/x86_64/lib64/python3.9/site-packages/xarray/backends/common.py in store(self, variables, attributes, check_encoding_set, writer, unlimited_dims)
    259             writer = ArrayWriter()
    260
--> 261         variables, attributes = self.encode(variables, attributes)
    262
    263         self.set_attributes(attributes)

/cluster/apps/nss/gcc-8.2.0/python/3.9.9/x86_64/lib64/python3.9/site-packages/xarray/backends/common.py in encode(self, variables, attributes)
    348         # All NetCDF files get CF encoded by default, without this attempting
    349         # to write times, for example, would fail.
--> 350         variables, attributes = cf_encoder(variables, attributes)
    351         variables = {k: self.encode_variable(v) for k, v in variables.items()}
    352         attributes = {k: self.encode_attribute(v) for k, v in attributes.items()}

/cluster/apps/nss/gcc-8.2.0/python/3.9.9/x86_64/lib64/python3.9/site-packages/xarray/conventions.py in cf_encoder(variables, attributes)
    853     _update_bounds_encoding(variables)
    854
--> 855     new_vars = {k: encode_cf_variable(v, name=k) for k, v in variables.items()}
    856
    857     # Remove attrs from bounds variables (issue #2921)

/cluster/apps/nss/gcc-8.2.0/python/3.9.9/x86_64/lib64/python3.9/site-packages/xarray/conventions.py in <dictcomp>(.0)
    853     _update_bounds_encoding(variables)
    854
--> 855     new_vars = {k: encode_cf_variable(v, name=k) for k, v in variables.items()}
    856
    857     # Remove attrs from bounds variables (issue #2921)

/cluster/apps/nss/gcc-8.2.0/python/3.9.9/x86_64/lib64/python3.9/site-packages/xarray/conventions.py in encode_cf_variable(var, needs_copy, name)
    273     var = maybe_default_fill_value(var)
    274     var = maybe_encode_bools(var)
--> 275     var = ensure_dtype_not_object(var, name=name)
    276
    277     for attr_name in CF_RELATED_DATA:

/cluster/apps/nss/gcc-8.2.0/python/3.9.9/x86_64/lib64/python3.9/site-packages/xarray/conventions.py in ensure_dtype_not_object(var, name)
    231             data[missing] = fill_value
    232         else:
--> 233             data = _copy_with_dtype(data, dtype=_infer_dtype(data, name))
    234
    235         assert data.dtype.kind != "O" or data.dtype.metadata

/cluster/apps/nss/gcc-8.2.0/python/3.9.9/x86_64/lib64/python3.9/site-packages/xarray/conventions.py in _infer_dtype(array, name)
    165         return dtype
    166
--> 167     raise ValueError(
    168         "unable to infer dtype on variable {!r}; xarray "
    169         "cannot serialize arbitrary Python objects".format(name)

ValueError: unable to infer dtype on variable 'ecc_prior'; xarray cannot serialize arbitrary Python objects

Expected behavior
I expect the trace to be saved, as usual. I can individually save

trace.posterior.to_netcdf('posterior')
trace.log_likelihood.to_netcdf('log_likelihood')
trace.sample_stats.to_netcdf('sample_stats')

However, the same error is caused when trying to save
trace.observed_data.to_netcdf('observed_data')

My setup

  • Version of exoplanet: 0.5.2
  • Operating system: I reproduced this error once on Linux and once on macOS 10.15
  • Python version: python 3.9.9
  • Installation method: pip install -U "exoplanet[extras]"
  • Version of arviz: 0.11.4
  • Version of pymc3: 3.11.4

Has anyone encountered this problem before, knows how to solve it, or has a suggestion for a workaround?
Thank you very much in advance.

@dgegen dgegen added the bug Something isn't working label Apr 12, 2022
@dfm
Copy link
Member

dfm commented Apr 12, 2022

I have seen this before, and I don't totally understand why this happens. In the short term, I'd recommend using the groups argument to to_netcdf (care of @avivajpeyi):

trace.to_netcdf(
    filename=...,
    groups=["posterior", "log_likelihood", "sample_stats"],
)

Here's a simpler snippet that fails with the same issue:

import pymc3 as pm
import exoplanet as xo

with pm.Model() as model:
    ecc = pm.Uniform("ecc")
    xo.eccentricity.kipping13("ecc_prior", fixed=True, observed=ecc)
    trace = pm.sample(return_inferencedata=True)
    trace.to_netcdf("test")

This could be used to debug and find a longer term solution.

@dfm dfm changed the title Saving the trace xarray serialization error when using observed custom distribution Apr 12, 2022
@dgegen
Copy link
Author

dgegen commented Apr 12, 2022

@dfm Thank you very much for the quick reply, the boiled down version and the workaround. I would still be curious to know if trace.observed_data can also be saved / how the custom exoplanet distributions would need to be changed.

@dfm
Copy link
Member

dfm commented Apr 12, 2022

The issue is somehow related to the fact that observed_data is required to be fixed (i.e. the same for every sample), but here we're overloading the obs argument to be "tensor" that depends on the parameters. We could hack it to use pm.Potential or something, but then it would no longer work as a prior:

ecc = xo.eccentricity.kipping13("ecc_prior", fixed=True)

Perhaps there's some other PyMC3 trickery that we could use, but I don't know what it would look like!

@dgegen
Copy link
Author

dgegen commented Apr 13, 2022

I see. I played around some more and found that the problem of xarray serialization is not limited to the custom distributions of exoplanet but also occurs for celerite2.theano.GaussianProcess objects under certain circumstance:

import numpy as np
import exoplanet as xo
import pymc3 as pm
import aesara_theano_fallback.tensor as tt
from celerite2.theano import terms, GaussianProcess

x = np.linspace(-1, 1, 1000)
texp = x[1]-x[0]
true_orbit = xo.orbits.KeplerianOrbit(period=4, t0=0)
y = xo.LimbDarkLightCurve([0, 0]).get_light_curve(orbit=true_orbit, t=x, r= 0.1, texp=texp).eval()
y = y.reshape(-1) + np.random.normal(loc=0.0, scale=0.001, size=len(x))

with pm.Model() as model:
    t0 = pm.Normal("t0", mu=0.1, sd=0.5)
    orbit = xo.orbits.KeplerianOrbit(period=3.99, t0=t0)

    light_curves = xo.LimbDarkLightCurve([0, 0]).get_light_curve(orbit=orbit, r=0.1, t=x, texp=texp)
    resid = y - tt.sum(light_curves, axis=-1) 

    kernel = terms.SHOTerm(sigma=np.std(y), rho=np.std(y), Q=1/np.sqrt(2))
    gp = GaussianProcess(kernel, t=x, yerr=np.std(y))
    gp.marginal("gp", observed=resid)  # no error if resid is replaced by y

    # Sample
    trace = pm.sample(return_inferencedata=True, tune=100, draws=200)
    trace.to_netcdf("test")

raises the error

ValueError: unable to infer dtype on variable 'gp'; xarray cannot serialize arbitrary Python objects

However, there is no error if gp.marginal("gp", observed=resid) is replaced with gp.marginal("gp", observed=y) , i.e. we marginalise on a quantity that is independent of any random variables.

@dfm
Copy link
Member

dfm commented Apr 13, 2022

Yes. This is exactly the same issue. The "observed data" isn't a constant if it depends on the model!

@dgegen
Copy link
Author

dgegen commented Apr 13, 2022

You are right of course, I just wanted to point out that the problem is not limited to the custom distribution of exoplanet.

@dfm
Copy link
Member

dfm commented Apr 13, 2022

Haha good point. The celerite2 design (mistakes?) are the same and also my own :D

@dgegen
Copy link
Author

dgegen commented Apr 13, 2022

Somewhat off-topic:
Well, thank you very much for your great work. For tackling scientific problems, I find it always very pleasant to use higher-level abstraction languages like Python while still having the advantage of high performance of lower-level languages like C++. Accordingly, I am grateful that packages like celerite2 exist. I thought a few times that in the long run it might be worthwhile to consider translating exoplanet to Julia, thus, potentially avoiding the dependence on various other languages that is currently required which might increase performance. Needless to say that such an undertaking is of course also very labour-intensive and would take quite some time to realise.

On-topic:
I now asked in the PyMC forum here how this issue is usually handled and will let you know in case anyone comes up with a neat solution.

@dfm
Copy link
Member

dfm commented Apr 13, 2022

Thanks!

If you're interested in a Julia version, there has been some work on implementing at least some parts over in: https://github.com/JuliaAstro/Transits.jl
And we have this old implementation of celerite in Julia (I'm not sure how compatible it'll be with recent versions of Julia): https://github.com/ericagol/celerite.jl

@dgegen
Copy link
Author

dgegen commented Apr 13, 2022

Thanks for pointing out these projects, I will have a look at them!

@ericagol
Copy link
Collaborator

If you're interested in a Julia version, there has been some work on implementing at least some parts over in: https://github.com/JuliaAstro/Transits.jl

Also https://github.com/rodluger/Limbdark.jl

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants