# Copyright (c) 2012 Denis Bilenko. See LICENSE for details. from __future__ import with_statement, absolute_import import sys import os from gevent.hub import get_hub, getcurrent, sleep, integer_types from gevent.event import AsyncResult from gevent.greenlet import Greenlet from gevent.pool import IMap, IMapUnordered from gevent.lock import Semaphore from gevent._threading import Lock, Queue, start_new_thread # XXX apply_e is ugly and must not be needed # XXX apply() should re-raise everything __all__ = ['ThreadPool', 'ThreadResult'] class ThreadPool(object): def __init__(self, maxsize, hub=None): if hub is None: hub = get_hub() self.hub = hub self._maxsize = 0 self.manager = None self.pid = os.getpid() self.fork_watcher = hub.loop.fork(ref=False) self._init(maxsize) def _set_maxsize(self, maxsize): if not isinstance(maxsize, integer_types): raise TypeError('maxsize must be integer: %r' % (maxsize, )) if maxsize < 0: raise ValueError('maxsize must not be negative: %r' % (maxsize, )) difference = maxsize - self._maxsize self._semaphore.counter += difference self._maxsize = maxsize self.adjust() # make sure all currently blocking spawn() start unlocking if maxsize increased self._semaphore._start_notify() def _get_maxsize(self): return self._maxsize maxsize = property(_get_maxsize, _set_maxsize) def __repr__(self): return '<%s at 0x%x %s/%s/%s>' % (self.__class__.__name__, id(self), len(self), self.size, self.maxsize) def __len__(self): # XXX just do unfinished_tasks property return self.task_queue.unfinished_tasks def _get_size(self): return self._size def _set_size(self, size): if size < 0: raise ValueError('Size of the pool cannot be negative: %r' % (size, )) if size > self._maxsize: raise ValueError('Size of the pool cannot be bigger than maxsize: %r > %r' % (size, self._maxsize)) if self.manager: self.manager.kill() while self._size < size: self._add_thread() delay = 0.0001 while self._size > size: while self._size - size > self.task_queue.unfinished_tasks: self.task_queue.put(None) if getcurrent() is self.hub: break sleep(delay) delay = min(delay * 2, .05) if self._size: self.fork_watcher.start(self._on_fork) else: self.fork_watcher.stop() size = property(_get_size, _set_size) def _init(self, maxsize): self._size = 0 self._semaphore = Semaphore(1) self._lock = Lock() self.task_queue = Queue() self._set_maxsize(maxsize) def _on_fork(self): # fork() only leaves one thread; also screws up locks; # let's re-create locks and threads pid = os.getpid() if pid != self.pid: self.pid = pid # Do not mix fork() and threads; since fork() only copies one thread # all objects referenced by other threads has refcount that will never # go down to 0. self._init(self._maxsize) def join(self): delay = 0.0005 while self.task_queue.unfinished_tasks > 0: sleep(delay) delay = min(delay * 2, .05) def kill(self): self.size = 0 def _adjust_step(self): # if there is a possibility & necessity for adding a thread, do it while self._size < self._maxsize and self.task_queue.unfinished_tasks > self._size: self._add_thread() # while the number of threads is more than maxsize, kill one # we do not check what's already in task_queue - it could be all Nones while self._size - self._maxsize > self.task_queue.unfinished_tasks: self.task_queue.put(None) if self._size: self.fork_watcher.start(self._on_fork) else: self.fork_watcher.stop() def _adjust_wait(self): delay = 0.0001 while True: self._adjust_step() if self._size <= self._maxsize: return sleep(delay) delay = min(delay * 2, .05) def adjust(self): self._adjust_step() if not self.manager and self._size > self._maxsize: # might need to feed more Nones into the pool self.manager = Greenlet.spawn(self._adjust_wait) def _add_thread(self): with self._lock: self._size += 1 try: start_new_thread(self._worker, ()) except: with self._lock: self._size -= 1 raise def spawn(self, func, *args, **kwargs): while True: semaphore = self._semaphore semaphore.acquire() if semaphore is self._semaphore: break try: task_queue = self.task_queue result = AsyncResult() thread_result = ThreadResult(result, hub=self.hub) task_queue.put((func, args, kwargs, thread_result)) self.adjust() # rawlink() must be the last call result.rawlink(lambda *args: self._semaphore.release()) # XXX this _semaphore.release() is competing for order with get() # XXX this is not good, just make ThreadResult release the semaphore before doing anything else except: semaphore.release() raise return result def _decrease_size(self): if sys is None: return _lock = getattr(self, '_lock', None) if _lock is not None: with _lock: self._size -= 1 def _worker(self): need_decrease = True try: while True: task_queue = self.task_queue task = task_queue.get() try: if task is None: need_decrease = False self._decrease_size() # we want first to decrease size, then decrease unfinished_tasks # otherwise, _adjust might think there's one more idle thread that # needs to be killed return func, args, kwargs, result = task try: value = func(*args, **kwargs) except: exc_info = getattr(sys, 'exc_info', None) if exc_info is None: return result.handle_error((self, func), exc_info()) else: if sys is None: return result.set(value) del value finally: del func, args, kwargs, result, task finally: if sys is None: return task_queue.task_done() finally: if need_decrease: self._decrease_size() # XXX apply() should re-raise error by default # XXX because that's what builtin apply does # XXX check gevent.pool.Pool.apply and multiprocessing.Pool.apply def apply_e(self, expected_errors, function, args=None, kwargs=None): if args is None: args = () if kwargs is None: kwargs = {} success, result = self.spawn(wrap_errors, expected_errors, function, args, kwargs).get() if success: return result raise result def apply(self, func, args=None, kwds=None): """Equivalent of the apply() builtin function. It blocks till the result is ready.""" if args is None: args = () if kwds is None: kwds = {} return self.spawn(func, *args, **kwds).get() def apply_cb(self, func, args=None, kwds=None, callback=None): result = self.apply(func, args, kwds) if callback is not None: callback(result) return result def apply_async(self, func, args=None, kwds=None, callback=None): """A variant of the apply() method which returns a Greenlet object. If callback is specified then it should be a callable which accepts a single argument. When the result becomes ready callback is applied to it (unless the call failed).""" if args is None: args = () if kwds is None: kwds = {} return Greenlet.spawn(self.apply_cb, func, args, kwds, callback) def map(self, func, iterable): return list(self.imap(func, iterable)) def map_cb(self, func, iterable, callback=None): result = self.map(func, iterable) if callback is not None: callback(result) return result def map_async(self, func, iterable, callback=None): """ A variant of the map() method which returns a Greenlet object. If callback is specified then it should be a callable which accepts a single argument. """ return Greenlet.spawn(self.map_cb, func, iterable, callback) def imap(self, func, iterable): """An equivalent of itertools.imap()""" return IMap.spawn(func, iterable, spawn=self.spawn) def imap_unordered(self, func, iterable): """The same as imap() except that the ordering of the results from the returned iterator should be considered in arbitrary order.""" return IMapUnordered.spawn(func, iterable, spawn=self.spawn) class ThreadResult(object): def __init__(self, receiver, hub=None): if hub is None: hub = get_hub() self.receiver = receiver self.hub = hub self.value = None self.context = None self.exc_info = None self.async = hub.loop.async() self.async.start(self._on_async) def _on_async(self): self.async.stop() try: if self.exc_info is not None: try: self.hub.handle_error(self.context, *self.exc_info) finally: self.exc_info = None self.context = None self.async = None self.hub = None if self.receiver is not None: # XXX exception!!!? self.receiver(self) finally: self.receiver = None self.value = None def set(self, value): self.value = value self.async.send() def handle_error(self, context, exc_info): self.context = context self.exc_info = exc_info self.async.send() # link protocol: def successful(self): return True def wrap_errors(errors, function, args, kwargs): try: return True, function(*args, **kwargs) except errors: return False, sys.exc_info()[1]