# Copyright (c) 2021-2022 Patricio Cubillos
# Pyrat Bay is open-source software under the GPL-2.0 license (see LICENSE)
__all__ = [
'alphatize',
'spectrum',
'contribution',
'temperature',
'abundance',
'default_colors',
]
from itertools import cycle
from cycler import cycler, Cycler
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.colors import is_color_like, to_rgb
import numpy as np
import scipy.interpolate as si
from scipy.ndimage import gaussian_filter1d as gaussf
from .. import constants as pc
from .. import tools as pt
default_colors = {
'H2O': "navy",
'CO2': "red",
'CO': "limegreen",
'CH4': "orange",
'H2': "deepskyblue",
'He': "seagreen",
'HCN': "0.7",
'NH3': "magenta",
'C2H2': "brown",
'C2H4': "pink",
'N2': "gold",
'H': "olive",
'TiO': "black",
'VO': "peru",
'Na': "darkviolet",
'K': "cornflowerblue",
}
[docs]def alphatize(colors, alpha, bg='w'):
"""
Get rgb representation of a color as if it had the specified alpha.
Parameters
----------
colors: color or iterable of colors
The color to alphatize.
alpha: Float
Alpha value to apply.
bg: color
Background color.
Returns
-------
rgb: RGB or list of RGB color arrays
The RGB representation of the alphatized color (or list of colors).
Examples
--------
>>> import pyrabay.plots as pp
>>> pp.alphatize('r', 0.5)
array([1. , 0.5, 0.5])
>>> pp.alphatize(['r', 'b'], 0.8)
[array([1. , 0.2, 0.2]), array([0.2, 0.2, 1. ])]
"""
flatten = False
if is_color_like(colors):
colors = [colors]
flatten = True
colors = [np.array(to_rgb(color)) for color in colors]
bg = np.array(to_rgb(bg))
# https://matplotlib.org/tutorials/colors/colors.html
rgb = [(1.0-alpha) * bg + alpha*c for c in colors]
if flatten:
return rgb[0]
return rgb
[docs]def spectrum(
spectrum, wavelength, rt_path,
data=None, uncert=None, bandwl=None, bandflux=None,
bandtrans=None, bandidx=None,
starflux=None, rprs=None, label='model', bounds=None,
logxticks=None,
gaussbin=2.0, yran=None, filename=None, fignum=501, axis=None):
"""
Plot a transmission or emission model spectrum with (optional) data
points with error bars and band-integrated model.
Parameters
----------
spectrum: 1D float ndarray
Planetary spectrum evaluated at wavelength.
wavelength: 1D float ndarray
The wavelength of the model in microns.
rt_path: String
Radiative-transfer observing geometry (transit, eclipse, or emission).
data: 1D float ndarray
Observing data points at each bandwl.
uncert: 1D float ndarray
Uncertainties of the data points.
bandwl: 1D float ndarray
The mean wavelength for each band/data point.
bandflux: 1D float ndarray
Band-integrated model spectrum at each bandwl.
bandtrans: List of 1D float ndarrays
Transmission curve for each band.
bandidx: List of 1D float ndarrays.
The indices in wavelength for each bandtrans.
starflux: 1D float ndarray
Stellar spectrum evaluated at wavelength.
rprs: Float
Planet-to-star radius ratio.
label: String
Label for spectrum curve.
bounds: Tuple
Tuple with -2, -1, +1, and, +2 sigma boundaries of spectrum.
If not None, plot shaded area between +/-1sigma and +/-2sigma
boundaries.
logxticks: 1D float ndarray
If not None, switch the X-axis scale from linear to log, and set
the X-axis ticks at the locations given by logxticks.
gaussbin: Integer
Standard deviation for Gaussian-kernel smoothing (in number of samples).
yran: 1D float ndarray
Figure's Y-axis boundaries.
filename: String
If not None, save figure to filename.
fignum: Integer
Figure number.
axis: AxesSubplot instance
The matplotlib Axes of the figure.
Returns
-------
ax: AxesSubplot instance
The matplotlib Axes of the figure.
"""
# Plotting setup:
fs = 14.0
ms = 6.0
lw = 1.25
if axis is None:
plt.figure(fignum, (8, 5))
plt.clf()
ax = plt.subplot(111)
else:
ax = axis
spec_kw = {'label':label}
if bounds is None:
spec_kw['color'] = 'orange'
else:
spec_kw['color'] = 'orangered'
# Setup according to geometry:
if rt_path == 'emission':
fscale = 1.0
plt.ylabel(r'$F_{\rm p}$ (erg s$^{-1}$ cm$^{-2}$ cm)', fontsize=fs)
if rt_path == 'eclipse':
#if starflux is not None and rprs is not None:
spectrum = spectrum/starflux * rprs**2.0
if bounds is not None:
bounds = [bound/starflux * rprs**2.0 for bound in bounds]
fscale = 1.0 / pc.ppt
plt.ylabel(r'$F_{\rm p}/F_{\rm s}\ (ppt)$', fontsize=fs)
elif rt_path == 'transit':
fscale = 1.0 / pc.percent
plt.ylabel(r'$(R_{\rm p}/R_{\rm s})^2$ (%)', fontsize=fs)
gmodel = gaussf(spectrum, gaussbin)
if bounds is not None:
gbounds = [gaussf(bound, gaussbin) for bound in bounds]
ax.fill_between(
wavelength, fscale*gbounds[0], fscale*gbounds[3],
facecolor='gold', edgecolor='none',
)
ax.fill_between(
wavelength, fscale*gbounds[1], fscale*gbounds[2],
facecolor='orange', edgecolor='none',
)
# Plot model:
plt.plot(wavelength, gmodel*fscale, lw=lw, **spec_kw)
# Plot band-integrated model:
if bandflux is not None and bandwl is not None:
plt.plot(
bandwl, bandflux*fscale, 'o', ms=ms, color='tomato',
mec='maroon', mew=lw,
)
# Plot data:
if data is not None and uncert is not None and bandwl is not None:
plt.errorbar(
bandwl, data*fscale, uncert*fscale, fmt='o', label='data',
color='blue', ms=ms, elinewidth=lw, capthick=lw, zorder=3,
)
if yran is not None:
ax.set_ylim(np.array(yran))
yran = ax.get_ylim()
# Transmission filters:
if bandtrans is not None and bandidx is not None:
bandh = 0.06*(yran[1] - yran[0])
for btrans, bidx in zip(bandtrans, bandidx):
btrans = bandh * btrans/np.amax(btrans)
plt.plot(wavelength[bidx], yran[0]+btrans, '0.4', zorder=-10)
ax.set_ylim(yran)
if logxticks is not None:
ax.set_xscale('log')
ax.xaxis.set_minor_formatter(matplotlib.ticker.NullFormatter())
ax.xaxis.set_major_formatter(matplotlib.ticker.ScalarFormatter())
ax.set_xticks(logxticks)
ax.tick_params(
which='both', right=True, top=True, direction='in', labelsize=fs-2,
)
plt.xlabel('Wavelength (um)', fontsize=fs)
plt.legend(loc='best', numpoints=1, fontsize=fs-1)
plt.xlim(np.amin(wavelength), np.amax(wavelength))
plt.tight_layout()
if filename is not None:
plt.savefig(filename)
return ax
[docs]def contribution(
contrib_func, wl, rt_path, pressure, radius, rtop=0,
filename=None, filters=None, fignum=-21,
):
"""
Plot the band-integrated normalized contribution functions
(emission) or transmittance (transmission).
Parameters
----------
contrib_func: 2D float ndarray
Band-integrated contribution functions [nfilters, nlayers].
wl: 1D float ndarray
Mean wavelength of the bands in microns.
rt_path: String
Radiative-transfer observing geometry (emission or transit).
pressure: 1D float ndarray
Layer's pressure array (barye units).
radius: 1D float ndarray
Layer's impact parameter array (cm units).
rtop: Integer
Index of topmost valid layer.
filename: String
Filename of the output figure.
filters: 1D string ndarray
Name of the filter bands (optional).
fignum: Integer
Figure number.
Returns
-------
ax: AxesSubplot instance
The matplotlib Axes of the figure.
Notes
-----
- The dashed lines denote the 0.16 and 0.84 percentiles of the
cumulative contribution function or the transmittance (i.e.,
the boundaries of the central 68% of the respective curves).
- If there are more than 80 filters, this code will thin the
displayed filter names.
"""
nfilters = len(wl)
nlayers = len(pressure)
wlsort = np.argsort(wl)
wl = wl[wlsort]
contrib_func = contrib_func[:,wlsort]
if filters is not None:
filters = [filters[i] for i in wlsort]
press = pressure[rtop:]/pc.bar
rad = radius[rtop:]/pc.km
press = pressure[rtop:]/pc.bar
rad = radius[rtop:]/pc.km
zz = contrib_func/np.amax(contrib_func)
is_emission = rt_path in pc.emission_rt
is_transit = rt_path in pc.transmission_rt
if is_emission:
yran = np.amax(np.log10(press)), np.amin(np.log10(press))
xlabel = 'contribution function'
ylabel = ''
yright = 0.9
cbtop = 0.5
elif is_transit:
yran = np.amin(rad), np.amax(rad)
xlabel = 'transmittance'
ylabel = 'Impact parameter (km)'
yright = 0.84
cbtop = 0.8
else:
rt_paths = pc.rt_paths
print(f"Invalid radiative-transfer geometry. Select from: {rt_paths}.")
return
fs = 12
colors = np.asarray(np.linspace(0, 255, nfilters), int)
# 68% percentile boundaries of the central cumulative function:
lo = 0.5*(1-0.683)
hi = 1.0 - lo
# Filter fontsize and thinning:
ffs = 8.0 + (nfilters<50) + (nfilters<65)
thin = (nfilters>80) + (nfilters>125) + (nfilters<100) + nfilters//100
# Colormap and percentile limits:
z = np.empty((nfilters, nlayers, 4), dtype=float)
plo = np.zeros(nfilters+1)
phi = np.zeros(nfilters+1)
for i in range(nfilters):
z[i] = plt.cm.rainbow(colors[i])
z[i,:,-1] = zz[:,i]**(0.5+0.5*(is_transit))
if is_emission:
cumul = np.cumsum(zz[:,i])/np.sum(zz[:,i])
plo[i], phi[i] = press[cumul>lo][0], press[cumul>hi][0]
elif is_transit:
plo[i], phi[i] = press[zz[:,i]<lo][0], press[zz[:,i]<hi][0]
plo[-1] = plo[-2]
phi[-1] = phi[-2]
fig = plt.figure(fignum, (8.5, 5))
plt.clf()
plt.subplots_adjust(0.105, 0.10, yright, 0.95)
ax = plt.subplot(111)
pax = ax.twinx()
if is_emission:
ax.imshow(
z.swapaxes(0,1), aspect='auto',
extent=[0, nfilters, yran[0], yran[1]],
origin='upper', interpolation='nearest',
)
ax.yaxis.set_visible(False)
pax.spines['left'].set_visible(True)
pax.yaxis.set_label_position('left')
pax.yaxis.set_ticks_position('left')
elif is_transit:
ax.imshow(
z.swapaxes(0,1), aspect='auto',
extent=[0,nfilters,yran[0],yran[1]],
origin='upper', interpolation='nearest',
)
# Setting the right radius tick labels requires some sorcery:
fig.canvas.draw()
ylab = [l.get_text() for l in ax.get_yticklabels()]
rint = si.interp1d(rad, press, bounds_error=False)
pticks = rint(ax.get_yticks())
bounds = np.isfinite(pticks)
pint = si.interp1d(
press, np.linspace(yran[1], yran[0], nlayers), bounds_error=False,
)
ax.set_yticks(pint(pticks[bounds]))
ax.set_yticklabels(np.array(ylab)[bounds])
pax.plot(plo, drawstyle='steps-post', color='0.25', lw=0.75, ls='--')
pax.plot(phi, drawstyle='steps-post', color='0.25', lw=0.75, ls='--')
pax.set_ylim(np.amax(press), np.amin(press))
pax.set_yscale('log')
pax.set_ylabel(r'Pressure (bar)', fontsize=fs)
ax.set_xlim(0, nfilters)
ax.set_ylim(yran)
ax.set_xticklabels([])
ax.set_ylabel(ylabel, fontsize=fs)
ax.set_xlabel(f'Band-averaged {xlabel}', fontsize=fs)
# Print filter names/wavelengths:
for i in range(0, nfilters-thin//2, thin):
fname = f' {wl[i]:5.2f} um '
# Strip root and file extension:
if filters is not None:
fname = str(filters[i]) + ' @' + fname
ax.text(
i+0.1, yran[1], fname,
rotation=90, ha='left', va='top', fontsize=ffs,
)
# Color bar:
cbar = plt.axes([0.925, 0.10, 0.015, 0.85])
cz = np.zeros((100, 2, 4), dtype=float)
cz[:,0,3] = np.linspace(0.0,cbtop,100)**(0.5+0.5*(is_transit))
cz[:,1,3] = np.linspace(0.0,cbtop,100)**(0.5+0.5*(is_transit))
cbar.imshow(
cz, aspect='auto', extent=[0, 1, 0, 1],
origin='lower', interpolation='nearest',
)
if is_transit:
cbar.axhline(0.1585, color='k', lw=1.0, dashes=(2.5,1))
cbar.axhline(0.8415, color='w', lw=1.0, dashes=(2.5,1))
cbar.spines['right'].set_visible(True)
cbar.yaxis.set_label_position('right')
cbar.yaxis.set_ticks_position('right')
cbar.set_ylabel(xlabel.capitalize(), fontsize=fs)
cbar.xaxis.set_visible(False)
fig.canvas.draw()
if filename is not None:
plt.savefig(filename)
return ax
[docs]def temperature(
pressure, profiles=None, labels=None, colors=None,
bounds=None, punits='bar', ax=None, filename=None,
theme='blue', alpha=[0.8,0.6], fs=13, lw=2.0, fignum=504,
):
"""
Plot temperature profiles.
Parameters
----------
pressure: 1D float ndarray
The atmospheric pressure profile in barye.
profiles: iterable of 1D float ndarrays
Temperature profiles to plot.
labels: 1D string iterable
Labels for temperature profiles.
colors: 1D string iterable.
Colors for temperature profiles.
bounds: Tuple
Tuple with -1sigma, +1sigma, -2sigma, and +2sigma temperature
boundaries.
If not None, plot shaded area between +/-1sigma and +/-2sigma
boundaries.
punits: String
Pressure units for output plot (input units are always barye).
ax: AxesSubplot instance
If not None, plot into the given axis.
filename: String
If not None, save plot to given file name.
theme: String
The histograms' color theme for bounds regions.
Only 'blue' and 'orange' themes are valid at the moment.
Alternatively, provide a two-element iterable to provide the colors.
alpha: 2-element float iterable
Alpha transparency for bounds regions.
fs: Float
Labels font sizes.
lw: Float
Lines width.
fignum: Integer
Figure's number (ignored if axis is not None).
Returns
-------
ax: AxesSubplot instance
The matplotlib Axes of the figure.
"""
press = pressure / pt.u(punits)
if theme == 'blue':
col1, col2 = 'royalblue', 'royalblue'
elif theme == 'orange':
col1, col2 = 'orange', 'gold'
else: # Custom pair of colors:
col1, col2 = theme
# alpha != 0 does not work for ps/eps figures:
alpha1, alpha2 = alpha[:]
if filename is not None and filename.endswith('ps'):
fc2 = alphatize(col2, alpha2, 'white')
fc1 = alphatize(col1, alpha1, fc2)
alpha1 = alpha2 = 1.0
else:
fc1, fc2 = col1, col2
if profiles is None:
profiles = []
if np.ndim(profiles) == 1 and len(profiles) == len(pressure):
profiles = [profiles]
if labels is None:
_labels = [None for _ in profiles]
else:
_labels = labels
if colors is None:
c = cycle(default_colors.values())
colors = [next(c) for _ in profiles]
tighten = ax is None
if ax is None:
plt.figure(fignum, (7,5))
plt.clf()
ax = plt.subplot(111)
if bounds is not None and len(bounds) == 4:
low2, high2 = bounds[2:4]
ax.fill_betweenx(
press, low2, high2, facecolor=fc2, edgecolor='none', alpha=alpha2,
)
if bounds is not None and len(bounds) >= 2:
low1, high1 = bounds[0:2]
ax.fill_betweenx(
press, low1, high1, facecolor=fc1, edgecolor='none', alpha=alpha1,
)
for profile, color, label in zip(profiles, colors, _labels):
plt.plot(profile, press, color, lw=lw, label=label)
ax.set_ylim(np.amax(press), np.amin(press))
ax.set_yscale('log')
plt.xlabel('Temperature (K)', fontsize=fs)
plt.ylabel(f'Pressure ({punits})', fontsize=fs)
ax.tick_params(labelsize=fs-2)
if labels is not None:
plt.legend(loc='best', fontsize=fs-2)
if tighten:
plt.tight_layout()
if filename is not None:
plt.savefig(filename)
return ax
[docs]def abundance(
vol_mix_ratios, pressure, species,
highlight=None, xlim=None, punits='bar',
colors=None, dashes=None, filename=None,
lw=2.0, fignum=505, fs=13, legend_fs=None, ax=None,
):
"""
Plot atmospheric volume-mixing-ratio abundances.
Parameters
----------
vol_mix_ratios: 2D float ndarray
Atmospheric volume mixing ratios to plot [nlayers,nspecies].
pressure: 1D float ndarray
Atmospheric pressure [nlayers], units are given by punits argument.
species: 1D string iterable
Atmospheric species names [nspecies].
highlight: 1D string iterable
List of species names to highlight. Non-highlighed species are
plotted with alpha=0.4, below the highligted species, and are
not considered to set the default xlim (e.g., might not be shown
if their abundances are too low).
If None, all input species are highlighted.
xlim: 2-element float iterable
Volume mixing ratio plotting boundaries.
punits: String
Pressure units.
colors: 1D string iterable
List of colors to use.
- If len(colors) >= len(species), colors are assigned to each
species irrespective of highlight.
- If len(colors) < len(species), the display will cycle the
color list using solid, long-dashed, short-dashed, and dotted
line styles (all highlight species being displayed before the rest).
- If colors == 'default', use pyratbay.plots.default_colors
dict to assign colors.
- If colors is None, use matplotlib's default color cycler.
dashes: 1D dash-sequence iterable
List of line-styles for each species, irrespective of highlight.
len(dashes) has to be equal to len(species).
Alternatively, dashes can by a dash-sequence Cycler.
filename: String
If not None, save plot to given file name.
lw: Float
Lines width.
fignum: Integer
Figure's number (ignored if axis is not None).
fs: Float
Labels font sizes.
legend_fs: Float
Legend font size. If legend_fs is None, default to fs-2.
If legend_fs <= 0, do not plot a legend.
ax: AxesSubplot instance
If not None, plot into the given axis.
Returns
-------
ax: AxesSubplot instance
The matplotlib Axes of the figure.
Examples
--------
>>> import pyratbay.atmosphere as pa
>>> import pyratbay.plots as pp
>>> nlayers = 51
>>> pressure = pa.pressure('1e-6 bar', '1e2 bar', nlayers)
>>> temperature = pa.temperature('isothermal', pressure, params=1000.0)
>>> species = 'H2O CH4 CO CO2 NH3 C2H2 C2H4 HCN N2 TiO VO H2 H He Na K'.split()
>>> vmr = pa.chemistry('tea', pressure, temperature, species).vmr
>>> ax = pp.abundance(
>>> vmr, pressure, species, colors='default',
>>> highlight='H2O CH4 CO CO2 NH3 HCN H2 H He'.split())
"""
if legend_fs is None:
legend_fs = fs - 2
if highlight is None:
highlight = np.copy(species)
highlight = [spec for spec in species if spec in highlight]
lowlight = [spec for spec in species if spec not in highlight]
sorted_spec = highlight + lowlight
used_cols = []
if colors is None:
colors = matplotlib.rcParams['axes.prop_cycle'].by_key()['color']
if len(colors) >= len(species):
cols = colors
elif colors == 'default':
cols = [
default_colors[mol] if mol in default_colors
else None
for mol in species
]
used_cols = [c for c in default_colors.values() if c in cols]
remaining_cols = [c for c in default_colors.values() if c not in cols]
colors = used_cols + remaining_cols
else:
cols = [None for _ in species]
if isinstance(dashes, Cycler):
dash_cycler = dashes
dashes = None
else:
dash_cycler = cycler(dashes=[(), (8,1.5), (4,1), (1,1)])
dkws = cycle(dash_cycler * cycler(color=colors))
for _ in used_cols:
dkw = next(dkws)
_dashes = [() for _ in species]
for i in range(len(species)):
ispec = list(species).index(sorted_spec[i])
if cols[ispec] is None:
dkw = next(dkws)
cols[ispec] = dkw['color']
_dashes[ispec] = dkw['dashes']
if dashes is None or len(dashes) != len(species):
dashes = _dashes
press = pressure / pt.u(punits)
# Plot the results:
if ax is None:
plt.figure(fignum, (7,5))
plt.clf()
ax = plt.subplot(111)
for spec in highlight:
imol = list(species).index(spec)
ax.loglog(
vol_mix_ratios[:,imol], press, label=spec, lw=lw,
color=cols[imol], dashes=dashes[imol],
)
if xlim is None:
xlim = ax.get_xlim()
for spec in lowlight:
imol = list(species).index(spec)
ax.loglog(
vol_mix_ratios[:,imol], press, label=spec, lw=lw, zorder=-1,
color=alphatize(cols[imol],alpha=0.4), dashes=dashes[imol],
)
ax.set_xlim(xlim)
ax.set_ylim(np.amax(press), np.amin(press))
ax.set_xlabel('Volume mixing ratio', fontsize=fs)
ax.set_ylabel(f'Pressure ({punits})', fontsize=fs)
ax.tick_params(
which='both', right=True, top=True, direction='in', labelsize=fs-2,
)
if legend_fs > 0:
ax.legend(loc='best', fontsize=legend_fs)
if filename is not None:
plt.savefig(filename)
return ax