Source code for pyratbay.tools.mpi_tools

# Copyright (c) 2021-2026 Cubillos & Blecic
# Pyrat Bay is open-source software under the GPL-2.0 license (see LICENSE)

__all__ = [
    'check_mpi4py',
    'check_mpi_is_needed',
    'get_mpi_rank',
    'get_mpi_size',
    'mpi_barrier',
]

import importlib
import os
import sys
import warnings


[docs] def check_mpi4py(): """ Detect when the code was called with MPI and mpi4py module is missing Only raise an error when needed (more than one processor required), otherwise you might be running multiple runs in parallel but not talking to each other. """ size = 1 # Detect MPI in call (might not be exhaustive) if 'OMPI_COMM_WORLD_SIZE' in os.environ: size = int(os.environ['OMPI_COMM_WORLD_SIZE']) elif 'PMI_SIZE' in os.environ: size = int(os.environ['PMI_SIZE']) # Detect mpi4pi package is installed mpi4py_exists = importlib.util.find_spec('mpi4py') is not None # Complain only if necessary: if size > 1 and not mpi4py_exists: raise ModuleNotFoundError( "Attempted to run pyratbay with MPI, but module 'mpi4py' is not " "installed. Run 'pip install mpi4py' and try again" )
[docs] def check_mpi_is_needed(inputs): """ Prevent using parallel processes through MPI when not needed (only required for MultiNest runs). """ size = get_mpi_size() rank = get_mpi_rank() mpi_needed = ( inputs.runmode == 'retrieval' and inputs.sampler == 'multinest' ) if size > 1 and not mpi_needed: # Keep only rank-zero process to reach completion msg = ( 'Attempting to use MPI, but this is only needed for MultiNest ' 'runs. Subprocesses will be terminated' ) if rank == 0: warnings.warn(msg, category=Warning) else: sys.exit(0)
[docs] def get_mpi_rank(): """ Get the MPI rank of the current process (intended for MPI runs). If mpi4py is not installed, return zero. Returns ------- rank: Interger The MPI process rank. """ rank = 0 if 'PBAY_NO_MPI' in os.environ: return rank mpi_exists = importlib.util.find_spec('mpi4py') is not None if mpi_exists: from mpi4py import MPI rank = MPI.COMM_WORLD.Get_rank() return rank
[docs] def get_mpi_size(): """ Get the size of the current group of process (intended for MPI runs). If mpi4py is not installed, return one. Returns ------- size: Interger The size of the MPI group of processes. """ size = 1 if 'PBAY_NO_MPI' in os.environ: return size mpi_exists = importlib.util.find_spec('mpi4py') is not None if mpi_exists: from mpi4py import MPI size = MPI.COMM_WORLD.Get_size() return size
[docs] def mpi_barrier(): """ Make an MPI barrier() call. Ignore it if mpi4py is not installed. """ if 'PBAY_NO_MPI' in os.environ: return mpi_exists = importlib.util.find_spec('mpi4py') is not None if mpi_exists: from mpi4py import MPI MPI.COMM_WORLD.barrier()