MPI4PY Python based job queue with caching

This post introduces a technique to spread a number of similar jobs into multiple workers using MPI, with the following characteristics:

  • There is a single scheduler and multiple workers.
  • The whole set of jobs can be interrupted and restarted, thanks to using a file-based cache.
  • The task to be spread among workers can be a Python closure or lambda function (which cannot be pickled!).

The caching and memoization is based on the pickle library, which we save into this tiny utility module.

import pickle
import os.path
def dump(filename, *data):
	with open(filename, "wb") as f:
		pickle.dump(data, f)

def load(filename):
	if os.path.exists(filename):
		with open(filename, "rb") as f:
			data = pickle.load(f)
		return data if len(data) > 1 else data[0]
	else:
		return None

The task scheduler is built on top of Python's mpi4py library and, more precisely, on the MPICommExecutor class. I do this because I tried to roll out my own class and it met with the busy-wait problem in MPICH—all processes, even those that had finished their work, maxed 100% CPU when waiting for other jobs to complete or for data to be communicated.

from tools import *
import platform
import os
from functools import reduce
import traceback

def default_task_runner(args):
	global _task_object
	return _task_object.execute(args)

class Tasker(object):

	def __init__(self, method='MPI'):
		try:
			self.comm = None
			if method == 'MPI':
				#
				# MPI jobs, for most platforms, but only
				# if module is available
				from mpi4py import MPI
				self.comm = MPI.COMM_WORLD
				self.rank = self.comm.Get_rank()
				self.size = self.comm.Get_size()
				self.root = 0
		except Exception as e:
			print('Defaulting to serial jobs because mpi4py is not available.')
			print(e)
			self.comm = None
		if self.comm is None:
			# Backup system: serial jobs
			self.size = 1
			self.rank = self.root = 0
		self._prefix = f'Task {self.rank:3d}'
		self._function = None

	def isroot(self):
		return self.rank == self.root

	def __enter__(self):
		print(f'Entering tasker {self}')
		return self

	def __exit__(self, *args):
		return True

	def make_joblist(self, filename, args):
		dirname = filename + '.mpi'
		print(f'{self._prefix} Creating joblist for {filename} under {dirname}')
		if not os.path.exists(dirname):
			os.mkdir(dirname)
		joblist = [(i, f'{dirname}/job{i:04d}.pkl', arg) for i, arg in enumerate(args)]
		return joblist, dirname

	def execute(self, arg):
		if self._function is None:
			raise Exception(f'{self._prefix} Function not installed')
		index, tmpfilename, *arguments = arg
		if os.path.exists(tmpfilename):
			print(f'{self._prefix} loads {tmpfilename}')
			data = load(tmpfilename)
		else:
			print(f'{self._prefix} executes {tmpfilename}')
			data = (index, arguments, self._function(*arguments))
			print(f'{self._prefix} saves {tmpfilename}')
			dump(tmpfilename, *data)
		print(f'{self._prefix} made {data}')
		return data

	def cleanup_files(self, dirname, joblist):
		try:
			for i, tmp, *_ in joblist:
				print(f'{self._prefix} removing file {tmp}')
				os.remove(tmp)
			os.rmdir(dirname)
		except Exception as e:
			print(f'{self._prefix} unable to remove directory {dirname}')
			print(e)

	def map(self, filename, fn, args):
		global _task_object
		self._function = fn
		if os.path.exists(filename):
			print(f'{self._prefix} File {filename} already exists. Returning it')
			return load(filename)
		elif self.size == 1:
			print(f'{self._prefix} Running jobs in serial mode')
			joblist, dirname = self.make_joblist(filename, args)
			output = [self.execute(arg) for arg in joblist]
			dump(filename, output)
			print(f'{self._prefix} Data saved into {filename}')
			self.cleanup_files(dirname, joblist)
			return output
		else:
			try:
				_task_object = self
				import mpi4py.futures
				with mpi4py.futures.MPICommExecutor(self.comm, self.root) as executor:
					print(f'{self._prefix} Executor is {executor}')
					if executor is not None:
						# Root process
						print(f'{self._prefix} Missing {filename}')
						joblist, dirname = self.make_joblist(filename, args)
						print(f'{self._prefix} got joblist with {len(joblist)} items')
						output = list(executor.map(default_task_runner, joblist))
						print(f'{self._prefix} collected {len(output)} items')
						dump(filename, output)
						print(f'{self._prefix} Data saved into {filename}')
						self.cleanup_files(dirname, joblist)
			except Exception as e:
				print(f'{self._prefix} aborts due to exception {e}')
				quit()
			_task_object = None
			return output

The following code shows how to use it. The same program is to be run on all processes. That program will create a `Tasker` object and use the `map()` method to distribut work among all processes.

from mp import Tasker
with Tasker() as tsk:
	print(f'task={tsk}, rank={tsk.rank}')
	data = tsk.map('foo.pkl', lambda x: str(x), range(100))

The `Tasker` object will distribute 100 jobs among the `N-1` remaining processes, sending newer tasks to jobs that have already finished. Note also that the task to be done can be a closure and it may reference variables from the parent process. In order to run this code, I recommend using a non-buffering python instance. Assuming the code above was saved into the `job.py` file, we can spread the 100 jobs into 3 processes plus 1 scheduler, as follows:

$ mpirun -n 4 python -u ./job.py

The `Tasker` class will create a temporary directory with the `.mpi` extension. This directory will contain the memoized files. If your job crashes for some reason, such as exhausting the time allocated by the queue manager of your cluster, you can restart it and it will only run the tasks that were not completing. Once all tasks have been finished, the `*.mpi` directory will be deleted and pickle file with the right name, in this case `foo.pkl`, will be saved.