Source code for pyratbay.pyrat.pyrat_obj

# Copyright (c) 2021-2026 Cubillos & Blecic
# Pyrat Bay is open-source software under the GPL-2.0 license (see LICENSE)

from collections import OrderedDict
import os
import subprocess

import numpy as np
import scipy.interpolate as si
import mc3

from .. import atmosphere as pa
from .. import constants as pc
from .. import io as io
from .. import opacity as op
from .. import plots as pp
from .. import spectrum as ps
from .. import tools as pt

from .atmosphere import Atmosphere
from .observation import Observation
from .opacity import Opacity
from .retrieval import Retrieval
from .voigt import Voigt
from . import spectrum as sp
from . import extinction as ex
from . import objects as ob
from . import argum as ar


[docs] class Pyrat(): """ Main Pyrat object. """ def __init__(self, cfg_file, log=True, mute=False): """ Parse the command-line arguments into the pyrat object. Parameters ---------- cfg_file: String A Pyrat Bay configuration file. log: Bool Flag to save screen outputs to file (True) or not (False) (e.g., to prevent overwritting log of a previous run). mute: Bool If True, enforce verb to take a value of -1. Examples -------- >>> import pyratbay as pb >>> pyrat = pb.run('spectrum_transmission.cfg') """ # Setup time tracker: timer = pt.Timer() self.timestamps = OrderedDict() self.timestamps['init'] = timer.clock() # Parse config file inputs: if isinstance(cfg_file, str): self.inputs, self.log = pt.parse(cfg_file, log, mute) else: # If cfg file was already parsed (pb.run), sneakily use log # argument as the logging object instead of a boolean flag self.inputs = cfg_file self.log = log self.ncpu = self.inputs.ncpu self.runmode = self.inputs.runmode # TBD: Remove self.ex entirely? self.ex = ob.Extinction(self.inputs, self.log) self.od = ob.Optdepth(self.inputs, self.log) # Initialize wavenumber sampling: self.spec = sp.Spectrum(self.inputs, self.log) self.timestamps['spectrum'] = timer.clock() # Initialize Atmosphere: self.atm = Atmosphere(self.inputs, self.spec.wn, self.log) self.timestamps['atmosphere'] = timer.clock() self.obs = Observation(self.inputs, self.spec.wn, self.log) ar.check_spectrum(self) # Setup opacity models: self.opacity = Opacity( self.inputs, self.spec.wn, self.atm.species, self.atm.press, self.log, self, ) self.timestamps['read opacities'] = timer.clock() if 'lbl' in self.opacity.models_type: i_lbl = self.opacity.models_type.index('lbl') lbl = self.opacity.models[i_lbl] self.voigt = Voigt( self.inputs, lbl, self.ex, self.spec, self.atm, self.log, ) self.timestamps['voigt'] = timer.clock() # Setup more retrieval parameters: ar.setup(self) self.ret = Retrieval( self.inputs, self.atm, self.obs, self.opacity, self.log, )
[docs] def compute_opacity(self): """ Calculate opacity (cm2 molecule-1) tabulated over temperature, pressure, and wavenumber arrays """ ex.compute_opacity(self)
[docs] def optical_depth(self): """ Calculate the optical depth. """ self.log.head('\nBegin optical-depth calculation.') ibottom = self.atm.nlayers for model in self.opacity.models: if model.name == 'deck': ibottom = model.itop + 1 break if self.opacity.is_patchy: extinction_cloudy = self.opacity.ec_cloud else: extinction_cloudy = None raypath, depth, ideep, depth_clear, ideep_clear = op.optical_depth( self.od.rt_path, self.opacity.ec, self.atm.radius, self.atm.rtop, ibottom, self.od.maxdepth, extinction_cloudy, ) self.od.raypath = raypath self.od.depth = depth self.od.ideep = ideep if self.opacity.is_patchy: self.od.ideep_clear = ideep_clear self.od.depth_clear = depth_clear self.log.head('Optical depth done.')
[docs] def run(self, temp=None, vmr=None, radius=None, skip=[]): """ Evaluate a Pyrat spectroscopic model Parameters ---------- temp: 1D float ndarray Updated atmospheric temperature profile in Kelvin, of size nlayers. abund: 2D float ndarray Updated atmospheric abundances profile by number density, of shape [nlayers, nmol]. radius: 1D float ndarray Updated atmospheric altitude profile in cm, of size nlayers. skip: List of strings If set, the opacity from the model names or line-sample species listed here will be neglected. """ timer = pt.Timer() # Re-calculate atmospheric properties if required: self.atm.calc_profiles(temp, vmr, radius) out_of_bounds = self.opacity.check_temp_bounds(self.atm.temp) good_status = len(out_of_bounds) == 0 if not good_status: self.log.warning( "Temperature values lie out of the cross-section " f"boundaries for: {out_of_bounds}" ) self.spec.spectrum[:] = 0.0 return if self.atm._out_of_bounds_vmr: self.spec.spectrum[:] = 0.0 return # Calculate extinction coefficient: self.opacity.calc_extinction_coefficient( self.atm.temp, self.atm.radius, self.atm.d, skip=skip, ) self.timestamps['extinction'] = timer.clock() # Calculate the optical depth: self.optical_depth() self.timestamps['odepth'] = timer.clock() # Calculate the spectrum: sp.spectrum(self) self.timestamps['spectrum'] = timer.clock() self.log.msg( "\nTimestamps (s):\n" + "\n".join( f"{key:10s}: {val:10.6f}" for key,val in self.timestamps.items() ) )
[docs] def eval(self, params, retmodel=True, skip=[]): """ Fitting routine for atmospheric retrieval Parameters ---------- params: 1D float iterable Array of fitting parameters that define the atmosphere. retmodel: Bool Flag to include the model spectra in the return. skip: List of strings If set, the opacity from the model names or line-sample species listed here will be neglected. Returns ------- spectrum: 1D float ndarray The output model spectra. Returned only if retmodel=True. bandflux: 1D float ndarray The waveband-integrated spectrum values. """ atm = self.atm ret = self.ret obs = self.obs params = np.asarray(params) if len(params) != ret.nparams: self.log.warning( f'The number of input fitting parameters ({len(params)}) does ' f'not match\nthe number of required parameters ({ret.nparams})' ) return None, None if retmodel else None # Update models parameters: ret.params = np.copy(params) if ret.itemp is not None: ifree = ret.map_pars['temp'] atm.tpars[ifree] = params[ret.itemp] if ret.imol is not None: for j,imol in enumerate(ret.imol): imodel,idx = ret.map_pars['mol'][j] atm.vmr_pars[imodel][idx] = params[imol] if ret.irad is not None: self.atm.rplanet = params[ret.irad][0] * pt.u(atm.runits) elif ret.ipress is not None: self.atm.refpressure = 10.0**params[ret.ipress][0] if ret.imass is not None: self.atm.mplanet = params[ret.imass][0] * pt.u(self.atm.mass_units) ifree = ret.map_pars['opacity'] for j,model in enumerate(self.opacity.models): if ifree[j] == []: continue idx = ifree[j] ipar = ret.iopacity[j] model.pars[idx] = params[ipar] if ret.ipatchy is not None: self.opacity.fpatchy = params[ret.ipatchy][0] if ret.itstar is not None: self.atm.tstar = params[ret.itstar][0] self.spec.starflux = self.spec.flux_interp(self.atm.tstar) self.obs.bandflux_star = np.array([ band(self.spec.starflux) for band in self.obs.filters ]) if ret.idilut is not None: self.spec.f_dilution = params[ret.idilut][0] # Calculate atmosphere and spectrum: self.run(skip=skip) reject_flag = False # Turn-on reject flag if temperature is out-of-bounds: temp = atm.temp if np.any(temp < ret.tlow) or np.any(temp > ret.thigh): temp[:] = 0.0 reject_flag = True self.log.warning( "Input temperature profile runs out of " f"boundaries ({ret.tlow:.1f}--{ret.thigh:.1f} K)" ) # Check abundaces stay within bounds: if pa.qcapcheck(atm.vmr, ret.qcap, atm.ibulk): reject_flag = True self.log.warning( "The sum of trace abundances' VMRs exceeds " f"the cap of {ret.qcap:.3f}" ) if self.atm._out_of_bounds_vmr: reject_flag = True if self.od.rt_path == 'f_lambda': # Convert flux from (erg s-1 cm-2 cm) to (W m-2 um-1) # TBD: check rplanet and distance exist self.spec.spectrum = ( 10.0 * self.spec.spectrum * (atm.rplanet/self.atm.distance * self.spec.wn*pc.um)**2 ) # High-resolution data if self.obs.ndata_hires > 0: self.spec.spectrum_convolved = conv_flux = ps.inst_convolution( self.spec.wn, self.spec.spectrum, obs.inst_resolution, sampling_res=self.spec.resolution, ) # Radial-velocity shift if ret.irv is not None: vel_km = params[ret.irv][0] wn = ps.rv_shift(vel_km, wn=self.spec.wn) else: wn = self.spec.wn # Interpolate at data if obs.data_hires is not None: obs.bandflux_hires = si.interp1d(wn, conv_flux)(obs.wn_hires) if reject_flag: obs.bandflux_hires[:] = np.inf # TBD: At the moment either return hires or lowres, but should be # able to combine in the future if retmodel: return self.spec.spectrum, obs.bandflux_hires return obs.bandflux_hires # Band-integrate spectrum: obs.bandflux = self.band_integrate() # Instrumental offset: if ret.ioffset is not None: ifree = ret.map_pars['offset'] obs.offset_pars[ifree] = params[ret.ioffset] obs.data = obs.depth.offset_data(obs.offset_pars, obs.units) # Uncertainty scaling: if ret.ierror is not None: ifree = ret.map_pars['error'] obs.uncert_pars[ifree] = params[ret.ierror] obs.uncert = obs.depth.scale_errors(obs.uncert_pars, obs.units) # Invalid model: if not np.any(obs.bandflux): reject_flag = True # Reject this iteration if there are invalid temperatures or radii: if obs.bandflux is not None and reject_flag: obs.bandflux[:] = np.inf if retmodel: return self.spec.spectrum, obs.bandflux return obs.bandflux
[docs] def retrieval(self): """ Run an MCMC or nested-sampling atmospheric retrieval. """ ret = self.ret obs = self.obs log = self.log if ret.sampler is None: log.error( 'Undefined retrieval algorithm (sampler). ' f'Select from {pc.samplers}' ) if ret.params is None: log.error( 'Undefined retrieval fitting parameters (retrieval_params)' ) if ret.pstep is None: log.error('Missing pstep argument, required for retrieval runs') if obs.data is None and obs.data_hires is None: log.error("Undefined transit/emission/eclipse data for retrieval") if obs.data is not None: if obs.uncert is None: log.error("Undefined data uncertainties") if obs.nfilters == 0: log.error("Undefined transmission filters (filters)") if obs.data_hires is not None: if obs.uncert_hires is None: log.error("Undefined high-resolution data uncertainties") if obs.nfilters_hires == 0: log.error("Undefined transmission filters (filters)") # Basename of the output files: basename = ret.retrieval_file # Create output folder if needed: pt.mkdir(basename) # Mute logging in pyrat object, but not in mc3: self.log = mc3.utils.Log(verb=-1, width=80) self.spec.specfile = None # Avoid writing spectra during retrieval ifree = ret.pstep > 0 texnames = np.array(ret.texnames)[ifree] # MultiNest wrapper call: if ret.sampler == 'multinest': output = pt.multinest_run(self, basename) if pt.get_mpi_rank() != 0: return posterior = output['posterior'] # mc3 MCMC wrapper call: elif ret.sampler == 'snooker': if ret.nsamples is None: log.error('Undefined number of retrieval samples (nsamples)') if ret.burnin is None: log.error('Undefined number of retrieval burn-in samples (burnin)') if ret.nchains is None: log.error('Undefined number of retrieval parallel chains (nchains)') # TBD: Fix resuming ret.resume = False retmodel = False # Return only the band-integrated spectrum # Run MCMC: output = mc3.sample( data=self.obs.data, uncert=self.obs.uncert, func=self.eval, indparams=[retmodel], params=ret.params, pmin=ret.pmin, pmax=ret.pmax, pstep=ret.pstep, prior=ret.prior, priorlow=ret.priorlow, priorup=ret.priorup, sampler=ret.sampler, nsamples=ret.nsamples, nchains=ret.nchains, burnin=ret.burnin, thinning=ret.thinning, grtest=True, grbreak=ret.grbreak, grnmin=ret.grnmin, log=log, ncpu=self.ncpu, plots=False, showbp=True, theme=ret.theme, pnames=ret.pnames, texnames=ret.texnames, resume=ret.resume, savefile=f'{basename}.npz', ) if output is None: log.error("Error in mc3") posterior, zchain, zmask = mc3.utils.burn(output) # Trace plot: savefile = f'{basename}_posterior_trace.png' mc3.plots.trace( posterior, zchain=zchain, burnin=ret.burnin, pnames=texnames, color=ret.theme.color, savefile=savefile, ) log.msg(savefile, indent=2) post = mc3.plots.Posterior( posterior, pnames=texnames, theme=ret.theme, bestp=output['bestp'][ifree], statistics=ret.statistics, show_estimates=True, # TBD: get from cfg? ) # Pairwise posteriors plots: savefile = f'{basename}_posterior_pairwise.png' post.plot(savefile=savefile) log.msg(savefile, indent=2) # Histogram plots: savefile = f'{basename}_posterior_marginal.png' post.plot_histogram(savefile=savefile) log.msg(savefile, indent=2) # Post processing (can be done directly from posterior outputs) ret.bestp = bestp = output['bestp'] ret.posterior = posterior # Best-fitting model: self.spec.specfile = f"{basename}_bestfit_spectrum.dat" ret.spec_best, ret.bestbandflux = self.eval(bestp) filename = f'{basename}_bestfit_spectrum.png' self.plot_spectrum(spec='best', filename=filename) atm = self.atm header = "# Retrieval best-fitting atmospheric model.\n\n" bestatm = f"{basename}_bestfit_atmosphere.atm" io.write_atm( bestatm, atm.press, atm.temp, atm.species, atm.vmr, radius=atm.radius, punits=atm.punits, runits=atm.runits, header=header, ) # Temperature profiles if atm.temp_model is not None: tparams = atm.tpars tparams[ret.map_pars['temp']] = bestp[ret.itemp] ret.temp_best = atm.temp_model(tparams) nsamples, nfree = np.shape(posterior) t_posterior = np.tile(tparams, (nsamples,1)) # Map temperature free parameters from posterior to tparams: ifree = np.where(self.ret.pstep>0)[0] for j, imap in zip(ret.itemp, ret.map_pars['temp']): if j in ifree: ipost = list(ifree).index(j) t_posterior[:,imap] = posterior[:,ipost] tpost = pa.temperature_posterior(t_posterior, atm.temp_model) ret.temp_median = tpost[0] ret.temp_post_boundaries = tpost[1:] self.plot_temperature( filename=f'{basename}_bestfit_temperature.png', ) # Contribution or transmittance is_transmission = self.od.rt_path in pc.transmission_rt path = 'transit' if is_transmission else 'emission' if self.obs.nfilters > 0: band_wl = 1.0/(self.obs.bandwn*pc.um) elif self.obs.nfilters_hires > 0: band_wl = 1.0/(self.obs.wn_hires*pc.um) band_cf = self.band_contribution() filename = f'{basename}_bestfit_contributions.png' pp.contribution(band_cf, band_wl, path, atm.press, filename) self.log = log # Un-mute root_output = os.path.split(basename)[0] log.msg(f"\nOutput retrieval files located at {root_output}") if self.inputs.post_processing: os.environ['PBAY_NO_MPI'] = "1" subprocess.call( f'pbay --post {self.inputs.config_file} &', shell=True, )
[docs] def radiative_equilibrium( self, nsamples=None, continue_run=False, convection=False, ): """ Compute radiative-thermochemical equilibrium atmosphere. Currently there is no convergence criteria implemented, some 100--300 iterations are typically sufficient to converge to a stable temperature-profile solution. Parameters ---------- nsamples: Integer Number of radiative-equilibrium iterations to run. continue_run: Bool If True, continue from a previous radiative-equilibrimu run. convection: Bool If True, skip convective flux calculation in the radiative equilibrium calculation. Returns ------- There are no returned values, but this method updates the temperature profile (self.atm.temp) and abundances (self.atm.vmr) with the values from the last radiative-equilibrium iteration. This method also defines self.atm.radeq_temps, a 2D array containing all temperature-profile iterations. """ atm = self.atm if nsamples is None: nsamples = self.inputs.nsamples # No outputs while iterating tmp_verb = self.log.verb self.log.verb = 0 basename, extension = os.path.splitext(self.spec.specfile) self.spec.specfile = None # Enforce two-stream RT: rt_path = self.od.rt_path self.od.rt_path = 'emission_two_stream' tmin = np.amax(list(self.opacity.tmin.values())) tmax = np.amin(list(self.opacity.tmax.values())) # Initial temperature scale factor if not hasattr(atm, '_dt_scale') or not continue_run: atm._dt_scale = np.tile(1.0e5, atm.nlayers) if hasattr(atm, 'radeq_temps') and continue_run: radeq_temps = atm.radeq_temps else: radeq_temps = np.atleast_2d(atm.temp) print("\nRadiative-thermochemical equilibrium calculation:") radeq_temps = ps.radiative_equilibrium( atm.press, radeq_temps, nsamples, atm.chem_model, self.run, self.spec.wn, self.spec, atm, convection, tmin, tmax, ) # Update last tempertature iteration and save to file: self.atm.radeq_temps = radeq_temps atm.temp = radeq_temps[-1] io.write_atm( f'{basename}.atm', atm.press, atm.temp, atm.species, atm.vmr, punits="bar", header="# Radiative-thermochemical equilibrium profile.\n\n", ) self.od.rt_path = rt_path np.savez(f'{basename}.npz', pressure=atm.press, temps=radeq_temps) self.spec.specfile = f'{basename}.dat' spec_type = 'emission' io.write_spectrum( self.spec.wl, self.spec.spectrum, self.spec.specfile, spec_type, ) self.log.verb = tmp_verb
[docs] def band_integrate(self): """ Band-integrate transmission spectrum (transit) or planet-to-star flux ratio (eclipse) over transmission band passes. """ if self.obs.filters is None: return None if self.od.rt_path in pc.transmission_rt: spectrum = self.spec.spectrum else: spectrum = self.spec.fplanet bandflux = np.array([band(spectrum) for band in self.obs.filters]) if self.od.rt_path in pc.eclipse_rt: rprs = self.atm.rplanet/self.atm.rstar bandflux *= rprs**2.0 / self.obs.bandflux_star self.obs.bandflux = bandflux return self.obs.bandflux
[docs] def band_contribution(self): """ Compute contribution functions or transmittance at each band. """ if self.obs.nfilters_hires != 0: bands = self.obs.filters_hires else: bands = self.obs.filters bands_idx = [band.idx for band in bands] responses = [band.response for band in bands] is_transmission = self.od.rt_path in pc.transmission_rt if is_transmission: contrib = ps.transmittance(self.od.depth, self.od.ideep) if self.opacity.is_patchy: patchy = self.opacity.fpatchy depth = self.od.depth_clear contrib_clear = ps.transmittance(depth, self.od.ideep_clear) contrib = patchy*contrib + (1.0-patchy)*contrib_clear else: contrib = ps.contribution_function( self.od.depth, self.atm.press, self.od.B, ) cf = ps.band_cf(contrib, responses, self.spec.wn, bands_idx) return cf
[docs] def get_ec(self, layer): """ Extract extinction-coefficient contribution (in cm-1) from each component of the atmosphere at the requested layer. Parameters ---------- layer: Integer The index of the atmospheric layer where to extract the EC. Returns ------- ec: 2D float ndarray An array of shape [ncomponents, nwave] with the EC spectra (in cm-1) from each component of the atmosphere. label: List of strings The names of each atmospheric component that contributed to EC. """ if len(self.opacity.models) > 0: return self.opacity.get_ec(self.atm.temp, self.atm.d, layer) return None, []
[docs] def plot_spectrum(self, spec='model', **kwargs): """ Plot spectrum. Parameters ---------- spec: String Flag indicating which model to plot. By default plot the latest evaulated model (spec='model'). Another option is 'best', to plot the posterior best-fit (after a retrieval posterior run). kwargs: dict Dictionary of arguments to pass into plots.spectrum(). See help(pyratbay.plots.spectrum). Returns ------- ax: AxesSubplot instance The matplotlib Axes of the figure. """ obs = self.obs args = { 'logxticks': self.inputs.logxticks, 'yran': self.inputs.yran, 'theme': self.ret._default_theme, 'data_color': self.inputs.data_color, } is_hires = obs.nfilters_hires > 0 if is_hires: band_wl = np.array([band.wl0 for band in obs.filters_hires]) args['wavelength'] = band_wl args['data'] = obs.data_hires args['uncert'] = obs.uncert_hires args['bands_wl0'] = band_wl args['resolution'] = None args['marker'] = '.' args['data_front'] = False else: args['wavelength'] = self.spec.wl args['data'] = obs.data args['uncert'] = obs.uncert args['bands_wl0'] = [band.wl0 for band in obs.filters] args['bands_wl'] = [band.wl for band in obs.filters] args['bands_response'] = [band.response for band in obs.filters] args['bands_flux'] = obs.bandflux if self.obs.inst_resolution is not None: args['resolution'] = self.obs.inst_resolution args['marker'] = 'o' args['data_front'] = True if spec == 'model': args['label'] = 'model' if is_hires: args['spectrum'] = obs.bandflux_hires else: args['spectrum'] = self.spec.spectrum args['bands_flux'] = obs.bandflux elif spec == 'best': args['label'] = 'best-fit model' if is_hires: args['spectrum'] = self.ret.bestbandflux else: args['spectrum'] = self.ret.spec_best args['bands_flux'] = self.ret.bestbandflux else: return if self.od.rt_path == 'f_lambda': args['rt_path'] = 'f_lambda' elif self.od.rt_path in pc.transmission_rt: args['rt_path'] = 'transit' elif self.od.rt_path in pc.eclipse_rt: args['rt_path'] = 'eclipse' else: args['rt_path'] = 'emission' # kwargs can overwite any of the previous value: args.update(kwargs) ax = pp.spectrum(**args) return ax
[docs] def plot_temperature(self, **kwargs): """ Plot temperature profile. If self.ret.posterior exitst, plot the best fit, median, and the '1sigma/2sigma' boundaries of the temperature posterior distribution. Parameters ---------- kwargs: dict Dictionary of arguments to pass into plots.temperature(). See help(pyratbay.plots.temperature). Returns ------- ax: AxesSubplot instance The matplotlib Axes of the figure. """ kwargs['pressure'] = self.atm.press kwargs['theme'] = self.ret.theme if self.ret.posterior is None: kwargs['profiles'] = [self.atm.temp] else: kwargs['profiles'] = [self.ret.temp_median, self.ret.temp_best] kwargs['labels'] = ['median', 'best-fit'] kwargs['bounds'] = self.ret.temp_post_boundaries ax = pp.temperature(**kwargs) return ax
def __str__(self): if self.spec.resolution is not None: wave = f"R={self.spec.resolution:.1f}" elif self.spec.wlstep is not None: wave = f'delta-wl={self.spec.wlstep:.2f}' else: wave = f"delta-wn={self.spec.wnstep:.3f} cm-1" opacities = [] if len(self.opacity.models) > 0: for i,model in enumerate(self.opacity.models): if self.opacity.models_type[i] in ['line_sample', 'lbl']: opacities += model.species.tolist() elif self.opacity.models_type[i] in ['alkali']: opacities.append(model.species) else: opacities.append(model.name) pmin = self.atm.press[ 0] pmax = self.atm.press[-1] wlmin = 1.0/(self.spec.wn[-1]*pc.um) wlmax = 1.0/(self.spec.wn[ 0]*pc.um) return ( "Pyrat atmospheric model\n" f"configuration file: '{self.inputs.config_file}'\n" f"Pressure profile: {pmin:.2e} -- {pmax:.2e} bar " f"({self.atm.nlayers:d} layers)\n" f"Wavelength range: {wlmin:.2f} -- {wlmax:.2f} um " f"({self.spec.nwave:d} samples, {wave})\n" f"Composition:\n {self.atm.species}\n" f"Opacity sources:\n {opacities}" )