#!/usr/bin/env python3
# -*- coding: utf-8 -*-
__author__ = "Christian Heider Nielsen"
import multiprocessing
import pickle
import queue
import time
from abc import ABC, abstractmethod
from typing import Any, Iterable
import cloudpickle
__all__ = ["CloudPickleBase", "PooledQueueTask", "PooledQueueProcessor"]
[docs]class CloudPickleBase(object):
"""
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
:param x: (Any) the variable you wish to wrap for pickling with cloudpickle"""
[docs] def __init__(self, x: Any):
self._x = x
def __getstate__(self):
return cloudpickle.dumps(self._x)
def __setstate__(self, x):
self._x = pickle.loads(x)
def __call__(self, *args, **kwargs):
return self._x(*args, **kwargs)
[docs]class PooledQueueTask(ABC):
"""
Pooled queue task interface"""
def __call__(self, *args, **kwargs):
return self.call(*args, **kwargs)
[docs] @abstractmethod
def call(self, *args, **kwargs) -> Any:
"""
:param args:
:type args:
:param kwargs:
:type kwargs:"""
raise NotImplemented
[docs]class PooledQueueProcessor(object):
"""
This is a workaround of Pythons extremely slow interprocess communication pipes.
The ideal solution would be to use a multiprocessing.queue, but it apparently communication is band
limited.
This solution has processes complete tasks (batches) and a thread add the results to a queue.queue."""
[docs] def __init__(
self,
func,
args: Iterable = (),
kwargs=None,
max_queue_size=100,
n_proc=None,
max_tasks_per_child=None,
fill_at_construction=True,
blocking=True,
):
if kwargs is None:
kwargs = {}
self._max_queue_size = max_queue_size
if isinstance(func, type):
func = func()
self._func = CloudPickleBase(func)
self.args = args
self.kwargs = kwargs
self.blocking = blocking
if max_tasks_per_child is None:
max_tasks_per_child = max_queue_size // 4
if n_proc is None:
n_proc = multiprocessing.cpu_count()
self._queue = queue.Queue(maxsize=max_queue_size)
self._pool = multiprocessing.Pool(n_proc, maxtasksperchild=max_tasks_per_child)
if fill_at_construction:
self.fill()
[docs] def fill(self) -> None:
"""
fill queue"""
for i in range(self._max_queue_size):
self.maybe_fill()
[docs] def close(self) -> None:
"""
close pool"""
self._pool.close()
self._pool.join()
[docs] def terminate(self) -> None:
"""
terminate pool"""
self._pool.terminate()
self._pool.join()
[docs] def maybe_fill(self) -> None:
"""
fill queue if not full"""
if self.queue_size < self._max_queue_size: # and not self._queue.full():
self._pool.apply_async(
self._func, self.args, self.kwargs, self.put, self.raise_error
)
@property
def queue_size(self) -> int:
"""
:return:
:rtype:"""
return self._queue.qsize()
[docs] def put(self, res) -> None:
"""
:param res:
:type res:"""
self._queue.put(res)
[docs] def raise_error(self, excptn) -> None:
"""
:param excptn:
:type excptn:"""
self._pool.terminate()
self._pool.close()
# print(excptn.__cause__)
# sys.exit(1)
# exc_type, exc_obj, exc_tb = sys.exc_info()
raise excptn
[docs] def get(self) -> Any:
"""
:return:"""
if self.queue_size < 1: # self._queue.empty():
if len(multiprocessing.active_children()) == 0:
if self.blocking:
self.maybe_fill()
else:
raise StopIteration
res = self._queue.get(self.blocking)
self.maybe_fill()
return res
def __len__(self) -> int:
return self.queue_size
def __iter__(self):
return self
def __next__(self) -> Any:
return self.get()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self._pool.terminate()
self._pool.close()
if exc_type:
# print(exc_type, exc_val, exc_tb) # trace_back
raise exc_type(exc_val)
# sys.exit()
if __name__ == "__main__":
class Square(PooledQueueTask):
def call(self, i, *args, **kwargs):
"""
:param i:
:type i:
:param args:
:type args:
:param kwargs:
:type kwargs:
:return:
:rtype:"""
return i * 2
class Exc(PooledQueueTask):
def call(self, *args, **kwargs):
"""
:param args:
:type args:
:param kwargs:
:type kwargs:"""
raise NotImplementedError
task = Square()
processor = PooledQueueProcessor(
task, [2], fill_at_construction=True, max_queue_size=100
)
for GPU_STATS, _ in zip(processor, range(30)):
print(GPU_STATS)
processor.blocking = True
processor.args = [4]
time.sleep(3)
for GPU_STATS in processor:
print(GPU_STATS)
if GPU_STATS == 8:
break