pyfem.parallel.mpi_setup 源代码

# -*- coding: utf-8 -*-
"""

"""
import sys
from typing import Dict, Any, Optional, Union

try:
    from mpi4py import MPI

    MPIContext = Union[MPI.Comm, int, bool]
except ImportError:
    MPIContext = None

# MPI context singleton
_MPI_CONTEXT: Optional[Dict[str, MPIContext]] = None


[文档] def setup_mpi() -> Dict[str, MPIContext]: """Initialize MPI environment and return context dictionary.""" try: from mpi4py import MPI from petsc4py import PETSc comm: MPI.Comm = MPI.COMM_WORLD rank: int = comm.Get_rank() size: int = comm.Get_size() return { 'comm': comm, 'rank': rank, 'size': size, 'is_parallel': size > 1, 'is_master': rank == 0, 'is_worker': rank > 0 } except ImportError as e: raise RuntimeError( "Parallel version requires mpi4py and petsc4py.\n" "Please install them or use the serial version." ) from e
[文档] def get_mpi_context() -> Dict[str, MPIContext]: """Get MPI context (singleton pattern).""" global _MPI_CONTEXT if _MPI_CONTEXT is None: try: _MPI_CONTEXT = setup_mpi() except RuntimeError: # 在串行模式下提供模拟的MPI上下文 class DummyComm: """模拟的MPI通信器""" _is_dummy_comm = True @staticmethod def Get_rank(): return 0 @staticmethod def Get_size(): return 1 @staticmethod def barrier(): pass @staticmethod def bcast(obj, root=0): return obj @staticmethod def Abort(code=1): sys.exit(code) _MPI_CONTEXT = { 'comm': DummyComm(), 'rank': 0, 'size': 1, 'is_parallel': False, 'is_master': True, 'is_worker': False } return _MPI_CONTEXT
[文档] def is_dummy_comm(comm): """判断通信器是否为DummyComm""" return getattr(comm, '_is_dummy_comm', False)
if __name__ == "__main__": mpi_info = get_mpi_context() print(f"MPI Rank: {mpi_info['rank']}, Size: {mpi_info['size']}, Is Parallel: {mpi_info['is_parallel']}")