Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update sol save/load; add test #198

Merged
merged 9 commits into from
Mar 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion frank/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def load_sol(sol_file):
(see frank.radial_fitters.FrankFitter)
"""

sol = np.load(sol_file, allow_pickle=True)
with open(sol_file, 'rb') as f:
sol = pickle.load(f)

return sol

Expand Down
8 changes: 4 additions & 4 deletions frank/statistical_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ def __init__(self, DHT, geometry,
if scale_height is None:
raise ValueError('You requested a model with a non-zero scale height'
' but did not specify H(R) (scale_height=None)')
self._scale_height = scale_height
self._H2 = 0.5*(2*np.pi*scale_height(self.r) / rad_to_arcsec)**2
self._scale_height = scale_height(self.r)
self._H2 = 0.5*(2*np.pi*self._scale_height / rad_to_arcsec)**2

if self._verbose:
logging.info(' Assuming an optically thin model but geometrically: '
Expand Down Expand Up @@ -278,7 +278,7 @@ def check_hash(self, hash, multi_freq=False, geometry=None):
if hash[4] is None:
return False
else:
return np.alltrue(self._scale_height(self.r) == hash[4](self.r))
return np.alltrue(self._scale_height == hash[4])


def predict_visibilities(self, I, q, k=None, geometry=None):
Expand Down Expand Up @@ -500,7 +500,7 @@ def size(self):
def scale_height(self):
"Vertial thickness of the disc, unit = arcsec"
if self._scale_height is not None:
return self._scale_height(self.r)
return self._scale_height
else:
return None

Expand Down
34 changes: 31 additions & 3 deletions frank/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@
from frank.constants import rad_to_arcsec
from frank.hankel import DiscreteHankelTransform
from frank.radial_fitters import FourierBesselFitter, FrankFitter
from frank.debris_fitters import FrankDebrisFitter
from frank.geometry import (
FixedGeometry, FitGeometryGaussian, FitGeometryFourierBessel
)
from frank.constants import deg_to_rad
from frank.utilities import UVDataBinner, generic_dht
from frank.io import load_uvtable, save_uvtable
from frank.io import load_uvtable, save_uvtable, load_sol, save_fit
from frank.statistical_models import VisibilityMapping
from frank import fit

Expand Down Expand Up @@ -157,7 +157,6 @@ def load_AS209(uv_cut=None):

return uv_AS209_DSHARP, geometry


def test_fit_geometry():
"""Check the geometry fit on a subset of the AS209 data"""
AS209, _ = load_AS209()
Expand Down Expand Up @@ -426,6 +425,35 @@ def test_uvbin():
np.testing.assert_allclose(w, uvbin.weights[i])
np.testing.assert_allclose(len(widx), uvbin.bin_counts[i])

def test_save_load_sol():
"""Check saving/loading a frank 'sol' object"""
AS209, AS209_geometry = load_AS209(uv_cut=1e6)
u, v, vis, weights = [AS209[k][::100] for k in ['u', 'v', 'V', 'weights']]
Rmax, N = 1.6, 20

# generate a sol from a standard frank fit
FF = FrankFitter(Rmax, N, AS209_geometry, alpha=1.05, weights_smooth=1e-2)
sol = FF.fit(u, v, vis, weights)

# and from a frank debris fit (has additional keys over a standard fit sol)
FF_deb = FrankDebrisFitter(Rmax, N, AS209_geometry, lambda x : 0.05 * x,
alpha=1.05, weights_smooth=1e-2)
sol_deb = FF_deb.fit(u, v, vis, weights)

tmp_dir = '/tmp/frank/tests'
os.makedirs(tmp_dir, exist_ok=True)

save_prefix = [os.path.join(tmp_dir, 'standard'), os.path.join(tmp_dir, 'debris')]
sols = [sol, sol_deb]

for ii, jj in enumerate(save_prefix):
# save the 'sol' object
save_fit(u, v, vis, weights, sols[ii], prefix=jj,
save_profile_fit=False, save_vis_fit=False, save_uvtables=False
)
# load it
load_sol(jj + '_frank_sol.obj')


def _run_pipeline(geometry='gaussian', fit_phase_offset=True,
fit_inc_pa=True, make_figs=False,
Expand Down