[docs]classComm:"""Wrapper around MPI COMM_WORLD. Does nothing if MPI is not initialized."""def__init__(self)->None:"""Initialize MPI Communicator if availablel."""self._enabled=Falseself._initialized=Falseself._rank=0self._size=1try:importmpi4pympi4py.rc.initialize=Falsempi4py.rc.finalize=Truefrommpi4pyimportMPIself._MPI=MPIself._enabled=TrueexceptImportErrorase:log.info("Missing mpi4py, multi-node monitoring disabled")log.info(e)def_mpi_init(self)->None:log.debug("Initializing MPI")ifnotself._MPI.Is_initialized():self._MPI.Init()self._comm=self._MPI.COMM_WORLDself._rank=self._comm.Get_rank()self._size=self._comm.Get_size()self._initialized=Truelog.info(f"MPI initialized: rank={self._rank}, size={self._size}")ifself._size==1:log.info("Single node monitoring mode")
[docs]defGet_rank(self)->int:"""Get local MPI rank. Returns ------- int MPI Rank """ifself._enabled:ifnotself._initialized:self._mpi_init()returnself._comm.Get_rank()else:returnself._rank
[docs]defGet_size(self)->int:"""MPI World size. Returns ------- int World Size """ifself._enabled:ifnotself._initialized:self._mpi_init()returnself._comm.Get_size()else:returnself._size
[docs]defgather(self,obj:Any,root:int=0)->Optional[List[Any]]:"""MPI gather operation. Parameters ---------- obj : Any Object to be gathered. root : int, optional Reciever rank, by default 0 Returns ------- Optional[List[Any]] List with the gathered objects. """ifself._enabled:ifnotself._initialized:self._mpi_init()returnself._comm.gather(obj,root=root)else:return[obj]
[docs]defallgather(self,obj:Any)->List[Any]:"""MPI allgather operation. Parameters ---------- obj : Any Object to be gathered. Returns ------- List[Any] List with the gathered objects. """ifself._enabled:ifnotself._initialized:self._mpi_init()returnself._comm.allgather(obj)else:return[obj]
[docs]defbcast(self,obj:Any,root:int=0)->Any:"""MPI broadcast operation. Parameters ---------- obj : Any Object to be broadcasted. root : int, optional Sender rank, by default 0 Returns ------- Any Broadcasted object. """ifself._enabled:ifnotself._initialized:self._mpi_init()returnself._comm.bcast(obj,root)else:returnobj
[docs]defgather_from_ranks(self,obj:Any,ranks:List[int],root:int=0)->Optional[List[Any]]:"""Collect python objects from specific ranks at the determined root. Parameters ---------- obj : Any Object to be collected. ranks : List[int] List of ranks that need to send the object. root : int, optional Reciever rank, by default 0 Returns ------- Optional[List[Any]] List with the gathered objects. """ifself._enabled:ifnotself._initialized:self._mpi_init()result=Noneifself.Get_rank()!=root:self._comm.send(obj,root)else:result=[]forrankinranks:ifself.Get_rank()!=rank:result.append(self._comm.recv(source=rank))else:result.append(obj)returnresultelse:return[obj]
[docs]defcheck_available_ranks(self)->List[int]:"""Return an array with all the ranks that are capable of responding to a single send/recv. Returns ------- List[int] List with responsive MPI ranks. """ifself._enabled:ifnotself._initialized:self._mpi_init()rank=self._comm.Get_rank()size=self._comm.Get_size()# Create a list to store available ranksavailable_ranks=[]# Start time for the timeout mechanismstart_time=time.time()# Non-blocking receive requests listrequests=[]fortarget_rankinrange(size):iftarget_rank!=rank:self._comm.isend(rank,dest=target_rank,tag=0)# Initiate non-blocking receive requests from all other ranksfortarget_rankinrange(size):iftarget_rank!=rank:# Skip sending to selfreq=self._comm.irecv(source=target_rank,tag=0)requests.append((target_rank,req))# Check for available ranks while handling timeoutswhiletime.time()-start_time<5:# 5 seconds timeout for demonstrationfortarget_rank,reqinrequests:iftarget_ranknotinavailable_ranks:ifreq.Test():# Check if a request has received a messageavailable_ranks.append(target_rank)# Add the rank to available listiflen(available_ranks)==size-1:# All ranks are availablebreak# Sleep for a short duration before checking againtime.sleep(0.1)# Cancel all remaining requests to prevent potential deadlocksfortarget_rank,reqinrequests:iftarget_ranknotinavailable_ranks:req.Cancel()available_ranks.append(rank)sorted_available_ranks=sorted(available_ranks)returnsorted_available_rankselse:return[0]