Skip to content

Commit

Permalink
fixing #115
Browse files Browse the repository at this point in the history
  • Loading branch information
dfm committed Oct 20, 2020
1 parent 115d6ee commit 4bac68e
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
6 changes: 6 additions & 0 deletions src/exoplanet/theano_ops/starry.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ def __init__(self):
self.ld = driver.LimbDark()
super().__init__()

def __getstate__(self):
return {}

def __setstate__(self, data):
self.ld = driver.LimbDark()

def make_node(self, *inputs):
in_args = [as_tensor_variable(i) for i in inputs]
if any(i.dtype != "float64" for i in in_args):
Expand Down
8 changes: 7 additions & 1 deletion tests/theano_ops/starry_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
import theano.tensor as tt
from theano.tests import unittest_tools as utt

from exoplanet.theano_ops.driver import SimpleLimbDark
from exoplanet.theano_ops.starry import (
GetCl,
GetClRev,
LimbDark,
RadiusFromOccArea,
)
from exoplanet.theano_ops.driver import SimpleLimbDark


class TestGetCl(utt.InferShapeTester):
Expand Down Expand Up @@ -104,6 +104,12 @@ def test_grad(self):
func = lambda *args: self.op(*args)[0] # NOQA
utt.verify_grad(func, in_args)

def test_pickle(self):
f, _, in_args = self.get_args()
data = pickle.dumps(self.op, -1)
new_op = pickle.loads(data)
utt.assert_allclose(f(*in_args), new_op(*in_args)[0].eval())


class TestRadiusFromOccArea(utt.InferShapeTester):
def setUp(self):
Expand Down

0 comments on commit 4bac68e

Please sign in to comment.