diff --git a/frank/io.py b/frank/io.py index ab65b837..320d7572 100644 --- a/frank/io.py +++ b/frank/io.py @@ -131,7 +131,8 @@ def load_sol(sol_file): (see frank.radial_fitters.FrankFitter) """ - sol = np.load(sol_file, allow_pickle=True) + with open(sol_file, 'rb') as f: + sol = pickle.load(f) return sol diff --git a/frank/statistical_models.py b/frank/statistical_models.py index e2b98628..981625c7 100644 --- a/frank/statistical_models.py +++ b/frank/statistical_models.py @@ -102,8 +102,8 @@ def __init__(self, DHT, geometry, if scale_height is None: raise ValueError('You requested a model with a non-zero scale height' ' but did not specify H(R) (scale_height=None)') - self._scale_height = scale_height - self._H2 = 0.5*(2*np.pi*scale_height(self.r) / rad_to_arcsec)**2 + self._scale_height = scale_height(self.r) + self._H2 = 0.5*(2*np.pi*self._scale_height / rad_to_arcsec)**2 if self._verbose: logging.info(' Assuming an optically thin model but geometrically: ' @@ -278,7 +278,7 @@ def check_hash(self, hash, multi_freq=False, geometry=None): if hash[4] is None: return False else: - return np.alltrue(self._scale_height(self.r) == hash[4](self.r)) + return np.alltrue(self._scale_height == hash[4]) def predict_visibilities(self, I, q, k=None, geometry=None): @@ -500,7 +500,7 @@ def size(self): def scale_height(self): "Vertial thickness of the disc, unit = arcsec" if self._scale_height is not None: - return self._scale_height(self.r) + return self._scale_height else: return None diff --git a/frank/tests.py b/frank/tests.py index 0a98a94d..a9091cc7 100644 --- a/frank/tests.py +++ b/frank/tests.py @@ -25,12 +25,12 @@ from frank.constants import rad_to_arcsec from frank.hankel import DiscreteHankelTransform from frank.radial_fitters import FourierBesselFitter, FrankFitter +from frank.debris_fitters import FrankDebrisFitter from frank.geometry import ( FixedGeometry, FitGeometryGaussian, FitGeometryFourierBessel ) -from frank.constants import deg_to_rad from frank.utilities import UVDataBinner, generic_dht -from frank.io import load_uvtable, save_uvtable +from frank.io import load_uvtable, save_uvtable, load_sol, save_fit from frank.statistical_models import VisibilityMapping from frank import fit @@ -157,7 +157,6 @@ def load_AS209(uv_cut=None): return uv_AS209_DSHARP, geometry - def test_fit_geometry(): """Check the geometry fit on a subset of the AS209 data""" AS209, _ = load_AS209() @@ -426,6 +425,35 @@ def test_uvbin(): np.testing.assert_allclose(w, uvbin.weights[i]) np.testing.assert_allclose(len(widx), uvbin.bin_counts[i]) +def test_save_load_sol(): + """Check saving/loading a frank 'sol' object""" + AS209, AS209_geometry = load_AS209(uv_cut=1e6) + u, v, vis, weights = [AS209[k][::100] for k in ['u', 'v', 'V', 'weights']] + Rmax, N = 1.6, 20 + + # generate a sol from a standard frank fit + FF = FrankFitter(Rmax, N, AS209_geometry, alpha=1.05, weights_smooth=1e-2) + sol = FF.fit(u, v, vis, weights) + + # and from a frank debris fit (has additional keys over a standard fit sol) + FF_deb = FrankDebrisFitter(Rmax, N, AS209_geometry, lambda x : 0.05 * x, + alpha=1.05, weights_smooth=1e-2) + sol_deb = FF_deb.fit(u, v, vis, weights) + + tmp_dir = '/tmp/frank/tests' + os.makedirs(tmp_dir, exist_ok=True) + + save_prefix = [os.path.join(tmp_dir, 'standard'), os.path.join(tmp_dir, 'debris')] + sols = [sol, sol_deb] + + for ii, jj in enumerate(save_prefix): + # save the 'sol' object + save_fit(u, v, vis, weights, sols[ii], prefix=jj, + save_profile_fit=False, save_vis_fit=False, save_uvtables=False + ) + # load it + load_sol(jj + '_frank_sol.obj') + def _run_pipeline(geometry='gaussian', fit_phase_offset=True, fit_inc_pa=True, make_figs=False,