# Licensed under a 3-clause BSD style license - see LICENSE.rst
from multiprocessing import Pool

import astropy.units as u
import numpy as np
from astropy import log
from emcee import autocorr

from .extern.validator import validate_array
from .utils import sed_conversion, validate_data_table

__all__ = ["plot_chain", "plot_fit", "plot_data", "plot_blob", "plot_corner"]

marker_cycle = ["o", "s", "d", "p", "*"]
# from seaborn: sns.color_palette('deep',6)
color_cycle = [
    (0.2980392156862745, 0.4470588235294118, 0.6901960784313725),
    (0.3333333333333333, 0.6588235294117647, 0.40784313725490196),
    (0.7686274509803922, 0.3058823529411765, 0.3215686274509804),
    (0.5058823529411764, 0.4470588235294118, 0.6980392156862745),
    (0.8, 0.7254901960784313, 0.4549019607843137),
    (0.39215686274509803, 0.7098039215686275, 0.803921568627451),

[docs] def plot_chain(sampler, p=None, **kwargs): """Generate a diagnostic plot of the sampler chains. Parameters ---------- sampler : `emcee.EnsembleSampler` Sampler containing the chains to be plotted. p : int (optional) Index of the parameter to plot. If omitted, all chains are plotted. last_step : bool (optional) Whether to plot the last step of the chain or the complete chain (default). Returns ------- figure : `matplotlib.figure.Figure` Figure """ if p is None: npars = sampler.get_chain().shape[-1] for pp in range(npars): _plot_chain_func(sampler, pp, **kwargs) fig = None else: fig = _plot_chain_func(sampler, p, **kwargs) return fig
def _latex_float(f, format=".3g"): """""" float_str = "{{0:{0}}}".format(format).format(f) if "e" in float_str: base, exponent = float_str.split("e") return r"{0}\times 10^{{{1}}}".format(base, int(exponent)) else: return float_str def round2(x, n): y = round(x, n) if n < 1: y = str(int(y)) else: # preserve trailing zeroes y = ("{{0:.{0}f}}".format(n)).format(x) return y def _latex_value_error(val, elo, ehi=0, tol=0.25): order = int(np.log10(abs(val))) if order > 2 or order < -2: val /= 10**order elo /= 10**order ehi /= 10**order else: order = 0 nlo = -int(np.floor(np.log10(elo))) if elo * 10**nlo < 2: nlo += 1 if ehi: # elo = round(elo,nlo) nhi = -int(np.floor(np.log10(ehi))) if ehi * 10**nhi < 2: nhi += 1 # ehi = round(ehi,nhi) if np.abs(elo - ehi) / ((elo + ehi) / 2.0) > tol: n = max(nlo, nhi) string = "{0}^{{+{1}}}_{{-{2}}}".format( *[round2(x, nn) for x, nn in zip([val, ehi, elo], [n, nhi, nlo])] ) else: e = (elo + ehi) / 2.0 n = -int(np.floor(np.log10(e))) if e * 10**n < 2: n += 1 string = "{0} \\pm {1}".format(*[round2(x, n) for x in [val, e]]) else: string = "{0} \\pm {1}".format(*[round2(x, nlo) for x in [val, elo]]) if order != 0: string = "(" + string + r")\times10^{{{0}}}".format(order) return "$" + string + "$" def _plot_chain_func(sampler, p, last_step=False): chain = sampler.get_chain() label = sampler.labels[p] import matplotlib.pyplot as plt from scipy import stats if len(chain.shape) > 2: # transpose from (step, walker) to (walker, step) traces = chain[:, :, p].T if last_step: # keep only last step dist = traces[:, -1] else: # convert chain to flatchain dist = traces.flatten() else: log.warning("we need the full chain to plot the traces, not a flatchain!") return None nwalkers = traces.shape[0] nsteps = traces.shape[1] f = plt.figure() ax1 = f.add_subplot(221) ax2 = f.add_subplot(122) f.subplots_adjust(left=0.1, bottom=0.15, right=0.95, top=0.9) # plot five percent of the traces darker if nwalkers < 60: thresh = 1 - 3.0 / nwalkers else: thresh = 0.95 red = np.arange(nwalkers) / float(nwalkers) >= thresh ax1.set_rasterization_zorder(1) for t in traces[~red]: # range(nwalkers): ax1.plot(t, color=(0.1,) * 3, lw=1.0, alpha=0.25, zorder=0) for t in traces[red]: ax1.plot(t, color=color_cycle[0], lw=1.5, alpha=0.75, zorder=0) ax1.set_xlabel("step number") # [l.set_rotation(45) for l in ax1.get_yticklabels()] ax1.set_ylabel(label) ax1.yaxis.set_label_coords(-0.15, 0.5) ax1.set_title("Walker traces") nbins = min(max(25, int(len(dist) / 100.0)), 100) xlabel = label n, x, _ = ax2.hist( dist, nbins, histtype="stepfilled", color=color_cycle[0], lw=0, density=True, ) kde = stats.gaussian_kde(dist) ax2.plot(x, kde(x), color="k", label="KDE") quant = [16, 50, 84] xquant = np.percentile(dist, quant) quantiles = dict(zip(quant, xquant)) ax2.axvline( quantiles[50], ls="--", color="k", alpha=0.5, lw=2, label="50% quantile", ) ax2.axvspan( quantiles[16], quantiles[84], color=(0.5,) * 3, alpha=0.25, label="68% CI", lw=0, ) # ax2.legend() for xticklabel in ax2.get_xticklabels(): xticklabel.set_rotation(45) ax2.set_xlabel(xlabel) ax2.xaxis.set_label_coords(0.5, -0.1) ax2.set_title("posterior distribution") ax2.set_ylim(top=n.max() * 1.05) # Print distribution parameters on lower-left try: autocorr_message = "{0:.1f}".format(autocorr.integrated_time(chain)[p]) except autocorr.AutocorrError: # Raised when chain is too short for meaningful auto-correlation # estimation autocorr_message = None if last_step: clen = "last ensemble" else: clen = "whole chain" chain_props = "Walkers: {0} \nSteps in chain: {1} \n".format(nwalkers, nsteps) if autocorr_message is not None: chain_props += "Autocorrelation time: {0}\n".format(autocorr_message) chain_props += ( "Mean acceptance fraction: {0:.3f}\n".format( np.mean(sampler.acceptance_fraction) ) + "Distribution properties for the {clen}:\n \ $-$ median: ${median}$, std: ${std}$ \n \ $-$ median with uncertainties based on \n \ the 16th and 84th percentiles ($\\sim$1$\\sigma$):\n".format( median=_latex_float(quantiles[50]), std=_latex_float(np.std(dist)), clen=clen, ) ) info_line = ( " " * 10 + label + " = " + _latex_value_error( quantiles[50], quantiles[50] - quantiles[16], quantiles[84] - quantiles[50], ) ) chain_props += info_line if "log10(" in label or "log(" in label: nlabel = label.split("(")[-1].split(")")[0] ltype = label.split("(")[0] if ltype == "log10": new_dist = 10**dist elif ltype == "log": new_dist = np.exp(dist) quant = [16, 50, 84] quantiles = dict(zip(quant, np.percentile(new_dist, quant))) label_template = "\n" + " " * 10 + "{{label:>{0}}}".format(len(label)) new_line = label_template.format(label=nlabel) new_line += " = " + _latex_value_error( quantiles[50], quantiles[50] - quantiles[16], quantiles[84] - quantiles[50], ) chain_props += new_line info_line += new_line"{0:-^50}\n".format(label) + info_line) f.text(0.05, 0.45, chain_props, ha="left", va="top") return f def _process_blob(sampler, modelidx, last_step=False, energy=None): """ Process binary blob in sampler. If blob in position modelidx is: - a Quantity array of len(blob[i])=len(data['energy']: use blob as model, data['energy'] as modelx - a tuple: use first item as modelx, second as model - a Quantity scalar: return array of scalars """ # Allow process blob to be used by _calc_samples and _calc_ML by sending # only blobs, not full sampler try: blobs = sampler.get_blobs() blob0 = blobs[-1][0][modelidx] energy =["energy"] except AttributeError: blobs = [sampler] blob0 = sampler[0][modelidx] last_step = True if isinstance(blob0, u.Quantity): if blob0.size == energy.size: # Energy array for blob is not provided, use data['energy'] modelx = energy elif blob0.size == 1: modelx = None if last_step: model = u.Quantity([m[modelidx] for m in blobs[-1]]) else: model = [] for step in blobs: for walkerblob in step: model.append(walkerblob[modelidx]) model = u.Quantity(model) elif np.isscalar(blob0): modelx = None if last_step: model = u.Quantity([m[modelidx] for m in blobs[-1]]) else: model = [] for step in blobs: for walkerblob in step: model.append(walkerblob[modelidx]) model = u.Quantity(model) elif isinstance(blob0, list) or isinstance(blob0, tuple): if ( len(blob0) == 2 and isinstance(blob0[0], u.Quantity) and isinstance(blob0[1], u.Quantity) ): # Energy array for model is item 0 in blob, model flux is item 1 modelx = blob0[0] if last_step: model = u.Quantity([m[modelidx][1] for m in blobs[-1]]) else: model = [] for step in blobs: for walkerblob in step: model.append(walkerblob[modelidx][1]) model = u.Quantity(model) else: raise TypeError("Model {0} has wrong blob format".format(modelidx)) else: raise TypeError("Model {0} has wrong blob format".format(modelidx)) return modelx, model def _read_or_calc_samples( sampler, modelidx=0, n_samples=100, last_step=False, e_range=None, e_npoints=100, threads=None, ): """Get samples from blob or compute them from chain and sampler.modelfn""" if e_range is None: # return the results saved in blobs modelx, model = _process_blob(sampler, modelidx, last_step=last_step) else: # prepare bogus data for calculation e_range = validate_array("e_range", u.Quantity(e_range), physical_type="energy") e_unit = e_range.unit energy = ( np.logspace( np.log10(e_range[0].value), np.log10(e_range[1].value), e_npoints, ) * e_unit ) data = { "energy": energy, "flux": np.zeros(energy.shape) *["flux"].unit, } # init pool and select parameters chain = sampler.get_chain()[-1] if last_step else sampler.get_chain(flat=True) pars = chain[np.random.randint(len(chain), size=n_samples)] args = ((p, data) for p in pars) blobs = [] with Pool(threads) as pool: modelouts = pool.starmap(sampler.modelfn, args) for modelout in modelouts: if isinstance(modelout, np.ndarray): blobs.append([modelout]) else: blobs.append(modelout) modelx, model = _process_blob(blobs, modelidx=modelidx, energy=data["energy"]) return modelx, model def _calc_ML(sampler, modelidx=0, e_range=None, e_npoints=100): """Get ML model from blob or compute them from chain and sampler.modelfn""" ML, MLp, MLerr, ML_model = find_ML(sampler, modelidx) if e_range is not None: # prepare bogus data for calculation e_range = validate_array("e_range", u.Quantity(e_range), physical_type="energy") e_unit = e_range.unit energy = ( np.logspace( np.log10(e_range[0].value), np.log10(e_range[1].value), e_npoints, ) * e_unit ) data = { "energy": energy, "flux": np.zeros(energy.shape) *["flux"].unit, } modelout = sampler.modelfn(MLp, data) if isinstance(modelout, np.ndarray): blob = modelout else: blob = modelout[modelidx] if isinstance(blob, u.Quantity): modelx = data["energy"].copy() model_ML = blob.copy() elif len(blob) == 2: modelx = blob[0].copy() model_ML = blob[1].copy() else: raise TypeError("Model {0} has wrong blob format".format(modelidx)) ML_model = (modelx, model_ML) return ML, MLp, MLerr, ML_model def _calc_CI( sampler, modelidx=0, confs=[3, 1], last_step=False, e_range=None, e_npoints=100, threads=None, ): """Calculate confidence interval.""" from scipy import stats # If we are computing the samples for the confidence intervals, we need at # least one sample to constrain the highest confidence band # 1 sigma -> 6 samples # 2 sigma -> 43 samples # 3 sigma -> 740 samples # 4 sigma -> 31574 samples # 5 sigma -> 3488555 samples # We limit it to 1000 samples and warn that it might not be enough if e_range is not None: maxconf = np.max(confs) minsamples = min(100, int(1 / stats.norm.cdf(-maxconf) + 1)) if minsamples > 1000: log.warning( "In order to sample the confidence band for {0} sigma," " {1} new samples need to be computed, but we are limiting" " it to 1000 samples, so the confidence band might not be" " well constrained." " Consider reducing the maximum" " confidence significance or using the samples stored in" " the sampler by setting e_range" " to None".format(maxconf, minsamples) ) minsamples = 1000 else: minsamples = None modelx, model = _read_or_calc_samples( sampler, modelidx, last_step=last_step, e_range=e_range, e_npoints=e_npoints, n_samples=minsamples, threads=threads, ) nwalkers = len(model) - 1 CI = [] for conf in confs: fmin = stats.norm.cdf(-conf) fmax = stats.norm.cdf(conf) ymin, ymax = [], [] for fr, y in ((fmin, ymin), (fmax, ymax)): nf = int((fr * nwalkers)) for i in range(len(modelx)): ysort = np.sort(model[:, i]) y.append(ysort[nf]) # create an array from lists ymin and ymax preserving units CI.append((u.Quantity(ymin), u.Quantity(ymax))) return modelx, CI def _plot_MLmodel(ax, sampler, modelidx, e_range, e_npoints, e_unit, sed): """compute and plot ML model""" ML, MLp, MLerr, ML_model = _calc_ML( sampler, modelidx, e_range=e_range, e_npoints=e_npoints ) f_unit, sedf = sed_conversion(ML_model[0], ML_model[1].unit, sed) ax.loglog( ML_model[0].to(e_unit).value, (ML_model[1] * sedf).to(f_unit).value, color="k", lw=2, alpha=0.8, ) def plot_CI( ax, sampler, modelidx=0, sed=True, confs=[3, 1, 0.5], e_unit=u.eV, label=None, e_range=None, e_npoints=100, threads=None, last_step=False, ): """Plot confidence interval. Parameters ---------- ax : `matplotlib.Axes` Axes to plot on. sampler : `emcee.EnsembleSampler` Sampler modelidx : int, optional Model index. Default is 0 sed : bool, optional Whether to plot SED or differential spectrum. If `None`, the units of the observed spectrum will be used. confs : list, optional List of confidence levels (in sigma) to use for generating the confidence intervals. Default is `[3,1,0.5]` e_unit : :class:`~astropy.units.Unit` or str parseable to unit Unit in which to plot energy axis. e_npoints : int, optional How many points to compute for the model samples and ML model if `e_range` is set. threads : int, optional How many parallel processing threads to use when computing the samples. Defaults to the number of available cores. last_step : bool, optional Whether to only use the positions in the final step of the run (True, default) or the whole chain (False). """ confs.sort(reverse=True) modelx, CI = _calc_CI( sampler, modelidx=modelidx, confs=confs, e_range=e_range, e_npoints=e_npoints, last_step=last_step, threads=threads, ) # pick first confidence interval curve for units f_unit, sedf = sed_conversion(modelx, CI[0][0].unit, sed) for (ymin, ymax), conf in zip(CI, confs): color = np.log(conf) / np.log(20) + 0.4 ax.fill_between(, (ymax * sedf).to(f_unit).value, (ymin * sedf).to(f_unit).value, lw=0.001, color=(color,) * 3, alpha=0.6, zorder=-10, ) _plot_MLmodel(ax, sampler, modelidx, e_range, e_npoints, e_unit, sed) if label is not None: ax.set_ylabel("{0} [{1}]".format(label, f_unit.to_string("latex_inline"))) def plot_samples( ax, sampler, modelidx=0, sed=True, n_samples=100, e_unit=u.eV, e_range=None, e_npoints=100, threads=None, label=None, last_step=False, ): """Plot a number of samples from the sampler chain. Parameters ---------- ax : `matplotlib.Axes` Axes to plot on. sampler : `emcee.EnsembleSampler` Sampler modelidx : int, optional Model index. Default is 0 sed : bool, optional Whether to plot SED or differential spectrum. If `None`, the units of the observed spectrum will be used. n_samples : int, optional Number of samples to plot. Default is 100. e_unit : :class:`~astropy.units.Unit` or str parseable to unit Unit in which to plot energy axis. e_range : list of `~astropy.units.Quantity`, length 2, optional Limits in energy for the computation of the model samples and ML model. Note that setting this parameter will mean that the samples for the model are recomputed and depending on the model speed might be quite slow. e_npoints : int, optional How many points to compute for the model samples and ML model if `e_range` is set. threads : int, optional How many parallel processing threads to use when computing the samples. Defaults to the number of available cores. last_step : bool, optional Whether to only use the positions in the final step of the run (True, default) or the whole chain (False). """ modelx, model = _read_or_calc_samples( sampler, modelidx, last_step=last_step, e_range=e_range, e_npoints=e_npoints, threads=threads, ) # pick first model sample for units f_unit, sedf = sed_conversion(modelx, model[0].unit, sed) sample_alpha = min(5.0 / n_samples, 0.5) for my in model[np.random.randint(len(model), size=n_samples)]: ax.loglog(, (my * sedf).to(f_unit).value, color=(0.1,) * 3, alpha=sample_alpha, lw=1.0, ) _plot_MLmodel(ax, sampler, modelidx, e_range, e_npoints, e_unit, sed) if label is not None: ax.set_ylabel("{0} [{1}]".format(label, f_unit.to_string("latex_inline"))) def find_ML(sampler, modelidx): """ Find Maximum Likelihood parameters as those in the chain with a highest log probability. """ lnprobability = sampler.get_log_prob() index = np.unravel_index(np.argmax(lnprobability), lnprobability.shape) MLp = sampler.get_chain()[index] blobs = sampler.get_blobs() if modelidx is not None and blobs is not None: blob = blobs[index][modelidx] if isinstance(blob, u.Quantity): modelx =["energy"].copy() model_ML = blob.copy() elif len(blob) == 2: modelx = blob[0].copy() model_ML = blob[1].copy() else: raise TypeError("Model {0} has wrong blob format".format(modelidx)) elif modelidx is not None and hasattr(sampler, "modelfn"): blob = _process_blob( [sampler.modelfn(MLp,], modelidx,["energy"], ) modelx, model_ML = blob[0], blob[1][0] else: modelx, model_ML = None, None MLerr = [] for dist in sampler.get_chain(flat=True).T: hilo = np.percentile(dist, [16.0, 84.0]) MLerr.append((hilo[1] - hilo[0]) / 2.0) ML = lnprobability[index] return ML, MLp, MLerr, (modelx, model_ML)
[docs] def plot_blob(sampler, blobidx=0, label=None, last_step=False, figure=None, **kwargs): """ Plot a metadata blob as a fit to spectral data or value distribution Additional ``kwargs`` are passed to `plot_fit`. Parameters ---------- sampler : `emcee.EnsembleSampler` Sampler with a stored chain. blobidx : int, optional Metadata blob index to plot. label : str, optional Label for the value distribution. Labels for the fit plot can be passed as ``xlabel`` and ``ylabel`` and will be passed to `plot_fit`. Returns ------- figure : `matplotlib.pyplot.Figure` `matplotlib` figure instance containing the plot. """ modelx, model = _process_blob(sampler, blobidx, last_step) if label is None: label = "Model output {0}".format(blobidx) if modelx is None: # Blob is scalar, plot distribution f = plot_distribution(model, label, figure=figure) else: f = plot_fit( sampler, modelidx=blobidx, last_step=last_step, label=label, figure=figure, **kwargs, ) return f
[docs] def plot_fit( sampler, modelidx=0, label=None, sed=True, last_step=False, n_samples=100, confs=None, ML_info=False, figure=None, plotdata=None, plotresiduals=None, e_unit=None, e_range=None, e_npoints=100, threads=None, xlabel=None, ylabel=None, ulim_opts={}, errorbar_opts={}, ): """ Plot data with fit confidence regions. Parameters ---------- sampler : `emcee.EnsembleSampler` Sampler with a stored chain. modelidx : int, optional Model index to plot. label : str, optional Label for the title of the plot. sed : bool, optional Whether to plot SED or differential spectrum. last_step : bool, optional Whether to use only the samples of the last step in the run when showing either the model samples or the confidence intervals. n_samples : int, optional If not ``None``, number of sample models to plot. If ``None``, confidence bands will be plotted instead of samples. Default is 100. confs : list, optional List of confidence levels (in sigma) to use for generating the confidence intervals. Default is to plot sample models instead of confidence bands. ML_info : bool, optional Whether to plot information about the maximum likelihood parameters and the standard deviation of their distributions. Default is True. figure : `matplotlib.figure.Figure`, optional `matplotlib` figure to plot on. If omitted a new one will be generated. plotdata : bool, optional Wheter to plot data on top of model confidence intervals. Default is True if the physical types of the data and the model match. plotresiduals : bool, optional Wheter to plot the residuals with respect to the maximum likelihood model. Default is True if ``plotdata`` is True and either ``confs`` or ``n_samples`` are set. e_unit : `~astropy.units.Unit`, optional Units for the energy axis of the plot. The default is to use the units of the energy array of the observed data. e_range : list of `~astropy.units.Quantity`, length 2, optional Limits in energy for the computation of the model samples and ML model. Note that setting this parameter will mean that the samples for the model are recomputed and depending on the model speed might be quite slow. e_npoints : int, optional How many points to compute for the model samples and ML model if `e_range` is set. threads : int, optional How many parallel processing threads to use when computing the samples. Defaults to the number of available cores. xlabel : str, optional Label for the ``x`` axis of the plot. ylabel : str, optional Label for the ``y`` axis of the plot. ulim_opts : dict Option for upper-limit plotting. Available options are capsize (arrow width) and height_fraction (arrow length in fraction of flux value). errorbar_opts : dict Addtional options to pass to `matplotlib.plt.errorbar` for plotting the spectral flux points. """ import matplotlib.pyplot as plt ML, MLp, MLerr, model_ML = find_ML(sampler, modelidx) infostr = "Maximum log probability: {0:.3g}\n".format(ML) infostr += "Maximum Likelihood values:\n" maxlen = np.max([len(ilabel) for ilabel in sampler.labels]) vartemplate = "{{2:>{0}}}: {{0:>8.3g}} +/- {{1:<8.3g}}\n".format(maxlen) for p, v, ilabel in zip(MLp, MLerr, sampler.labels): infostr += vartemplate.format(p, v, ilabel) # data = if e_range is None and not hasattr(sampler, "blobs"): e_range = data["energy"][[0, -1]] * np.array((1.0 / 3.0, 3.0)) if plotdata is None and len(model_ML[0]) == len(data["energy"]): model_unit, _ = sed_conversion(model_ML[0], model_ML[1].unit, sed) data_unit, _ = sed_conversion(data["energy"], data["flux"].unit, sed) plotdata = model_unit.is_equivalent(data_unit) elif plotdata is None: plotdata = False if plotresiduals is None and plotdata and (confs is not None or n_samples): plotresiduals = True if confs is None and not n_samples and plotdata and not plotresiduals: # We actually only want to plot the data, so let's go there return plot_data(, xlabel=xlabel, ylabel=ylabel, sed=sed, figure=figure, e_unit=e_unit, ulim_opts=ulim_opts, errorbar_opts=errorbar_opts, ) if figure is None: f = plt.figure() else: f = figure if plotdata and plotresiduals: ax1 = plt.subplot2grid((4, 1), (0, 0), rowspan=3) ax2 = plt.subplot2grid((4, 1), (3, 0), sharex=ax1) for subp in [ax1, ax2]: f.add_subplot(subp) else: ax1 = f.add_subplot(111) if e_unit is None: e_unit = data["energy"].unit if confs is not None: plot_CI( ax1, sampler, modelidx, sed=sed, confs=confs, e_unit=e_unit, label=label, e_range=e_range, e_npoints=e_npoints, last_step=last_step, threads=threads, ) elif n_samples: plot_samples( ax1, sampler, modelidx, sed=sed, n_samples=n_samples, e_unit=e_unit, label=label, e_range=e_range, e_npoints=e_npoints, last_step=last_step, threads=threads, ) else: # plot only ML model _plot_MLmodel(ax1, sampler, modelidx, e_range, e_npoints, e_unit, sed) xlaxis = ax1 if plotdata: _plot_data_to_ax( data, ax1, e_unit=e_unit, sed=sed, ylabel=ylabel, ulim_opts=ulim_opts, errorbar_opts=errorbar_opts, ) if plotresiduals: _plot_residuals_to_ax( data, model_ML, ax2, e_unit=e_unit, sed=sed, errorbar_opts=errorbar_opts, ) xlaxis = ax2 for tl in ax1.get_xticklabels(): tl.set_visible(False) xmin = 10 ** np.floor( np.log10(np.min(data["energy"] - data["energy_error_lo"]).to(e_unit).value) ) xmax = 10 ** np.ceil( np.log10(np.max(data["energy"] + data["energy_error_hi"]).to(e_unit).value) ) ax1.set_xlim(xmin, xmax) else: ax1.set_xscale("log") ax1.set_yscale("log") if sed: ndecades = 10 else: ndecades = 20 # restrict y axis to ndecades to avoid autoscaling deep exponentials xmin, xmax, ymin, ymax = ax1.axis() ymin = max(ymin, ymax / 10**ndecades) ax1.set_ylim(bottom=ymin) # scale x axis to largest model_ML x point within ndecades decades of # maximum f_unit, sedf = sed_conversion(model_ML[0], model_ML[1].unit, sed) hi = np.where((model_ML[1] * sedf).to(f_unit).value > ymin) xmax = np.max(model_ML[0][hi]) ax1.set_xlim(right=10 ** np.ceil(np.log10( if e_range is not None: # ensure that xmin/xmax contains e_range xmin, xmax, ymin, ymax = ax1.axis() xmin = min(xmin, e_range[0].to(e_unit).value) xmax = max(xmax, e_range[1].to(e_unit).value) ax1.set_xlim(xmin, xmax) if ML_info and (confs is not None or n_samples): ax1.text( 0.05, 0.05, infostr, ha="left", va="bottom", transform=ax1.transAxes, family="monospace", ) if label is not None: ax1.set_title(label) if xlabel is None: xlaxis.set_xlabel("Energy [{0}]".format(e_unit.to_string("latex_inline"))) else: xlaxis.set_xlabel(xlabel) f.subplots_adjust(hspace=0) return f
def _plot_ulims(ax, x, y, xerr, color, capsize=5, height_fraction=0.25, elinewidth=2): """ Plot upper limits as arrows with cap at value of upper limit. uplim behaviour has been fixed in matplotlib 1.4 """ ax.errorbar(x, y, xerr=xerr, ls="", color=color, elinewidth=elinewidth, capsize=0) ax.errorbar( x, y, yerr=height_fraction * y, ls="", uplims=True, color=color, elinewidth=elinewidth, capsize=capsize, zorder=10, ) def _plot_data_to_ax( data_all, ax1, e_unit=None, sed=True, ylabel=None, ulim_opts={}, errorbar_opts={}, ): """Plots data errorbars and upper limits onto ax. X label is left to plot_data and plot_fit because they depend on whether residuals are plotted. """ if e_unit is None: e_unit = data_all["energy"].unit f_unit, sedf = sed_conversion(data_all["energy"], data_all["flux"].unit, sed) if "group" not in data_all.keys(): data_all["group"] = np.zeros(len(data_all)) groups = np.unique(data_all["group"]) for g in groups: data = data_all[np.where(data_all["group"] == g)] _, sedfg = sed_conversion(data["energy"], data["flux"].unit, sed) # wrap around color and marker cycles color = color_cycle[int(g) % len(color_cycle)] marker = marker_cycle[int(g) % len(marker_cycle)] ul = data["ul"] notul = ~ul # Hack to show y errors compatible with 0 in loglog plot yerr_lo = data["flux_error_lo"][notul] y = data["flux"][notul].to(yerr_lo.unit) bad_err = np.where((y - yerr_lo) <= 0.0) yerr_lo[bad_err] = y[bad_err] * (1.0 - 1e-7) yerr = u.Quantity((yerr_lo, data["flux_error_hi"][notul])) xerr = u.Quantity((data["energy_error_lo"], data["energy_error_hi"])) opts = dict( zorder=100, marker=marker, ls="", elinewidth=2, capsize=0, mec=color, mew=0.1, ms=5, color=color, ) opts.update(**errorbar_opts) ax1.errorbar( data["energy"][notul].to(e_unit).value, (data["flux"][notul] * sedfg[notul]).to(f_unit).value, yerr=(yerr * sedfg[notul]).to(f_unit).value, xerr=xerr[:, notul].to(e_unit).value, **opts, ) if np.any(ul): if "elinewidth" in errorbar_opts: ulim_opts["elinewidth"] = errorbar_opts["elinewidth"] _plot_ulims( ax1, data["energy"][ul].to(e_unit).value, (data["flux"][ul] * sedfg[ul]).to(f_unit).value, (xerr[:, ul]).to(e_unit).value, color, **ulim_opts, ) ax1.set_xscale("log") ax1.set_yscale("log") xmin = 10 ** np.floor( np.log10(np.min(data["energy"] - data["energy_error_lo"]).to(e_unit).value) ) xmax = 10 ** np.ceil( np.log10(np.max(data["energy"] + data["energy_error_hi"]).to(e_unit).value) ) ax1.set_xlim(xmin, xmax) # avoid autoscaling to errorbars to 0 notul = ~data_all["ul"] if np.any(data_all["flux_error_lo"][notul] >= data_all["flux"][notul]): elo = (data_all["flux"][notul] * sedf[notul]).to(f_unit).value - ( data_all["flux_error_lo"][notul] * sedf[notul] ).to(f_unit).value gooderr = np.where(data_all["flux_error_lo"][notul] < data_all["flux"][notul]) ymin = 10 ** np.floor(np.log10(np.min(elo[gooderr]))) ax1.set_ylim(bottom=ymin) if ylabel is None: if sed: ax1.set_ylabel( r"$E^2\mathrm{{d}}N/\mathrm{{d}}E$" " [{0}]".format( u.Unit(f_unit).to_string("latex_inline") ) ) else: ax1.set_ylabel( r"$\mathrm{{d}}N/\mathrm{{d}}E$" " [{0}]".format( u.Unit(f_unit).to_string("latex_inline") ) ) else: ax1.set_ylabel(ylabel) def _plot_residuals_to_ax( data_all, model_ML, ax, e_unit=u.eV, sed=True, errorbar_opts={} ): """Function to compute and plot residuals in units of the uncertainty""" if "group" not in data_all.keys(): data_all["group"] = np.zeros(len(data_all)) groups = np.unique(data_all["group"]) MLf_unit, MLsedf = sed_conversion(model_ML[0], model_ML[1].unit, sed) MLene = model_ML[0].to(e_unit) MLflux = (model_ML[1] * MLsedf).to(MLf_unit) ax.axhline(0, color="k", lw=1, ls="--") interp = False if data_all["energy"].size != MLene.size or not np.allclose( data_all["energy"].value, MLene.value ): interp = True from scipy.interpolate import interp1d modelfunc = interp1d(MLene.value, MLflux.value, bounds_error=False) for g in groups: groupidx = np.where(data_all["group"] == g) data = data_all[groupidx] notul = ~data["ul"] df_unit, dsedf = sed_conversion(data["energy"], data["flux"].unit, sed) ene = data["energy"].to(e_unit) xerr = u.Quantity((data["energy_error_lo"], data["energy_error_hi"])) flux = (data["flux"] * dsedf).to(df_unit) dflux = (data["flux_error_lo"] + data["flux_error_hi"]) / 2.0 dflux = (dflux * dsedf).to(df_unit)[notul] if interp: difference = flux[notul] - modelfunc(ene[notul]) * flux.unit else: difference = flux[notul] - MLflux[groupidx][notul] # wrap around color and marker cycles color = color_cycle[int(g) % len(color_cycle)] marker = marker_cycle[int(g) % len(marker_cycle)] opts = dict( zorder=100, marker=marker, ls="", elinewidth=2, capsize=0, mec=color, mew=0.1, ms=6, color=color, ) opts.update(errorbar_opts) ax.errorbar( ene[notul].value, (difference / dflux).decompose().value, yerr=(dflux / dflux).decompose().value, xerr=xerr[:, notul].to(e_unit).value, **opts, ) from matplotlib.ticker import MaxNLocator ax.yaxis.set_major_locator( MaxNLocator(5, integer="True", prune="upper", symmetric=True) ) ax.set_ylabel(r"$\Delta\sigma$") ax.set_xscale("log")
[docs] def plot_data( input_data, xlabel=None, ylabel=None, sed=True, figure=None, e_unit=None, ulim_opts={}, errorbar_opts={}, ): """ Plot spectral data. Parameters ---------- input_data : `emcee.EnsembleSampler`, `astropy.table.Table`, or `dict` Spectral data to plot. Can be given as a data table, a dict generated with `validate_data_table` or a `emcee.EnsembleSampler` with a data property. xlabel : str, optional Label for the ``x`` axis of the plot. ylabel : str, optional Label for the ``y`` axis of the plot. sed : bool, optional Whether to plot SED or differential spectrum. figure : `matplotlib.figure.Figure`, optional `matplotlib` figure to plot on. If omitted a new one will be generated. e_unit : `astropy.unit.Unit`, optional Units for energy axis. Defaults to those of the data. ulim_opts : dict Options for upper-limit plotting. Available options are capsize (arrow width) and height_fraction (arrow length in fraction of flux value). errorbar_opts : dict Addtional options to pass to `matplotlib.plt.errorbar` for plotting the spectral flux points. """ import matplotlib.pyplot as plt try: data = validate_data_table(input_data) except TypeError as exc: if hasattr(input_data, "data"): data = elif isinstance(input_data, dict) and "energy" in input_data.keys(): data = input_data else: log.warning( "input_data format unknown, no plotting data! " "Data loading exception: {}".format(exc) ) raise if figure is None: f = plt.figure() else: f = figure if len(f.axes) > 0: ax1 = f.axes[0] else: ax1 = f.add_subplot(111) # try to get units from previous plot in figure try: old_e_unit = u.Unit(ax1.get_xlabel().split("[")[-1].split("]")[0]) except ValueError: old_e_unit = u.Unit("") if e_unit is None and old_e_unit.physical_type == "energy": e_unit = old_e_unit elif e_unit is None: e_unit = data["energy"].unit _plot_data_to_ax( data, ax1, e_unit=e_unit, sed=sed, ylabel=ylabel, ulim_opts=ulim_opts, errorbar_opts=errorbar_opts, ) if xlabel is not None: ax1.set_xlabel(xlabel) elif xlabel is None and ax1.get_xlabel() == "": ax1.set_xlabel( r"$\mathrm{Energy}$" + " [{0}]".format(e_unit.to_string("latex_inline")) ) ax1.autoscale() return f
def plot_distribution(samples, label, figure=None): """Plot a distribution and print statistics about it""" import matplotlib.pyplot as plt from scipy import stats quant = [16, 50, 84] quantiles = dict(zip(quant, np.percentile(samples, quant))) std = np.std(samples) if isinstance(samples[0], u.Quantity): unit = samples[0].unit std = std.value quantiles = {k: v.value for k, v in quantiles.items()} else: unit = "" dist_props = "{label} distribution properties:\n \ $-$ median: ${median}$ {unit}, std: ${std}$ {unit}\n \ $-$ Median with uncertainties based on \n \ the 16th and 84th percentiles ($\\sim$1$\\sigma$):\n\ {label} = {value_error} {unit}".format( label=label, median=_latex_float(quantiles[50]), std=_latex_float(std), value_error=_latex_value_error( quantiles[50], quantiles[50] - quantiles[16], quantiles[84] - quantiles[50], ), unit=unit, ) if figure is None: f = plt.figure() else: f = figure ax = f.add_subplot(111) f.subplots_adjust(bottom=0.40, top=0.93, left=0.06, right=0.95) f.text(0.2, 0.27, dist_props, ha="left", va="top") histnbins = min(max(25, int(len(samples) / 100.0)), 100) xlabel = "" if label is None else label if isinstance(samples, u.Quantity): samples_nounit = samples.value else: samples_nounit = samples n, x, _ = ax.hist( samples_nounit, histnbins, histtype="stepfilled", color=color_cycle[0], lw=0, density=True, ) kde = stats.gaussian_kde(samples_nounit) ax.plot(x, kde(x), color="k", label="KDE") ax.axvline( quantiles[50], ls="--", color="k", alpha=0.5, lw=2, label="50% quantile", ) ax.axvspan( quantiles[16], quantiles[84], color=(0.5,) * 3, alpha=0.25, label="68% CI", lw=0, ) for xticklabel in ax.get_xticklabels(): xticklabel.set_rotation(45) if unit != "": xlabel += " [{0}]".format(unit) ax.set_xlabel(xlabel) ax.set_title("Posterior distribution of {0}".format(label)) ax.set_ylim(top=n.max() * 1.05) return f
[docs] def plot_corner(sampler, show_ML=True, **kwargs): """ A plot that summarizes the parameter samples by showing them as individual histograms and 2D histograms against each other. The maximum likelihood parameter vector is indicated by a cross. This function is a thin wrapper around `corner.corner`, found at Parameters ---------- sampler : `emcee.EnsembleSampler` Sampler with a stored chain. show_ML : bool, optional Whether to show the maximum likelihood parameter vector as a cross on the 2D histograms. """ import matplotlib.pyplot as plt oldlw = plt.rcParams["lines.linewidth"] plt.rcParams["lines.linewidth"] = 0.7 try: from corner import corner if show_ML: _, MLp, _, _ = find_ML(sampler, 0) else: MLp = None corner_opts = { "labels": sampler.labels, "truths": MLp, "quantiles": [0.16, 0.5, 0.84], "verbose": False, "truth_color": color_cycle[0], } corner_opts.update(kwargs) f = corner(sampler.get_chain(flat=True), **corner_opts) except ImportError: log.warning("The corner package is not installed; corner plot not available") f = None plt.rcParams["lines.linewidth"] = oldlw return f