Skip to content

Commit 1edfaeb

Browse files
committed
Add workhorse TaskPool.
1 parent c353c80 commit 1edfaeb

File tree

2 files changed

+255
-0
lines changed

2 files changed

+255
-0
lines changed

celery/concurrency/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
'threads': 'celery.concurrency.threads:TaskPool',
2323
'solo': 'celery.concurrency.solo:TaskPool',
2424
'processes': 'celery.concurrency.prefork:TaskPool', # XXX compat alias
25+
'workhorse': 'celery.concurrency.workhorse:TaskPool',
2526
}
2627

2728

celery/concurrency/workhorse.py

Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
import errno
2+
import logging
3+
import os
4+
import signal
5+
import sys
6+
import threading
7+
import time
8+
from itertools import count
9+
from logging import getLogger
10+
11+
import signalfd
12+
from billiard import util
13+
from billiard.exceptions import WorkerLostError
14+
from billiard.process import _maybe_flush
15+
from billiard.process import _set_current_process
16+
from celery.concurrency.base import BasePool
17+
from celery.concurrency.prefork import _set_task_join_will_block
18+
from celery.concurrency.prefork import platforms, process_destructor
19+
from celery.concurrency.prefork import WORKER_SIGIGNORE
20+
from celery.concurrency.prefork import WORKER_SIGRESET
21+
from celery.five import monotonic
22+
23+
from ctypes import c_char
24+
from ctypes import c_int32
25+
from ctypes import c_uint32
26+
from ctypes import c_uint64
27+
from ctypes import Structure
28+
29+
logger = getLogger(__name__)
30+
31+
class signalfd_siginfo(Structure):
32+
_fields_ = (
33+
('ssi_signo', c_uint32), # Signal number
34+
('ssi_errno', c_int32), # Error number (unused)
35+
('ssi_code', c_int32), # Signal code
36+
('ssi_pid', c_uint32), # PID of sender
37+
('ssi_uid', c_uint32), # Real UID of sender
38+
('ssi_fd', c_int32), # File descriptor (SIGIO)
39+
('ssi_tid', c_uint32), # Kernel timer ID (POSIX timers)
40+
('ssi_band', c_uint32), # Band event (SIGIO)
41+
('ssi_overrun', c_uint32), # POSIX timer overrun count
42+
('ssi_trapno', c_uint32), # Trap number that caused signal
43+
('ssi_status', c_int32), # Exit status or signal (SIGCHLD)
44+
('ssi_int', c_int32), # Integer sent by sigqueue(2)
45+
('ssi_ptr', c_uint64), # Pointer sent by sigqueue(2)
46+
('ssi_utime', c_uint64), # User CPU time consumed (SIGCHLD)
47+
('ssi_stime', c_uint64), # System CPU time consumed (SIGCHLD)
48+
('ssi_addr', c_uint64), # Address that generated signal
49+
# (for hardware-generated signals)
50+
('_padding', c_char * 46), # Pad size to 128 bytes (allow for
51+
# additional fields in the future)
52+
)
53+
54+
class ExcInfo(object):
55+
internal = False
56+
tb = traceback = None
57+
58+
def __init__(self, exc):
59+
self.type = type(exc)
60+
self.exception = exc
61+
62+
@property
63+
def exc_info(self):
64+
return self.type, self.exception, None
65+
66+
67+
class Workhorse(object):
68+
_counter = count(1)
69+
_children = ()
70+
71+
def __init__(self, target, args, kwargs):
72+
self._name = self.name = 'Workhorse-%s' % Workhorse._counter.next()
73+
sys.stdout.flush()
74+
sys.stderr.flush()
75+
self.pid = os.fork()
76+
if self.pid == 0:
77+
try:
78+
platforms.signals.reset(*WORKER_SIGRESET)
79+
platforms.signals.ignore(*WORKER_SIGIGNORE)
80+
_set_task_join_will_block(True)
81+
82+
if 'random' in sys.modules:
83+
import random
84+
random.seed()
85+
if sys.stdin is not None:
86+
try:
87+
sys.stdin.close()
88+
sys.stdin = open(os.devnull)
89+
except (OSError, ValueError):
90+
pass
91+
_set_current_process(self)
92+
93+
# Re-init logging system.
94+
# Workaround for http://bugs.python.org/issue6721/#msg140215
95+
# Python logging module uses RLock() objects which are broken
96+
# after fork. This can result in a deadlock (Celery Issue #496).
97+
loggerDict = logging.Logger.manager.loggerDict
98+
logger_names = list(loggerDict.keys())
99+
logger_names.append(None) # for root logger
100+
for name in logger_names:
101+
if not name or not isinstance(loggerDict[name],
102+
logging.PlaceHolder):
103+
for handler in logging.getLogger(name).handlers:
104+
handler.createLock()
105+
logging._lock = threading.RLock()
106+
107+
util._finalizer_registry.clear()
108+
util._run_after_forkers()
109+
try:
110+
target(*args, **kwargs)
111+
exitcode = 0
112+
finally:
113+
util._exit_function()
114+
except SystemExit as exc:
115+
if not exc.args:
116+
exitcode = 1
117+
elif isinstance(exc.args[0], int):
118+
exitcode = exc.args[0]
119+
else:
120+
sys.stderr.write(str(exc.args[0]) + '\n')
121+
exitcode = 0 if isinstance(exc.args[0], str) else 1
122+
except:
123+
exitcode = 1
124+
if not util.error('Process %s', self.name, exc_info=True):
125+
import traceback
126+
sys.stderr.write('Process %s:\n' % self.name)
127+
traceback.print_exc()
128+
finally:
129+
util.info('Process %s exiting with exitcode %d', self.pid, exitcode)
130+
_maybe_flush(sys.stdout)
131+
_maybe_flush(sys.stderr)
132+
_maybe_flush(sys.__stdout__)
133+
_maybe_flush(sys.__stderr__)
134+
os._exit(exitcode)
135+
136+
class TaskPool(BasePool):
137+
sigfd = None
138+
sigfh = None
139+
workers = ()
140+
uses_semaphore = True
141+
sem = None
142+
143+
def on_start(self):
144+
self.sem = self.options['semaphore']
145+
self.workers = {}
146+
self.sigfd = signalfd.signalfd(0, [signal.SIGCHLD], signalfd.SFD_NONBLOCK | signalfd.SFD_CLOEXEC)
147+
self.sigfh = os.fdopen(self.sigfd, 'rb')
148+
signalfd.sigprocmask(signalfd.SIG_BLOCK, [signal.SIGCHLD])
149+
150+
def register_with_event_loop(self, hub):
151+
hub.add_reader(self.sigfd, self.on_sigchld)
152+
153+
def on_sigchld(self):
154+
pending = {}
155+
156+
si = signalfd_siginfo()
157+
while True:
158+
try:
159+
self.sigfh.readinto(si)
160+
except IOError as exc:
161+
if exc.errno != errno.EAGAIN:
162+
raise
163+
break
164+
else:
165+
assert si.ssi_signo == signal.SIGCHLD
166+
pending[si.ssi_pid] = si.ssi_status
167+
while True:
168+
try:
169+
pid, exit_code = os.waitpid(0, os.WNOHANG)
170+
except OSError as exc:
171+
if exc.errno != 10:
172+
raise
173+
break
174+
else:
175+
if not pid:
176+
break
177+
if pid not in pending:
178+
if os.WIFEXITED(exit_code):
179+
pending[pid] = os.WEXITSTATUS(exit_code)
180+
elif os.WIFSIGNALED(exit_code):
181+
pending[pid] = os.WTERMSIG(exit_code)
182+
elif os.WIFSTOPPED(exit_code):
183+
pending[pid] = os.WSTOPSIG(exit_code)
184+
185+
for pid, exit_code in pending.iteritems():
186+
self.on_worker_exit(pid, exit_code)
187+
188+
def on_worker_exit(self, pid, exit_code):
189+
if pid in self.workers:
190+
options = self.workers.pop(pid)
191+
if exit_code == 0:
192+
options['callback'](None)
193+
else:
194+
logger.warn('Got SIGCHLD with exit_code:%r for pid:%r and task_id:%r', exit_code, pid, options['correlation_id'])
195+
options['error_callback'](ExcInfo(WorkerLostError(exit_code)))
196+
if self.active:
197+
self.sem.release()
198+
process_destructor(pid, exit_code)
199+
200+
@staticmethod
201+
def terminate_job(pid, signum):
202+
logger.warn("Killing pid:%s with signum:%s", pid, signum)
203+
try:
204+
os.kill(pid, signum)
205+
except OSError as exc:
206+
if exc.errno != errno.ESRCH:
207+
raise
208+
209+
def on_apply(self, target, args, kwargs, **options):
210+
accept_callback = options['accept_callback']
211+
212+
process = Workhorse(target, args, kwargs)
213+
self.workers[process.pid] = options
214+
if accept_callback:
215+
accept_callback(process.pid, monotonic())
216+
217+
def on_stop(self):
218+
try:
219+
for pid in list(self.workers):
220+
try:
221+
pid, exit_code = os.waitpid(pid, 0)
222+
except OSError as exc:
223+
if exc.errno != errno.ECHILD:
224+
logger.warn("Failed to wait for child process %s: %s", pid, exc)
225+
continue
226+
else:
227+
self.on_worker_exit(pid, exit_code)
228+
except:
229+
self.terminate()
230+
raise
231+
on_close = on_stop
232+
233+
def terminate(self, timeout=5):
234+
for pid in self.workers:
235+
self.terminate_job(pid, signal.SIGTERM)
236+
237+
while self.workers and timeout > 0:
238+
self.on_sigchld()
239+
time.sleep(1)
240+
timeout -= 1
241+
242+
for pid in self.workers:
243+
self.terminate_job(pid, signal.SIGKILL)
244+
245+
def _get_info(self):
246+
return {
247+
'max-concurrency': self.limit,
248+
'processes': [pid for pid in self.workers],
249+
'max-tasks-per-child': 1,
250+
'put-guarded-by-semaphore': self.putlocks,
251+
'timeouts': 0, # TODO: support timeouts
252+
}
253+
254+
# TODO: missing restart method !?

0 commit comments

Comments
 (0)