# Copyright (c) 2021-2026 Cubillos & Blecic
# Pyrat Bay is open-source software under the GPL-2.0 license (see LICENSE)
__all__ = [
'log_error',
'cd',
'tmp_reset',
'eta',
'resolve_theme',
'binsearch',
'divisors',
'unpack',
'u',
'get_param',
'is_number',
'ifirst', 'ilast',
'mkdir',
'isfile',
'file_exists',
'path',
'Formatted_Write',
'Timer',
'get_exomol_mol',
'cia_hitran', 'cia_borysow',
'interpolate_opacity',
'none_div',
'radius_to_depth',
'depth_to_radius',
]
import os
import re
import struct
import time
import numbers
import string
import textwrap
import itertools
import functools
from collections.abc import Iterable
from contextlib import contextmanager
from matplotlib.colors import is_color_like
import mc3.utils as mu
import mc3.plots as mp
import numpy as np
import scipy.interpolate as sip
from .mpi_tools import (
get_mpi_rank,
mpi_barrier,
)
from .. import constants as pc
from .. import io as io
from ..lib import _indices
[docs]
@contextmanager
def log_error(log=None, error=None):
"""Capture exceptions into a log.error() call."""
try:
yield
except Exception as e:
if log is None:
log = mu.Log(logname=None, verb=1, width=80)
if error is None:
error = str(e)
log.error(error, ValueError)
[docs]
@contextmanager
def cd(newdir):
"""
Context manager for changing the current working directory.
Taken from here: https://stackoverflow.com/questions/431684/
"""
olddir = os.getcwd()
os.chdir(os.path.expanduser(newdir))
try:
yield
finally:
os.chdir(olddir)
def recursive_setattr(obj, attr, val):
"""Recursive setattr, see https://stackoverflow.com/questions/31174295"""
pre, _, post = attr.rpartition('.')
return setattr(recursive_getattr(obj, pre) if pre else obj, post, val)
def recursive_getattr(obj, attr):
"""Recursive getattr, see https://stackoverflow.com/questions/31174295"""
def _getattr(obj, attr):
return getattr(obj, attr)
return functools.reduce(_getattr, [obj] + attr.split('.'))
[docs]
@contextmanager
def tmp_reset(obj, *attrs, **tmp_attrs):
"""
Temporarily remove attributes from an object.
Examples
--------
>>> import pyratbay.tools as pt
>>> o = type('obj', (object,), {'x':1.0, 'y':2.0})
>>> obj = type('obj', (object,), {'z':3.0, 'w':4.0, 'o':o})
>>> # All listed arguments are set to None:
>>> with pt.tmp_reset(obj, 'o.x', 'z'):
>>> print(obj.o.x, obj.o.y, obj.z, obj.w)
(None, 2.0, None, 4.0)
>>> # Keyword arguments can be set to a value, but cannot be recursive:
>>> with pt.tmp_reset(obj, 'o.x', z=10):
>>> print(obj.o.x, obj.o.y, obj.z, obj.w)
(None, 2.0, 10, 4.0)
"""
orig_attrs = {}
for attr in attrs:
orig_attrs[attr] = recursive_getattr(obj, attr)
recursive_setattr(obj, attr, None)
for attr, tmp_val in tmp_attrs.items():
orig_attrs[attr] = recursive_getattr(obj, attr)
recursive_setattr(obj, attr, tmp_val)
yield
for attr, orig_val in orig_attrs.items():
recursive_setattr(obj, attr, orig_val)
[docs]
def eta(time_seconds, n_completed, n_total, fmt='.2f'):
"""
Find most appropriate units to report the remaining time
(seconds, minutes, hours, days)
Parameters
----------
time_seconds: Float
An amount of time in seconds.
n_completed: Integer
Number of completed steps.
n_total: Integer
Total number of steps to complete.
Returns
-------
delta_time: Float
The time_seconds in the recalculated units.
"""
delta_time = time_seconds * (n_total-n_completed) / n_completed
units = 'sec'
if delta_time < 60.0:
return f'{delta_time:{fmt}} {units}'
delta_time /= 60.0
units = 'min'
if delta_time < 60.0:
return f'{delta_time:{fmt}} {units}'
delta_time /= 60.0
units = 'hours'
if delta_time < 24.0:
return f'{delta_time:{fmt}} {units}'
delta_time /= 24.0
units = 'days'
return f'{delta_time:{fmt}} {units}'
[docs]
def resolve_theme(theme):
"""
Resolve input Theme or color into a mc3.plots.Theme instance.
Makes sure that input is either None, a mc3.plots.Theme, or
a value that can be interpreted as a matplotlib color.
Parameters
----------
theme: Any
A matplotlib color or a mc3.plots.Theme instance
Returns
-------
theme: mc3.plots.Theme instance
A Theme computed using the input color.
Examples
--------
>>> import pyratbay.tools as pt
>>> import mc3
>>> # A Theme instance is returned unmodified
>>> theme = pt.resolve_theme(mc3.plots.THEMES['indigo'])
>>> # Anything that can be interpreted as matplolib color:
>>> theme1 = pt.resolve_theme('red')
>>> theme2 = pt.resolve_theme('xkcd:green')
>>> theme3 = pt.resolve_theme((0,0,1))
>>> # If input is None, return None
>>> theme = pt.resolve_theme(None)
>>> # Anything else will throw an error:
>>> theme = pt.resolve_theme('not_a_plt_color')
ValueError: Invalid color theme: 'not_a_plt_color'
"""
if isinstance(theme, mp.Theme) or theme is None:
pass
elif isinstance(theme, str) and theme in mp.THEMES:
theme = mp.THEMES[theme]
elif is_color_like(theme):
theme = mp.Theme(theme)
else:
raise ValueError(f"Invalid color theme: '{theme}'")
return theme
[docs]
def binsearch(tli, wnumber, rec0, nrec, upper=True):
r"""
Do a binary+linear search in TLI dbfile for record with wavenumber
immediately less equal to wnumber (if upper is True), or greater
equal to wnumber (if upper) is False (considering duplicate values
in tli file).
Parameters
----------
tli: File object
TLI file where to search.
wnumber: Scalar
Target wavenumber in cm-1.
rec0: Integer
File position of first wavenumber record.
nrec: Integer
Number of wavenumber records.
upper: Boolean
If True, consider wnumber as an upper boundary. If False,
consider wnumber as a lower boundary.
Returns
-------
irec: Integer
Index of record nearest to target. Return -1 if out of bounds.
Examples
--------
>>> import pyratbay.tools as pt
>>> import struct
>>> # Mock a TLI file:
>>> wn = [0.0, 1.0, 1.0, 1.0, 2.0, 2.0]
>>> with open('tli_demo.dat', 'wb') as tli:
>>> tli.write(struct.pack(str(len(wn))+"d", *wn))
>>> # Now do bin searches for upper and lower boundaries:
>>> with open('tli_demo.dat', 'rb') as tli:
>>> bs_lower = [pt.binsearch(tli, target, 0, len(wn), upper=False)
>>> for target in [-1.0, 0.0, 0.5, 1.0, 1.5, 2.0, 2.5]]
>>> bs_upper = [pt.binsearch(tli, target, 0, len(wn), upper=True)
>>> for target in [-1.0, 0.0, 0.5, 1.0, 1.5, 2.0, 2.5]]
>>> print(bs_lower, bs_upper, sep='\n')
[0, 0, 1, 1, 4, 4, -1]
[-1, 0, 0, 3, 3, 5, 5]
"""
if nrec <= 0:
raise ValueError('Requested binsearch over a zero a zero-sized array.')
# Initialize indices and current record:
irec = ilo = 0
ihi = nrec - 1
tli.seek(rec0, 0)
current = first = struct.unpack('d', tli.read(8))[0]
tli.seek(rec0 + ihi*pc.dreclen, 0)
last = struct.unpack('d', tli.read(8))[0]
# Out of bounds:
if wnumber < first and upper:
return -1
if last < wnumber and not upper:
return -1
# Binary search:
while ihi - ilo > 1:
irec = (ihi + ilo) // 2
tli.seek(rec0 + irec*pc.dreclen, 0)
current = struct.unpack('d', tli.read(8))[0]
if current > wnumber:
ihi = irec
else:
ilo = irec
# Linear search:
if upper and current > wnumber:
return irec - 1
elif not upper and current < wnumber:
return irec + 1
elif upper:
while current <= wnumber:
irec += 1
if irec > nrec-1:
return nrec-1
tli.seek(rec0 + irec*pc.dreclen, 0)
current = struct.unpack('d', tli.read(8))[0]
return irec - 1
else:
while current >= wnumber:
irec -= 1
if irec < 0:
return 0
tli.seek(rec0 + irec*pc.dreclen, 0)
current = struct.unpack('d', tli.read(8))[0]
return irec + 1
[docs]
def divisors(number):
"""
Find all the integer divisors of number.
"""
divs = []
for i in np.arange(1, number/2+1):
if number % i == 0:
divs.append(i)
divs.append(number)
return np.asarray(divs, int)
[docs]
def unpack(file, n, dtype):
r"""
Wrapper for struct unpack.
Parameters
----------
file: File object
File object to read from.
n: Integer
Number of elements to read from file.
dtype: String
Data type of the bytes read.
Returns
-------
output: Scalar, tuple, or string
If dtype is 's' return the string (decoded as UTF-8).
If there is a single element to read, return the scalar value.
Else, return a tuple with the elements read.
Examples
--------
>>> import pyratbay.tools as pt
>>> import struct
>>> import numpy as np
>>> # Store a string and numbers in a binary file:
>>> with open('delete_me.dat', 'wb') as bfile:
>>> bfile.write(struct.pack('3s', 'H2O'.encode('utf-8')))
>>> bfile.write(struct.pack('h', 3))
>>> bfile.write(struct.pack('3f', np.pi, np.e, np.inf))
>>> # Unpack them:
>>> with open('delete_me.dat', 'rb') as bfile:
>>> string = pt.unpack(bfile, 3, 's')
>>> number = pt.unpack(bfile, 1, 'h')
>>> values = pt.unpack(bfile, 3, 'f')
>>> # See outputs:
>>> print(string, number, values, sep='\n')
H2O
3
(3.1415927410125732, 2.7182817459106445, inf)
"""
# Calculate number of bytes and read:
size = struct.calcsize(f'{n}{dtype}')
output = struct.unpack(f'{n}{dtype}', file.read(size))
if dtype == 's':
return output[0].decode('utf-8')
elif n == 1:
return output[0]
return output
[docs]
def u(units):
"""
Get the conversion factor (to the CGS system) for units.
Parameters
----------
units: String
Name of units.
Returns
-------
value: Float
Value of input units in CGS units.
Examples
--------
>>> import pyratbay.tools as pt
>>> for units in ['cm', 'm', 'rearth', 'rjup', 'au']:
>>> print(f'{units} = {pt.u(units)} cm')
cm = 1.0 cm
m = 100.0 cm
rearth = 637810000.0 cm
rjup = 7149200000.0 cm
au = 14959787069100.0 cm
"""
# Accept only valid units:
if not hasattr(pc, units):
raise ValueError(
f"Units '{units}' does not exist in pyratbay.constants.")
return getattr(pc, units)
[docs]
def get_param(param, units='none', gt=None, ge=None):
"""
Read a parameter that may or may not have units.
If it doesn't, default to the 'units' input argument.
Parameters
----------
param: String, Float, integer, or ndarray
The parameter value (which may contain the units).
units: String
The default units for the parameter.
gt: Float
If not None, check output is greater than gt.
ge: Float
If not None, check output is greater-equal than gt.
Returns
-------
value: Float or integer
Examples
--------
>>> import pyratbay.tools as pt
>>> # One meter in cm:
>>> pt.get_param('1.0 m')
100.0
>>> # Alternatively, specify units in second argument:
>>> pt.get_param(1.0, 'm')
100.0
>>> # Units in 'param' take precedence over 'unit':
>>> pt.get_param('1.0 m', 'km')
100.0
>>> # Request returned value to be positive:
>>> pt.get_param('-30.0 kelvin', gt=0.0)
ValueError: Value -30.0 must be > 0.0.
"""
if param is None:
return None
# Split the parameter if it has a white-space:
if isinstance(param, str):
par = param.split()
if len(par) > 2:
raise ValueError(f"Invalid value '{param}'")
if len(par) == 2:
units = par[1]
if not hasattr(pc, units):
raise ValueError(f"Invalid units for value '{param}'")
try:
value = float(par[0])
except:
raise ValueError(f"Invalid value '{param}'")
else:
value = param
# Use given units:
if isinstance(param, (numbers.Number, np.ndarray)) \
or (isinstance(param, str) and len(par) == 1):
if units is None or not hasattr(pc, units):
raise ValueError(f"Invalid units '{units}'")
# Apply the units:
value *= u(units)
if gt is not None and value <= gt:
raise ValueError(f'Value {value} must be > {gt}')
if ge is not None and value < ge:
raise ValueError(f'Value {value} must be >= {ge}')
return value
[docs]
def is_number(value):
r"""
Check whether a string value can be parsed as a number.
Examples
--------
>>> import pyratbay.tools as pt
>>> # These return True
>>> pt.is_number('1')
>>> pt.is_number('1.0')
>>> pt.is_number('-3.14')
>>> pt.is_number('+3.14')
>>> pt.is_number('1.0e+02')
>>> pt.is_number('inf')
>>> pt.is_number('nan')
>>> # These return False
>>> pt.is_number('1.0-3.14')
>>> pt.is_number('10abcde')
>>> pt.is_number('1.0e')
>>> pt.is_number('true')
>>> pt.is_number('none')
"""
try:
_ = float(value)
return True
except ValueError:
return False
[docs]
def ifirst(data, default_ret=-1):
"""
Get the first index where data is True or 1.
Parameters
----------
data: 1D bool/integer iterable
An array of bools or integers.
default_ret: Integer
Default returned value when no value in data is True or 1.
Returns
-------
first: integer
First index where data == True or 1. Return default_ret otherwise.
Examples
--------
>>> import pyratbay.tools as pt
>>> import numpy as np
>>> print(pt.ifirst([1,0,0]))
0
>>> print(pt.ifirst(np.arange(5)>2.5))
3
>>> print(pt.ifirst([False, True, True]))
1
>>> print(pt.ifirst([False, False, False]))
-1
>>> print(pt.ifirst([False, False, False], default_ret=0))
0
"""
return _indices.ifirst(np.asarray(data, int), default_ret)
[docs]
def ilast(data, default_ret=-1):
"""
Get the last index where data is 1 or True.
Parameters
----------
data: 1D bool/integer iterable
An array of bools or integers.
default_ret: Integer
Default returned value when no value in data is True or 1.
Returns
-------
last: integer
Last index where data == 1 or True. Return default_ret otherwise.
Examples
--------
>>> import pyratbay.tools as pt
>>> import numpy as np
>>> print(pt.ilast([1,0,0]))
0
>>> print(pt.ilast(np.arange(5)<2.5))
2
>>> print(pt.ilast([False, True, True]))
2
>>> print(pt.ilast([False, False, False]))
-1
>>> print(pt.ilast([False, False, False], default_ret=0))
0
"""
return _indices.ilast(np.asarray(data, int), default_ret)
[docs]
def mkdir(file_path):
"""
Create a directory for given file_path if it doesn't exists.
Creating nested folders is not allowed.
Parameters
----------
file_path: String
Path to a file.
Examples
--------
>>> import pyratbay.tools as pt
>>> log_file = 'NS1/ns_emission_tutorial.log'
>>> pt.mkdir(log_file)
"""
path, filename = os.path.split(file_path)
# path.removeprefix() alternative (python<3.9)
if path.startswith('./'):
path = path[2:]
# Only make dirs in main process
rank = get_mpi_rank()
if rank == 0 and path !='' and not os.path.exists(path):
os.mkdir(path)
# Synchronize to ensure mkdir call has completed
mpi_barrier()
[docs]
def isfile(path):
"""
Check whether a path (or list of paths) is a regular file.
Parameters
----------
path: String or list
Path(s) to check.
Returns
-------
status: Integer
If path is None, return -1.
If any path is not a regular file, return 0.
If all paths are a regular file, return 1.
Examples (for Python 2.7, import from pathlib2)
--------
>>> import pyratbay.tools as pt
>>> from pathlib import Path
>>> # Mock couple files:
>>> file1, file2 = './tmp_file1.deleteme', './tmp_file2.deleteme'
>>> Path(file1).touch()
>>> Path(file2).touch()
>>> # Input is None:
>>> print(pt.isfile(None))
-1
>>> # All input files exist:
>>> print(pt.isfile(file1))
1
>>> print(pt.isfile([file1]))
1
>>> print(pt.isfile([file1, file2]))
1
>>> # At least one input does not exist:
>>> print(pt.isfile('nofile'))
0
>>> print(pt.isfile(['nofile']))
0
>>> print(pt.isfile([file1, 'nofile']))
0
"""
# None exception:
if path is None:
return -1
if isinstance(path, str):
paths = [path]
else:
paths = path
# Regular file or not:
return int(all(os.path.isfile(path) for path in paths))
[docs]
def file_exists(pname, desc, value):
"""
Check that a file or list of files (value) exist. If not None
and file(s) do not exist, raise a ValueError.
Parameters
----------
pname: String
Parameter name.
desc: String
Parameter description.
value: String or list of strings
File path(s) to check.
Examples (for Python 2.7, import from pathlib2)
--------
>>> import pyratbay.tools as pt
>>> from pathlib import Path
>>> # None is OK:
>>> pt.file_exists('none', 'None input', None)
>>> # Create a file, check it exists:
>>> Path('./new_tmp_file.dat').touch()
>>> pt.file_exists('testfile', 'Test', 'new_tmp_file.dat')
>>> # Non-existing file throws error:
>>> pt.file_exists('testfile', 'Test', 'no_file.dat')
ValueError: Test file (testfile) does not exist: 'no_file.dat'
"""
if value is None:
return
if isinstance(value, str):
values = [value]
else:
values = value
for value in values:
if not os.path.isfile(value):
raise ValueError(f"{desc} file ({pname}) does not exist: '{value}'")
[docs]
def path(filename):
"""
Ensure file names have non-null path
Parameters
----------
filename: String
A file name.
Examples
--------
>>> import pyratbay.tools as pt
>>> print(pt.path('file.txt'))
./file.txt
>>> print(pt.path('./file.txt'))
./file.txt
>>> print(pt.path('/home/user/file.txt'))
/home/user/file.txt
"""
if filename is None:
return None
path, filename = os.path.split(filename)
if path == '':
path = '.'
return f'{path}/{filename}'
[docs]
class Timer(object):
"""
Timer to get the time (in seconds) since the last call.
"""
def __init__(self):
self.t0 = time.time()
def clock(self):
tnew = time.time()
delta = tnew - self.t0
self.t0 = tnew
return delta
[docs]
def get_exomol_mol(file):
"""
Parse an exomol file to extract the molecule and isotope name.
Parameters
----------
file: String
An exomol line-list file (must follow ExoMol naming convention).
Returns
-------
molecule: String
Name of the molecule.
isotope: String
Name of the isotope (See Tennyson et al. 2016, jmosp, 327).
Examples
--------
>>> import pyratbay.tools as pt
>>> filenames = [
>>> '1H2-16O__POKAZATEL__00400-00500.trans.bz2',
>>> '1H-2H-16O__VTT__00250-00500.trans.bz2',
>>> '12C-16O2__HITEMP.pf',
>>> '12C-16O-18O__Zak.par',
>>> '12C-1H4__YT10to10__01100-01200.trans.bz2',
>>> '12C-1H3-2H__MockName__01100-01200.trans.bz2'
>>> ]
>>> for db in filenames:
>>> print(pt.get_exomol_mol(db))
('H2O', '116')
('H2O', '126')
('CO2', '266')
('CO2', '268')
('CH4', '21111')
('CH4', '21112')
"""
atoms = os.path.split(file)[1].split('_')[0].split('-')
elements = []
isotope = ''
for atom in atoms:
match = re.match(r"([0-9]+)([a-z]+)([0-9]*)", atom, re.I)
N = 1 if match.group(3) == '' else int(match.group(3))
elements += N * [match.group(2)]
isotope += match.group(1)[-1:] * N
composition = [list(g[1]) for g in itertools.groupby(elements)]
molecule = ''.join([
c[0] + str(len(c))*(len(c)>1)
for c in composition
])
# Edge case:
if molecule == 'OCO':
molecule = 'CO2'
return molecule, isotope
[docs]
def cia_hitran(ciafile, tstep=1, wstep=1):
"""
Re-write a HITRAN CIA file into Pyrat Bay format.
See Richard et al. (2012) and https://www.cfa.harvard.edu/HITRAN/
Parameters
----------
ciafile: String
A HITRAN CIA file.
tstep: Integer
Slicing step size along temperature dimension.
wstep: Integer
Slicing step size along wavenumber dimension.
Examples
--------
>>> import pyratbay.tools as pt
>>> # Before moving on, download a HITRAN CIA files from the link above.
>>> ciafile = 'H2-H2_2011.cia'
>>> pt.cia_hitran(ciafile, tstep=2, wstep=10)
"""
# Extract CS data:
with open(ciafile, 'r') as f:
info = f.readline().strip().split()
species = info[0].split('-')
temps, data, wave = [], [], []
wnmin, wnmax = -1, -1
f.seek(0)
for line in f:
if line.strip().startswith('-'.join(species)):
info = line.strip().split()
# if wn ranges differ, trigger new set
if float(info[1]) != wnmin or float(info[2]) != wnmax:
wnmin = float(info[1])
wnmax = float(info[2])
temp = float(info[4])
nwave = int (info[3])
wn = np.zeros(nwave, np.double)
cs = np.zeros(nwave, np.double)
i = 0
continue
# else, read in opacities
wn[i], cs[i] = line.split()[0:2]
i += 1
if i == nwave:
temps.append(temp)
# Thin the arrays in wavenumber if requested:
data.append(cs[::wstep])
wave.append(wn[::wstep])
# Identify sets:
temps = np.array(temps)
ntemps = len(temps)
i = 0
while i < ntemps:
wn = wave[i]
j = i
while j < ntemps and len(wave[j])==len(wn) and np.all(wave[j]-wn==0):
j += 1
temp = temps[i:j:tstep]
# Set cm-1 amagat-2 units:
cs = np.array(data[i:j])[::tstep] * pc.amagat**2
pair = '-'.join(species)
wl_min = 1.0/(wn[-1]*pc.um)
wl_max = 1.0/(wn[0]*pc.um)
csfile = (
f'CIA_HITRAN_{pair}_{wl_min:.1f}-{wl_max:.1f}um_'
f'{temp[0]:04.0f}-{temp[-1]:04.0f}K.dat')
header = (
f'# This file contains the reformated {pair} CIA data from\n'
f'# HITRAN file: {ciafile}\n\n')
io.write_cs(csfile, cs, species, temp, wn, header)
i = j
[docs]
def cia_borysow(ciafile, species1, species2):
"""
Re-write a Borysow CIA file into Pyrat Bay format.
See http://www.astro.ku.dk/~aborysow/programs/
Parameters
----------
ciafile: String
A HITRAN CIA file.
species1: String
First CIA species.
species2: String
Second CIA species.
Examples
--------
>>> import pyratbay.tools as pt
>>> # Before moving on, download a HITRAN CIA files from the link above.
>>> ciafile = 'ciah2he_dh_quantmech'
>>> pt.cia_borysow(ciafile, 'H2', 'He')
"""
data = np.loadtxt(ciafile, skiprows=3)
wn = data[:,0]
cs = data[:,1:].T
with open(ciafile) as f:
_ = f.readline()
temp = f.readline().split()[1:]
temp = [float(t.replace('K','')) for t in temp]
species = [species1, species2]
pair = '-'.join(species)
wl_min = 1.0/(wn[-1]*pc.um)
wl_max = 1.0/(wn[0]*pc.um)
file_name = os.path.basename(ciafile)
csfile = (
f'CIA_Borysow_{pair}_{wl_min:.1f}-{wl_max:.1f}um_'
f'{temp[0]:04.0f}-{temp[-1]:04.0f}K.dat')
header = (
f'# This file contains the reformated {pair} CIA data from:\n'
f'# http://www.astro.ku.dk/~aborysow/programs/{file_name}\n\n')
io.write_cs(csfile, cs, species, temp, wn, header)
[docs]
def interpolate_opacity(
cs_file, temperature=None, pressure=None, wn_mask=None, wl_thinning=1,
):
"""
Interpolate the cross-section data from an opacity file over a
desired temperature and pressure array.
Parameters
----------
cs_file: String
Path to a cross-section file.
temperature: 1D float array
The desired temperature array in K.
If this is the same as the tabulated temperatures, do not interpolate.
pressure: 1D float array
The desired pressure profile in bars.
If this is the same as the tabulated pressure, do not interpolate.
wn_mask: 1D bool array
A mask of wavelength points to take.
wl_thinning: Integer
Thinning factor to take every n-th value of the wavenumber array
Returns
-------
interp_cs: 4D float array
The interpolated cross-section array.
"""
_, temp, press, wn = io.read_opacity(cs_file, extract='arrays')
logp_table = np.log(press)
if wn_mask is None:
wn_mask = np.ones(len(wn), bool)
# If the pressure is the same as in the table, no need to interpolate:
resample_pressure = (
pressure is not None and
(
len(press) != len(pressure) or
np.any(np.abs(1.0-press/pressure) > 0.01)
)
)
resample_temperature = (
temperature is not None and
(
len(temp) != len(temperature) or
np.any(np.abs(1.0-temp/temperature) > 0.01)
)
)
cross_section = io.read_opacity(cs_file, extract='opacity')[:,:,wn_mask]
cross_section = cross_section[:,:,::wl_thinning]
if not resample_pressure and not resample_temperature:
return cross_section
# Work in log_opacity, avoid infinities by capping at 1e-100:
log_cs = np.log(cross_section)
log_cs[~np.isfinite(log_cs)] = -230.0
if resample_pressure:
logp = np.log(pressure)
cs_extrap = log_cs[:,0], log_cs[:,-1]
cs_interp = sip.interp1d(
logp_table, log_cs,
axis=1,
kind='slinear',
bounds_error=False,
fill_value=cs_extrap,
)
log_cs = cs_interp(logp)
if resample_temperature:
cs_extrap = log_cs[0], log_cs[-1]
cs_interp = sip.interp1d(
temp, log_cs,
axis=0,
kind='slinear',
bounds_error=False,
fill_value=cs_extrap,
)
log_cs = cs_interp(temperature)
return np.exp(log_cs)
[docs]
def none_div(a, b):
"""
Non-breaking division when values are None.
"""
if a is None or b is None:
return None
return a/b
[docs]
def radius_to_depth(rprs, rprs_err):
r"""
Compute transit depth (and uncertainties) from input
planet=to-star radius-ratio, with error propagation.
Parameters
----------
rprs: Float or float iterable
Planet-to-star radius ratio.
rprs_err: Float or float iterable
Uncertainties of the radius ratios.
Returns
-------
depth: Float or float ndarray
Transit depth for given radius ratio.
depth_err: Float or float ndarray
Uncertainties of the transit depth.
Examples
--------
>>> import numpy as np
>>> import pyratbay.tools as pt
>>> rprs = 1.2
>>> rprs_err = 0.25
>>> depth, depth_err = pt.radius_to_depth(rprs, rprs_err)
>>> print(f'Depth = {depth} +/- {depth_err}')
Depth = 1.44 +/- 0.6
>>> rprs = [1.2, 1.5]
>>> rprs_err = [0.25, 0.3]
>>> depth, depth_err = pt.radius_to_depth(rprs, rprs_err)
>>> print('Depth Uncert\n' +
>>> '\n'.join([f'{d} +/- {de:.1f}' for d,de in zip(depth, depth_err)]))
Depth Uncert
1.44 +/- 0.6
2.25 +/- 0.9
"""
if not isinstance(rprs, Iterable):
pass
elif not isinstance(rprs, np.ndarray):
rprs = np.array(rprs)
rprs_err = np.array(rprs_err)
depth = rprs**2.0
depth_err = 2.0 * rprs * rprs_err
return depth, depth_err
[docs]
def depth_to_radius(depth, depth_err):
r"""
Compute planet-to-star radius ratio (and uncertainties) from
input transit depth, with error propagation.
Parameters
----------
depth: Float or float iterable
Transit depth.
depth_err: Float or float iterable
Uncertainties of the transit depth.
Returns
-------
rprs: Float or float ndarray
Planet-to-star radius ratio.
rprs_err: Float or float ndarray
Uncertainties of the radius ratio rprs.
Examples
--------
>>> import numpy as np
>>> import pyratbay.tools as pt
>>> depth = 1.44
>>> depth_err = 0.6
>>> rprs, rprs_err = pt.depth_to_radius(depth, depth_err)
>>> print(f'Rp/Rs = {rprs} +/- {rprs_err}')
Rp/Rs = 1.2 +/- 0.25
>>> depth = [1.44, 2.25]
>>> depth_err = [0.6, 0.9]
>>> rprs, rprs_err = pt.depth_to_radius(depth, depth_err)
>>> print('Rp/Rs Uncert\n'
>>> + '\n'.join([f'{r} +/- {re}' for r,re in zip(rprs, rprs_err)]))
Rp/Rs Uncert
1.2 +/- 0.25
1.5 +/- 0.3
"""
if not isinstance(depth, Iterable):
pass
elif not isinstance(depth, np.ndarray):
depth = np.array(depth)
depth_err = np.array(depth_err)
rprs = np.sqrt(depth)
rprs_err = 0.5 * depth_err / rprs
return rprs, rprs_err