# Copyright (c) 2021-2026 Cubillos & Blecic
# Pyrat Bay is open-source software under the GPL-2.0 license (see LICENSE)
__all__ = [
'Loglike',
'weighted_to_equal',
'posterior_snapshot',
'get_multinest_map',
'multinest_run',
'posterior_post_processing',
]
import datetime
import os
import sys
import time
import pickle
import mc3
import numpy as np
from ..pyrat import Pyrat
from .. import constants as pc
from .. import plots as pp
from .mpi_tools import get_mpi_rank
from .tools import (
eta,
isfile,
)
[docs]
class Loglike():
"""
Wrapper to compute the log(likelihood) for a Pyrat object.
Heavily based on mc3.stats.Loglike class, but this one
allows to dynamically modify the data and uncertainty values.
"""
def __init__(self, pyrat):
self.obs = pyrat.obs
self.func = pyrat.eval
self.params = pyrat.ret.params
self.pstep = pyrat.ret.pstep
self.ifree = self.pstep > 0
self.ishare = np.where(self.pstep<0)[0]
if pyrat.obs.data is None and pyrat.obs.data_hires is None:
raise ValueError(
'Attempting to compute a log-likelihood for a model '
'with no data'
)
if np.sum(self.ifree) == 0:
raise ValueError(
'Attempting to compute a log-likelihood for a model '
'with no free parameters'
)
self.pnames = np.array(pyrat.ret.texnames)
self.retrieval_file = pyrat.ret.retrieval_file
# dt_snapshot hours to seconds
self._dt_snapshot = pyrat.inputs.dt_retrieval_snapshot * 3600.0
if self._dt_snapshot > 0:
self.timer = time.time()
def get_data(self):
"""
Concatenate low- and high-resolution data into a single array
"""
obs = self.obs
if obs.data is not None and obs.data_hires is not None:
data = np.concatenate((obs.data, obs.data_hires))
uncert = np.concatenate((obs.uncert, obs.uncert_hires))
elif obs.data is not None:
data = obs.data
uncert = obs.uncert
else:
data = obs.data_hires
uncert = obs.uncert_hires
return data, uncert
def __call__(self, params):
"""
Evaluate the log(likelihood) for the input set of parameters.
Parameters
----------
params: 1D float array
Array of free parameters.
"""
# TMP plots
if get_mpi_rank() == 0 and self._dt_snapshot>0:
now = time.time()
if now - self.timer >= self._dt_snapshot:
posterior_snapshot(self.retrieval_file, self.pnames[self.ifree])
self.timer = now
# Update the free and shared parameters:
self.params[self.ifree] = params
for s in self.ishare:
self.params[s] = self.params[-int(self.pstep[s])-1]
# Evaluate model (and update data if necessary)
model = self.func(self.params, retmodel=False)
# Concatenate (low-res) data and high-res data arrays
data, uncert = self.get_data()
log_like = (
-0.5*np.sum(((data - model) / uncert)**2.0)
-0.5*np.sum(np.log(2.0*np.pi*uncert**2.0))
)
if not np.isfinite(log_like):
log_like = -1.0e98
return log_like
[docs]
def weighted_to_equal(posterior_file, get_weighted=False, min_size=15000):
"""
Compute an equally-weighted sample from a weighted-probability sample
read from a Multinest output.
Parameters
----------
posterior_file: String
A MultiNest probability-weighted sample output.
get_weighted: Bool
If True, also return the weighted sample.
min_size: Integer
Set the minimum sample size for the equally weighted posterior.
Returns
-------
equal_posterior: 2D float array
An equally-weighted posterior sample with dimensions (nsamples, npars).
weighted_posterior: 2D float array
The Multinest probabilty-weighted sample with dimensions
(nsamples, npars). This is only returned if get_weighted is True.
Examples
--------
>>> import pyratbay.tools as pt
>>> posterior = pt.weighted_to_equal('multinest_output.txt')
>>> # Bet both equal and weighted samples:
>>> posterior, weighted = pt.weighted_to_equal(
>>> 'multinest_output.txt',
>>> get_weighted=True,
>>> )
"""
# MN columns have: sample probability, -2*loglikehood, parameter values
data = np.loadtxt(posterior_file)
probability = data[:,0]
weighted_posterior = data[:,2:]
nsample = len(probability)
# Generate PDF from weigths CDF (see Numerical Recipes Sec. 7.3.2)
# This accomplishes the same as, e.g., dynesty.utils.resample_equal()
cdf = np.cumsum(probability)
cdf /= cdf[-1]
if nsample < min_size:
nsample = min_size
rng = np.random.default_rng(seed=None)
u = sorted(rng.random(nsample))
indices = np.zeros(nsample, dtype=int)
i = 0
i_cdf = 0
while i < nsample:
if u[i] < cdf[i_cdf]:
indices[i] = i_cdf
i += 1
else:
i_cdf += 1
equal_posterior = weighted_posterior[rng.permutation(indices)]
if get_weighted:
return equal_posterior, weighted_posterior
return equal_posterior
[docs]
def posterior_snapshot(retrieval_file, pnames):
"""
Take a snapshot of a retrieval run, plot the histogram and traces
of the parameters.
"""
root, pfile = os.path.split(retrieval_file)
if not os.path.exists(f'{root}/{pfile}.txt'):
#print('No posterior file yet')
return
equal_posterior, weighted_posterior = weighted_to_equal(
f'{root}/{pfile}.txt', get_weighted=True,
)
with open(f'{root}/{pfile}resume.dat', 'r') as f:
lines = f.readlines()
nsamples = int(lines[1].split()[1]) / 1e6
today = str(datetime.date.today()).replace('-', '_')
label = f'{nsamples:06.2f}M__{today}'
if len(np.unique(equal_posterior[:,0])) == 1:
#print(f'Not enough samples to generate plots ({label}).')
return
post = mc3.plots.Posterior(equal_posterior, pnames)
mc3.plots.trace(
weighted_posterior,
pnames=pnames,
savefile=f'{root}/tmp_trace_{label}.png',
)
post.plot_histogram(savefile=f'{root}/tmp_histograms_{label}.png')
[docs]
def get_multinest_map(stats_file):
"""
Get maximum-a-posteriori (MAP) parameters from a MultiNest output file.
Parameters
----------
stats_file: String
Path to a Multinest *stats.dat output file.
Returns
-------
params: 1D float array
The MAP parameter values.
"""
with open(stats_file, 'r') as f:
lines = f.readlines()
map_line = lines.index('MAP Parameters\n') + 2
nlines = len(lines)
npars = len(lines) - map_line
params = []
for i in range(npars):
if map_line+i >= nlines:
break
if lines[map_line+i].strip() == '':
break
index, value = lines[map_line+i].split()
params.append(value)
return np.array(params, np.double)
[docs]
def multinest_run(pyrat, basename):
"""
A Wrapper of a MultiNest posterior sampling.
Parameters
----------
pyrat: Pyrat() object
basename: String
Basename for output files. May contain path.
Should not contain a file extension.
Note
----
For OS X users, it is recommended to set the TMPDIR environment
variable to "/tmp", e.g., from the command line:
export TMPDIR=/tmp
to avoid an MPI error when terminating the execution
(the call will run to completion in any case)
https://github.com/open-mpi/ompi/issues/7393#issuecomment-882018321
"""
from pymultinest.run import run
os.environ["OMP_NUM_THREADS"] = "1"
# Shut up for a moment:
log = pyrat.log
rank = get_mpi_rank()
if rank == 0:
log.msg('Starting Multinest atmospheric retrieval')
tmp_verb = log.verb
log.verb = -1
n_free = np.sum(pyrat.ret.pstep>0)
prior_transform = mc3.stats.Prior_transform(
pyrat.ret.prior,
pyrat.ret.priorlow,
pyrat.ret.priorup,
pyrat.ret.pmin,
pyrat.ret.pmax,
pyrat.ret.pstep,
)
def safe_prior(cube, ndim, nparams):
try:
a = np.array([cube[i] for i in range(n_free)])
b = prior_transform(a)
for i in range(n_free):
cube[i] = b[i]
except Exception as e:
sys.stderr.write(f'ERROR in prior: {e}\n')
sys.exit(1)
loglike = Loglike(pyrat)
def safe_loglikelihood(cube, ndim, nparams, lnew):
try:
a = np.array([cube[i] for i in range(n_free)])
l = float(loglike(a))
return l
except Exception as e:
sys.stderr.write(f'ERROR in loglikelihood: {e}\n')
sys.exit(1)
# The pymultinest call:
run(
LogLikelihood=safe_loglikelihood,
Prior=safe_prior,
n_dims=n_free,
importance_nested_sampling=False,
outputfiles_basename=basename,
n_live_points=pyrat.ret.nlive,
resume=pyrat.ret.resume,
verbose=True,
)
if get_mpi_rank() != 0:
return
# Post (some plots and stats):
output = {}
output['pstep'] = pstep = pyrat.ret.pstep
output['bestp'] = bestp = pyrat.ret.params
output['texnames'] = texnames = np.array(pyrat.ret.texnames)
output['pnames'] = pyrat.ret.pnames
ifree = np.where(pstep>0)[0]
ishare = np.where(pstep<0)[0]
bestp[ifree] = get_multinest_map(f'{basename}stats.dat')
for s in ishare:
bestp[s] = bestp[-int(pstep[s])-1]
posterior, weighted_posterior = weighted_to_equal(
f'{basename}.txt',
get_weighted=True,
)
output['posterior'] = posterior
theme = pyrat.ret.theme
post = mc3.plots.Posterior(
posterior, pnames=texnames[ifree], theme=theme,
bestp=bestp[ifree], statistics=pyrat.ret.statistics,
show_estimates=True, # TBD: get from cfg?
)
# Trace plot:
savefile = f'{basename}_posterior_trace.png'
mc3.plots.trace(
weighted_posterior,
pnames=texnames[ifree],
color=theme.color,
savefile=savefile,
)
log.msg(savefile, indent=2)
# Statistics:
best_model = pyrat.eval(bestp, retmodel=False)
data, uncert = loglike.get_data()
ndata = len(data)
best_chisq = np.sum((best_model-data)**2 / uncert**2)
red_chisq = best_chisq / (ndata-n_free)
if ndata <= n_free:
red_chisq = np.nan
# TBD: need to add log(prior)
output['best_log_post'] = loglike(bestp[ifree])
output['best_chisq'] = best_chisq
output['red_chisq'] = red_chisq
output['BIC'] = best_chisq + n_free*np.log(ndata)
output['stddev_residuals'] = np.std(best_model-data)
sample_stats = mc3.stats.calc_sample_statistics(
post.posterior, bestp, pstep, calc_hpd=True,
)
output['medianp'] = sample_stats[0]
output['meanp'] = sample_stats[1]
output['stdp'] = sample_stats[2]
output['median_low_bounds'] = sample_stats[3]
output['median_high_bounds'] = sample_stats[4]
output['mode'] = sample_stats[5]
output['hpd_low_bounds'] = sample_stats[6]
output['hpd_high_bounds'] = sample_stats[7]
stats_file = f'{basename}_statistics.txt'
mc3.stats.summary_stats(post, output, filename=stats_file)
# Restore verbosity
log.verb = tmp_verb
return output
[docs]
def posterior_post_processing(cfg_file=None, pyrat=None, suffix=''):
"""
Compute quantities of interest from a retrieval posterior distribution.
The produced data is stored into a pickle file with root name based
on the logfile.
Parameters
----------
cfg_file: String
A pyratbay config file of a retrieval run (already executed,
so the parameter posterior files must already exist).
pyrat: a Pyrat instance
A pyrat object of an already executed retrieval.
Used if cfg_file is None.
"""
if pyrat is None and cfg_file is None:
raise ValueError(
"At least one of the input arguments ('cfg_file' or 'pyrat') "
"must be provided"
)
if cfg_file is not None:
pyrat = Pyrat(cfg_file, log=False, mute=True)
# Basename of the output files (no extension):
basename = pyrat.ret.retrieval_file
if pyrat.ret.sampler == 'multinest':
if isfile(basename + '.txt') == 0:
raise ValueError('MultiNest posterior outputs do not exist')
posterior = weighted_to_equal(basename + '.txt')
elif pyrat.ret.sampler == 'snooker':
mcmc = np.load(basename + '.npz')
posterior = mc3.utils.burn(mcmc)[0]
texnames = np.array(pyrat.ret.texnames)
theme = pyrat.ret.theme
post = mc3.plots.Posterior(
posterior, texnames, theme=theme, statistics=pyrat.ret.statistics,
)
# Quantiles for all posterior stats: median -1sigma +1sigma -2sigma +2sigma
quantiles = np.array([0.5, 0.15865, 0.84135, 0.02275, 0.97725])
nquantiles = len(quantiles)
# Parameter statistics
stats_1sigma = mc3.stats.calc_sample_statistics(
post.posterior, pyrat.ret.params, pyrat.ret.pstep, quantile=0.683,
)
stats_2sigma = mc3.stats.calc_sample_statistics(
post.posterior, pyrat.ret.params, pyrat.ret.pstep, quantile=0.9545,
)
ifree = pyrat.ret.pstep > 0
nfree = np.sum(ifree)
params_posterior = np.zeros((nquantiles,nfree))
params_posterior[0] = stats_1sigma[0][ifree]
params_posterior[1] = stats_1sigma[3][ifree]
params_posterior[2] = stats_1sigma[4][ifree]
params_posterior[3] = stats_2sigma[3][ifree]
params_posterior[4] = stats_2sigma[4][ifree]
# Unique posterior samples:
u, uind, uinv = np.unique(
post.posterior[:,0], return_index=True, return_inverse=True,
)
n_unique = len(u)
print(f'Computing {len(u):d} models for posteriors post-processing')
# Array of all model parameters (with unique samples)
u_posterior = np.repeat([pyrat.ret.params], n_unique, axis=0)
u_posterior[:,ifree] = post.posterior[uind]
is_eclipse = pyrat.od.rt_path in pc.eclipse_rt
is_emission = pyrat.od.rt_path in pc.emission_rt
is_transmission = pyrat.od.rt_path in pc.transmission_rt
nbands = pyrat.obs.ndata
band_wl = 1.0 / pyrat.obs.bandwn / pc.um
ndata_hires = pyrat.obs.nfilters_hires
if ndata_hires != 0:
nbands = ndata_hires
band_wl = np.array([band.wl0 for band in pyrat.obs.filters_hires])
# Evaluate models / spectra:
pyrat.spec.specfile = None
nwave = pyrat.spec.nwave
models = np.zeros((n_unique, nwave))
band_models = np.zeros((n_unique, nbands))
temp = np.zeros((n_unique, pyrat.atm.nlayers))
vmr = np.zeros((n_unique, pyrat.atm.nlayers, pyrat.atm.nmol))
cf = np.zeros((n_unique, pyrat.atm.nlayers, nbands))
t0 = time.time()
for i in range(n_unique):
models[i], band_models[i] = pyrat.eval(u_posterior[i])
temp[i] = pyrat.atm.temp
vmr[i] = pyrat.atm.vmr
cf[i] = pyrat.band_contribution()
timeleft = eta(time.time()-t0, i+1, n_unique, fmt='.2f')
if i%3 == 0:
eta_text = (
f'{i+1}/{n_unique} samples, '
f'{100*(i+1)/n_unique:.2f} % done, '
f'ETA: {timeleft}'
)
print(f'{eta_text:80s}', end='\r', flush=True)
endline = f'{100*(i+1)/n_unique:6.2f} % done'
print(f'{endline:80s}', flush=True)
spectrum_posterior = np.zeros((nquantiles,nwave))
for i in range(nwave):
msample = models[uinv,i]
spectrum_posterior[:,i] = np.percentile(msample, 100.0*quantiles)
band_models_posterior = np.percentile(
band_models[uinv,:], 100.0*quantiles, axis=0,
)
temperature_posterior = np.percentile(temp[uinv], 100.0*quantiles, axis=0)
vmr_posterior = np.percentile(vmr[uinv], 100.0*quantiles, axis=0)
cf_posterior = cf[uinv]
cf_median = np.median(cf_posterior, axis=0)
# Collect spectroscopically active species
active_species = []
for model in pyrat.opacity.models:
if not hasattr(model, 'species'):
continue
if isinstance(model.species, str):
model_species = [model.species]
else:
model_species = list(model.species)
if model.name == 'H- bound-free/free-free' and 'H-' in pyrat.atm.species:
model_species.append('H-')
for spec in model_species:
if spec not in active_species:
active_species.append(spec)
if pyrat.od.rt_path == 'f_lambda':
flux_units = 'W m-2 um-1'
else:
flux_units = 'erg s-1 cm-2 cm'
units = {
'depth': pyrat.obs.units,
'flux': flux_units,
'pressure': 'bar',
'temperature': 'K',
'wavelength': 'um',
}
outputs = {}
if is_transmission:
outputs['depth_posterior'] = spectrum_posterior
elif is_emission:
outputs['flux_posterior'] = spectrum_posterior
elif is_eclipse:
rprs = pyrat.atm.rplanet / pyrat.atm.rstar
fplanet = spectrum_posterior * pyrat.spec.starflux / rprs**2.0
outputs['depth_posterior'] = spectrum_posterior
outputs['flux_posterior'] = fplanet
outputs['rprs'] = rprs
outputs |= {
'temperature_posterior': temperature_posterior,
'vmr_posterior': vmr_posterior,
'band_models_posterior': band_models_posterior,
'cf_posterior_median': cf_median,
'params_posterior' : params_posterior,
'params_names' : np.array(pyrat.ret.pnames)[ifree],
'params_texnames' : texnames[ifree],
'pressure': pyrat.atm.press,
'wl': pyrat.spec.wl,
'band_wl': band_wl,
'bands_wl': [band.wl for band in pyrat.obs.filters],
'bands_response': [band.response for band in pyrat.obs.filters],
'species': pyrat.atm.species,
'active_species': active_species,
'starflux': pyrat.spec.starflux,
'quantiles': quantiles,
'units': units,
'path': pyrat.od.rt_path,
}
if pyrat.obs.data is not None:
outputs['data'] = pyrat.obs.data
outputs['uncert'] = pyrat.obs.uncert
if pyrat.obs.data_hires is not None:
outputs['data_hires'] = pyrat.obs.data_hires
outputs['uncert_hires'] = pyrat.obs.uncert_hires
post_file = f'{basename}{suffix}_posteriors_info.pickle'
with open(post_file, 'wb') as handle:
pickle.dump(outputs, handle, protocol=pickle.HIGHEST_PROTOCOL)
# Now make some plots:
pp.posteriors(
post_file,
theme=pyrat.ret.theme,
data_color=pyrat.inputs.data_color,
logxticks=pyrat.inputs.logxticks,
)
return outputs