# Licensed under a 3-clause BSD style license - see LICENSE.rst
import astropy.units as u
import numpy as np
from .core import _prefit, lnprobmodel
from .extern.validator import validate_array
from .plot import _plot_data_to_ax, color_cycle
from .utils import sed_conversion, validate_data_table
__all__ = ["InteractiveModelFitter"]
def _process_model(model):
if (
isinstance(model, tuple) or isinstance(model, list)
) and not isinstance(model, np.ndarray):
return model[0]
else:
return model
[docs]class InteractiveModelFitter:
"""
Interactive model fitter using matplotlib widgets
Parameters
----------
modelfn : function
A function that takes a vector in the parameter space and the data
table, and returns the expected fluxes at the energies in the
spectrum.
p0 : array
Initial position vector.
data : `~astropy.table.Table` or list of `~astropy.table.Table`
Table or tables with observed spectra. Must follow the format described
in `naima.run_sampler`.
e_range : list of `~astropy.units.Quantity`, length 2, optional
Limits in energy for the computation of the model.
Note that setting this parameter will mean that the model output is
computed twice when `data` is provided: once for display using
`e_range` and once for computation of the log-probability using the
energy values of the spectra.
e_npoints : int, optional
How many points to compute for the model if `e_range` is set. Default
is 100.
labels : iterable of strings, optional
Labels for the parameters included in the position vector ``p0``. If
not provided ``['par1','par2', ... ,'parN']`` will be used.
sed : bool, optional
Whether to plot SED or differential spectrum.
auto_update : bool, optional
Whether to update the model plot when parameter sliders are changed.
Default is True and can also be changed through the GUI.
"""
def __init__(
self,
modelfn,
p0,
data=None,
e_range=None,
e_npoints=100,
labels=None,
sed=True,
auto_update=True,
):
import matplotlib.pyplot as plt
from matplotlib.widgets import Button, CheckButtons, Slider
self.pars = p0
self.P0_IS_ML = False
npars = len(p0)
if labels is None:
labels = ["par{0}".format(i) for i in range(npars)]
elif len(labels) < npars:
labels += ["par{0}".format(i) for i in range(len(labels), npars)]
self.hasdata = data is not None
self.data = None
if self.hasdata:
self.data = validate_data_table(data, sed=sed)
self.modelfn = modelfn
self.fig = plt.figure()
modelax = plt.subplot2grid((10, 4), (0, 0), rowspan=4, colspan=4)
if e_range is not None:
e_range = validate_array(
"e_range", u.Quantity(e_range), physical_type="energy"
)
energy = (
np.logspace(
np.log10(e_range[0].value),
np.log10(e_range[1].value),
e_npoints,
)
* e_range.unit
)
if self.hasdata:
energy = energy.to(self.data["energy"].unit)
else:
e_unit = e_range.unit
else:
energy = np.logspace(-4, 2, e_npoints) * u.TeV
e_unit = u.TeV
# Bogus flux array to send to model if not using data
if sed:
flux = np.zeros(e_npoints) * u.Unit("erg/(cm2 s)")
else:
flux = np.zeros(e_npoints) * u.Unit("1/(TeV cm2 s)")
if self.hasdata:
e_unit = self.data["energy"].unit
_plot_data_to_ax(self.data, modelax, sed=sed, e_unit=e_unit)
if e_range is None:
# use data for model
energy = self.data["energy"]
flux = self.data["flux"]
self.data_for_model = {"energy": energy, "flux": flux}
model = _process_model(self.modelfn(p0, self.data_for_model))
if self.hasdata:
if not np.array_equal(
self.data_for_model["energy"].to(u.eV).value,
self.data["energy"].to(u.eV).value,
):
# this will be slow, maybe interpolate already computed model?
model_for_lnprob = _process_model(
self.modelfn(self.pars, self.data)
)
else:
model_for_lnprob = model
lnprob = lnprobmodel(model_for_lnprob, self.data)
if isinstance(lnprob, u.Quantity):
lnprob = lnprob.decompose().value
self.lnprobtxt = modelax.text(
0.05,
0.05,
r"",
ha="left",
va="bottom",
transform=modelax.transAxes,
size=20,
)
self.lnprobtxt.set_text(
r"$\ln\mathcal{{L}} = {0:.1f}$".format(lnprob)
)
self.f_unit, self.sedf = sed_conversion(energy, model.unit, sed)
if self.hasdata:
datamin = (
self.data["energy"][0] - self.data["energy_error_lo"][0]
).to(e_unit).value / 3
xmin = min(datamin, energy[0].to(e_unit).value)
datamax = (
self.data["energy"][-1] + self.data["energy_error_hi"][-1]
).to(e_unit).value * 3
xmax = max(datamax, energy[-1].to(e_unit).value)
modelax.set_xlim(xmin, xmax)
else:
# plot_data_to_ax has not set ylabel
unit = self.f_unit.to_string("latex_inline")
if sed:
modelax.set_ylabel(r"$E^2 dN/dE$ [{0}]".format(unit))
else:
modelax.set_ylabel(r"$dN/dE$ [{0}]".format(unit))
modelax.set_xlim(energy[0].value, energy[-1].value)
(self.line,) = modelax.loglog(
energy.to(e_unit),
(model * self.sedf).to(self.f_unit),
lw=2,
c="k",
zorder=10,
)
modelax.set_xlabel(
"Energy [{0}]".format(energy.unit.to_string("latex_inline"))
)
paraxes = []
for n in range(npars):
paraxes.append(
plt.subplot2grid((2 * npars, 10), (npars + n, 0), colspan=7)
)
self.parsliders = []
slider_props = {"facecolor": color_cycle[-1], "alpha": 0.5}
for label, parax, valinit in zip(labels, paraxes, p0):
# Attempt to estimate reasonable parameter ranges from label
pmin, pmax = valinit / 10, valinit * 3
if "log" in label:
span = 2
if "norm" in label or "amplitude" in label:
# give more range for normalization parameters
span *= 2
pmin, pmax = valinit - span, valinit + span
elif ("index" in label) or ("alpha" in label):
if valinit > 0.0:
pmin, pmax = 0, 5
else:
pmin, pmax = -5, 0
elif "norm" in label or "amplitude" in label:
# norm without log, it will not be pretty because sliders are
# only linear
pmin, pmax = valinit / 100, valinit * 100
slider = Slider(
parax,
label,
pmin,
pmax,
valinit=valinit,
valfmt="%.4g",
**slider_props
)
slider.on_changed(self.update_if_auto)
self.parsliders.append(slider)
autoupdateax = plt.subplot2grid((8, 4), (4, 3), colspan=1, rowspan=1)
auto_update_check = CheckButtons(
autoupdateax, ("Auto update",), (auto_update,)
)
auto_update_check.on_clicked(self.update_autoupdate)
self.autoupdate = auto_update
updateax = plt.subplot2grid((8, 4), (5, 3), colspan=1, rowspan=1)
update_button = Button(updateax, "Update model")
update_button.on_clicked(self.update)
if self.hasdata:
fitax = plt.subplot2grid((8, 4), (6, 3), colspan=1, rowspan=1)
fit_button = Button(fitax, "Do Nelder-Mead fit")
fit_button.on_clicked(self.do_fit)
closeax = plt.subplot2grid((8, 4), (7, 3), colspan=1, rowspan=1)
close_button = Button(closeax, "Close window")
close_button.on_clicked(self.close_fig)
self.fig.subplots_adjust(top=0.98, right=0.98, bottom=0.02, hspace=0.2)
plt.show()
def update_autoupdate(self, label):
self.autoupdate = not self.autoupdate
def update_if_auto(self, val):
if self.autoupdate:
self.update(val)
def update(self, event):
# If we update, values have changed and P0 is not ML anymore
self.P0_IS_ML = False
self.pars = [slider.val for slider in self.parsliders]
model = _process_model(self.modelfn(self.pars, self.data_for_model))
self.line.set_ydata((model * self.sedf).to(self.f_unit))
if self.hasdata:
if not np.array_equal(
self.data_for_model["energy"].to(u.eV).value,
self.data["energy"].to(u.eV).value,
):
# this will be slow, maybe interpolate already computed model?
model = _process_model(self.modelfn(self.pars, self.data))
lnprob = lnprobmodel(model, self.data)
if isinstance(lnprob, u.Quantity):
lnprob = lnprob.decompose().value
self.lnprobtxt.set_text(
r"$\ln\mathcal{{L}} = {0:.1f}$".format(lnprob)
)
self.fig.canvas.draw_idle()
def do_fit(self, event):
self.pars = [slider.val for slider in self.parsliders]
self.pars, P0_IS_ML = _prefit(self.pars, self.data, self.modelfn, None)
autoupdate = self.autoupdate
self.autoupdate = False
if P0_IS_ML:
for slider, val in zip(self.parsliders, self.pars):
slider.set_val(val)
self.update("after_fit")
self.autoupdate = autoupdate
self.P0_IS_ML = P0_IS_ML
def close_fig(self, event):
import matplotlib.pyplot as plt
plt.close(self.fig)