Skip to content

Commit

Permalink
Merge pull request #38 from LSSTDESC/fitchromaticpsf1d
Browse files Browse the repository at this point in the history
Fitchromaticpsf1d
  • Loading branch information
jeremyneveu authored Sep 26, 2019
2 parents a9fa07e + 986548b commit a00d276
Show file tree
Hide file tree
Showing 13 changed files with 1,061 additions and 689 deletions.
4 changes: 2 additions & 2 deletions runFitter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from spectractor import parameters
from spectractor.fit.fitter import run_minimisation, SpectrogramFitWorkspace
from spectractor.fit.fit_spectrogram import SpectrogramFitWorkspace, run_spectrogram_minimisation
from spectractor.config import load_config


Expand Down Expand Up @@ -36,6 +36,6 @@

w = SpectrogramFitWorkspace(file_name, atmgrid_file_name=atmgrid_filename, nsteps=2000,
burnin=1000, nbins=10, verbose=0, plot=False, live_fit=False)
run_minimisation(w, method="newton")
run_spectrogram_minimisation(w, method="newton")
# run_emcee(fit_workspace)
# fit_workspace.analyze_chains()
2 changes: 1 addition & 1 deletion spectractor/extractor/background.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def extract_background_photutils(data, err, ws=(20, 30), mask_signal_region=True
ax[0].set_title(f'Data background: mean={np.nanmean(bgd_bands):.3f}, std={np.nanstd(bgd_bands):.3f}')
ax[0].set_xlabel('X [pixels]')
ax[0].set_ylabel('Y [pixels]')
bkg.plot_meshes(outlines=True, color='#1f77b4', ax=ax[0])
bkg.plot_meshes(outlines=True, color='#1f77b4', axes=ax[0])
b = bkg.background
im = ax[1].imshow(b, origin='lower', aspect="auto")
ax[1].set_xlabel('X [pixels]')
Expand Down
307 changes: 242 additions & 65 deletions spectractor/extractor/psf.py

Large diffs are not rendered by default.

7 changes: 4 additions & 3 deletions spectractor/extractor/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,8 +910,9 @@ def shift_minimizer(params):
fwhm_func=fwhm_func, ax=None)
chisq += (shift * shift) / (parameters.PIXSHIFT_PRIOR / 2) ** 2
if parameters.DEBUG and parameters.DISPLAY:
spectrum.lambdas = lambdas_test
spectrum.plot_spectrum(live_fit=True, label=f'Order {spectrum.order:d} spectrum'
if parameters.LIVE_FIT:
spectrum.lambdas = lambdas_test
spectrum.plot_spectrum(live_fit=True, label=f'Order {spectrum.order:d} spectrum'
f'\nD={D:.2f}mm, shift={shift:.2f}pix')
return chisq

Expand Down Expand Up @@ -1051,7 +1052,7 @@ def extract_spectrum_from_image(image, spectrum, w=10, ws=(20, 30), right_edge=p
my_logger.info(f'\n\tStart PSF1D transverse fit...')
s = ChromaticPSF1D(Nx=Nx, Ny=Ny, deg=parameters.PSF_POLY_ORDER, saturation=image.saturation)
s.fit_transverse_PSF1D_profile(data, err, w, ws, pixel_step=1, sigma=5, bgd_model_func=bgd_model_func,
saturation=image.saturation, live_fit=parameters.DEBUG)
saturation=image.saturation, live_fit=False)

# Fill spectrum object
spectrum.pixels = np.arange(pixel_start, pixel_end, 1).astype(int)
Expand Down
404 changes: 404 additions & 0 deletions spectractor/fit/fit_spectrogram.py

Large diffs are not rendered by default.

138 changes: 138 additions & 0 deletions spectractor/fit/fit_spectrum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import numpy as np

from spectractor import parameters
from spectractor.config import set_logger
from spectractor.simulation.simulator import SimulatorInit, SpectrumSimulation
from spectractor.simulation.atmosphere import Atmosphere, AtmosphereGrid
from spectractor.fit.fitter import FitWorkspace


class SpectrumFitWorkspace(FitWorkspace):

def __init__(self, file_name, atmgrid_file_name="", nwalkers=18, nsteps=1000, burnin=100, nbins=10,
verbose=0, plot=False, live_fit=False, truth=None):
FitWorkspace.__init__(self, file_name, nwalkers, nsteps, burnin, nbins, verbose, plot, live_fit, truth=truth)
self.my_logger = set_logger(self.__class__.__name__)
self.spectrum, self.telescope, self.disperser, self.target = SimulatorInit(file_name)
self.airmass = self.spectrum.header['AIRMASS']
self.pressure = self.spectrum.header['OUTPRESS']
self.temperature = self.spectrum.header['OUTTEMP']
if atmgrid_file_name == "":
self.atmosphere = Atmosphere(self.airmass, self.pressure, self.temperature)
else:
self.use_grid = True
self.atmosphere = AtmosphereGrid(file_name, atmgrid_file_name)
if parameters.VERBOSE:
self.my_logger.info(f'\n\tUse atmospheric grid models from file {atmgrid_file_name}. ')
self.lambdas = self.spectrum.lambdas
self.data = self.spectrum.data
self.err = self.spectrum.err
self.A1 = 1.0
self.A2 = 0.05
self.ozone = 300.
self.pwv = 3
self.aerosols = 0.03
self.reso = 1.5
self.D = self.spectrum.header['D2CCD']
self.shift = self.spectrum.header['PIXSHIFT']
self.p = np.array([self.A1, self.A2, self.ozone, self.pwv, self.aerosols, self.reso, self.D, self.shift])
self.ndim = len(self.p)
self.input_labels = ["A1", "A2", "ozone", "PWV", "VAOD", "reso [pix]", r"D_CCD [mm]",
r"alpha_pix [pix]"]
self.axis_names = ["$A_1$", "$A_2$", "ozone", "PWV", "VAOD", "reso [pix]", r"$D_{CCD}$ [mm]",
r"$\alpha_{\mathrm{pix}}$ [pix]"]
self.bounds = [(0, 2), (0, 0.5), (0, 800), (0, 10), (0, 1), (1, 10), (50, 60), (-20, 20)]
if atmgrid_filename != "":
self.bounds[2] = (min(self.atmosphere.OZ_Points), max(self.atmosphere.OZ_Points))
self.bounds[3] = (min(self.atmosphere.PWV_Points), max(self.atmosphere.PWV_Points))
self.bounds[4] = (min(self.atmosphere.AER_Points), max(self.atmosphere.AER_Points))
self.nwalkers = max(2 * self.ndim, nwalkers)
self.simulation = SpectrumSimulation(self.spectrum, self.atmosphere, self.telescope, self.disperser)
self.get_truth()

def get_truth(self):
if 'A1' in list(self.spectrum.header.keys()):
A1_truth = self.spectrum.header['A1']
A2_truth = self.spectrum.header['A2']
ozone_truth = self.spectrum.header['OZONE']
pwv_truth = self.spectrum.header['PWV']
aerosols_truth = self.spectrum.header['VAOD']
reso_truth = self.spectrum.header['RESO']
D_truth = self.spectrum.header['D2CCD']
shift_truth = self.spectrum.header['X0SHIFT']
self.truth = (A1_truth, A2_truth, ozone_truth, pwv_truth, aerosols_truth,
reso_truth, D_truth, shift_truth)
else:
self.truth = None

def plot_spectrum_comparison_simple(self, ax, title='', extent=None, size=0.4):
lambdas = self.spectrum.lambdas
sub = np.where((lambdas > parameters.LAMBDA_MIN) & (lambdas < parameters.LAMBDA_MAX))
if extent is not None:
sub = np.where((lambdas > extent[0]) & (lambdas < extent[1]))
self.spectrum.plot_spectrum_simple(ax, lambdas=lambdas)
p0 = ax.plot(lambdas, self.model(lambdas), label='model')
ax.fill_between(lambdas, self.model - self.model_err,
self.model(lambdas) + self.model_err, alpha=0.3, color=p0[0].get_color())
# ax.plot(self.lambdas, self.model_noconv, label='before conv')
if title != '':
ax.set_title(title, fontsize=10)
ax.legend()
divider = make_axes_locatable(ax)
ax2 = divider.append_axes("bottom", size=size, pad=0)
ax.figure.add_axes(ax2)
residuals = (self.spectrum.data - self.model(lambdas)) / self.model(lambdas)
residuals_err = self.spectrum.err / self.model(lambdas)
ax2.errorbar(lambdas, residuals, yerr=residuals_err, fmt='ro', markersize=2)
ax2.axhline(0, color=p0[0].get_color())
ax2.grid(True)
residuals_model = self.model_err / self.model(lambdas)
ax2.fill_between(lambdas, -residuals_model, residuals_model, alpha=0.3, color=p0[0].get_color())
std = np.std(residuals[sub])
ax2.set_ylim([-2. * std, 2. * std])
ax2.set_xlabel(ax.get_xlabel())
ax2.set_ylabel('(data-fit)/fit')
ax2.set_xlim((lambdas[sub][0], lambdas[sub][-1]))
ax.set_xlim((lambdas[sub][0], lambdas[sub][-1]))
ax.set_ylim((0.9 * np.min(self.spectrum.data[sub]), 1.1 * np.max(self.spectrum.data[sub])))
ax.set_xticks(ax2.get_xticks()[1:-1])
ax.get_yaxis().set_label_coords(-0.15, 0.6)
ax2.get_yaxis().set_label_coords(-0.15, 0.5)

def simulate(self, A1, A2, ozone, pwv, aerosols, reso, D, shift):
lambdas, model, model_err = self.simulation.simulate(A1, A2, ozone, pwv, aerosols, reso, D, shift)
# if self.live_fit:
# self.plot_fit()
self.model = model
self.model_err = model_err
return model, model_err

def plot_fit(self):
fig = plt.figure(figsize=(12, 6))
ax1 = plt.subplot(222)
ax2 = plt.subplot(224)
ax3 = plt.subplot(121)
A1, A2, ozone, pwv, aerosols, reso, D, shift = self.p
self.title = f'A1={A1:.3f}, A2={A2:.3f}, PWV={pwv:.3f}, OZ={ozone:.3g}, VAOD={aerosols:.3f},\n ' \
f'reso={reso:.2f}pix, D={D:.2f}mm, shift={shift:.2f}pix '
# main plot
self.plot_spectrum_comparison_simple(ax3, title=self.title, size=0.8)
# zoom O2
self.plot_spectrum_comparison_simple(ax2, extent=[730, 800], title='Zoom $O_2$', size=0.8)
# zoom H2O
self.plot_spectrum_comparison_simple(ax1, extent=[870, 1000], title='Zoom $H_2 O$', size=0.8)
fig.tight_layout()
if self.live_fit:
plt.draw()
plt.pause(1e-8)
plt.close()
else:
if parameters.DISPLAY and parameters.VERBOSE:
plt.show()
if parameters.SAVE:
figname = self.filename.replace('.fits', '_bestfit.pdf')
self.my_logger.info(f'Save figure {figname}.')
fig.savefig(figname, dpi=100, bbox_inches='tight')

Loading

0 comments on commit a00d276

Please sign in to comment.