MPI4PY Python based job queue with caching

This post introduces a technique to spread a number of similar jobs into multiple workers using MPI, with the following characteristics:

  • There is a single scheduler and multiple workers.
  • The whole set of jobs can be interrupted and restarted, thanks to using a file-based cache.
  • The task to be spread among workers can be a Python closure or lambda function (which cannot be pickled!).

The caching and memoization is based on the pickle library, which we save into this tiny utility module.

import pickle
import os.path
def dump(filename, *data):
    with open(filename, "wb") as f:
        pickle.dump(data, f)

def load(filename):
    if os.path.exists(filename):
        with open(filename, "rb") as f:
            data = pickle.load(f)
        return data if len(data) > 1 else data[0]
    else:
        return None

The task scheduler is built on top of Python's mpi4py library and, more precisely, on the MPICommExecutor class. I do this because I tried to roll out my own class and it met with the busy-wait problem in MPICH—all processes, even those that had finished their work, maxed 100% CPU when waiting for other jobs to complete or for data to be communicated.

from tools import *
import platform
import os
from functools import reduce
import traceback

def default_task_runner(args):
    global _task_object
    return _task_object.execute(args)

class Tasker(object):

    def __init__(self, method='MPI'):
        try:
            self.comm = None
            if method == 'MPI':
                #
                # MPI jobs, for most platforms, but only
                # if module is available
                from mpi4py import MPI
                self.comm = MPI.COMM_WORLD
                self.rank = self.comm.Get_rank()
                self.size = self.comm.Get_size()
                self.root = 0
        except Exception as e:
            print('Defaulting to serial jobs because mpi4py is not available.')
            print(e)
            self.comm = None
        if self.comm is None:
            # Backup system: serial jobs
            self.size = 1
            self.rank = self.root = 0
        self._prefix = f'Task {self.rank:3d}'
        self._function = None

    def isroot(self):
        return self.rank == self.root

    def __enter__(self):
        print(f'Entering tasker {self}')
        return self

    def __exit__(self, *args):
        return True

    def make_joblist(self, filename, args):
        dirname = filename + '.mpi'
        print(f'{self._prefix} Creating joblist for {filename} under {dirname}')
        if not os.path.exists(dirname):
            os.mkdir(dirname)
        joblist = [(i, f'{dirname}/job{i:04d}.pkl', arg) for i, arg in enumerate(args)]
        return joblist, dirname

    def execute(self, arg):
        if self._function is None:
            raise Exception(f'{self._prefix} Function not installed')
        index, tmpfilename, *arguments = arg
        if os.path.exists(tmpfilename):
            print(f'{self._prefix} loads {tmpfilename}')
            data = load(tmpfilename)
        else:
            print(f'{self._prefix} executes {tmpfilename}')
            data = (index, arguments, self._function(*arguments))
            print(f'{self._prefix} saves {tmpfilename}')
            dump(tmpfilename, *data)
        print(f'{self._prefix} made {data}')
        return data

    def cleanup_files(self, dirname, joblist):
        try:
            for i, tmp, *_ in joblist:
                print(f'{self._prefix} removing file {tmp}')
                os.remove(tmp)
            os.rmdir(dirname)
        except Exception as e:
            print(f'{self._prefix} unable to remove directory {dirname}')
            print(e)

    def map(self, filename, fn, args):
        global _task_object
        self._function = fn
        if os.path.exists(filename):
            print(f'{self._prefix} File {filename} already exists. Returning it')
            return load(filename)
        elif self.size == 1:
            print(f'{self._prefix} Running jobs in serial mode')
            joblist, dirname = self.make_joblist(filename, args)
            output = [self.execute(arg) for arg in joblist]
            dump(filename, output)
            print(f'{self._prefix} Data saved into {filename}')
            self.cleanup_files(dirname, joblist)
            return output
        else:
            try:
                _task_object = self
                import mpi4py.futures
                with mpi4py.futures.MPICommExecutor(self.comm, self.root) as executor:
                    print(f'{self._prefix} Executor is {executor}')
                    if executor is not None:
                        # Root process
                        print(f'{self._prefix} Missing {filename}')
                        joblist, dirname = self.make_joblist(filename, args)
                        print(f'{self._prefix} got joblist with {len(joblist)} items')
                        output = list(executor.map(default_task_runner, joblist))
                        print(f'{self._prefix} collected {len(output)} items')
                        dump(filename, output)
                        print(f'{self._prefix} Data saved into {filename}')
                        self.cleanup_files(dirname, joblist)
            except Exception as e:
                print(f'{self._prefix} aborts due to exception {e}')
                quit()
            _task_object = None
            return output

The following code shows how to use it. The same program is to be run on all processes. That program will create a `Tasker` object and use the `map()` method to distribut work among all processes.

from mp import Tasker
with Tasker() as tsk:
    print(f'task={tsk}, rank={tsk.rank}')
    data = tsk.map('foo.pkl', lambda x: str(x), range(100))

The `Tasker` object will distribute 100 jobs among the `N-1` remaining processes, sending newer tasks to jobs that have already finished. Note also that the task to be done can be a closure and it may reference variables from the parent process. In order to run this code, I recommend using a non-buffering python instance. Assuming the code above was saved into the `job.py` file, we can spread the 100 jobs into 3 processes plus 1 scheduler, as follows:

$ mpirun -n 4 python -u ./job.py

The `Tasker` class will create a temporary directory with the `.mpi` extension. This directory will contain the memoized files. If your job crashes for some reason, such as exhausting the time allocated by the queue manager of your cluster, you can restart it and it will only run the tasks that were not completing. Once all tasks have been finished, the `*.mpi` directory will be deleted and pickle file with the right name, in this case `foo.pkl`, will be saved.