Skip to content

Commit 2571731

Browse files
committed
Merge pull request celery#1899 from celery/more-reliable-sqlalchemy
Make the database backend more reliable
2 parents 5254f9a + df32c47 commit 2571731

File tree

3 files changed

+95
-80
lines changed

3 files changed

+95
-80
lines changed

Changelog

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@ new in Celery 3.1.
1919

2020
- Now depends on :ref:`Kombu 3.0.14 <kombu:version-3.0.14>`.
2121

22+
- **Results**:
23+
24+
Reliability improvements to the SQLAlchemy database backend. Previously the
25+
connection from the MainProcess was improperly shared with the workers.
26+
(Issue #1786)
27+
2228
- **Redis:** Important note about events (Issue #1882).
2329

2430
There is a new transport option for Redis that enables monitors

celery/backends/database/__init__.py

Lines changed: 43 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,21 @@
88
"""
99
from __future__ import absolute_import
1010

11+
import logging
12+
from contextlib import contextmanager
1113
from functools import wraps
1214

1315
from celery import states
16+
from celery.backends.base import BaseBackend
1417
from celery.exceptions import ImproperlyConfigured
1518
from celery.five import range
1619
from celery.utils.timeutils import maybe_timedelta
1720

18-
from celery.backends.base import BaseBackend
21+
from .models import Task
22+
from .models import TaskSet
23+
from .session import SessionManager
1924

20-
from .models import Task, TaskSet
21-
from .session import ResultSession
25+
logger = logging.getLogger(__name__)
2226

2327
__all__ = ['DatabaseBackend']
2428

@@ -33,7 +37,19 @@ def _sqlalchemy_installed():
3337
return sqlalchemy
3438
_sqlalchemy_installed()
3539

36-
from sqlalchemy.exc import DatabaseError, OperationalError
40+
from sqlalchemy.exc import DatabaseError, OperationalError, ResourceClosedError, InvalidRequestError, IntegrityError
41+
from sqlalchemy.orm.exc import StaleDataError
42+
43+
44+
@contextmanager
45+
def session_cleanup(session):
46+
try:
47+
yield
48+
except Exception:
49+
session.rollback()
50+
raise
51+
finally:
52+
session.close()
3753

3854

3955
def retry(fun):
@@ -45,7 +61,15 @@ def _inner(*args, **kwargs):
4561
for retries in range(max_retries):
4662
try:
4763
return fun(*args, **kwargs)
48-
except (DatabaseError, OperationalError):
64+
except (
65+
DatabaseError, OperationalError, ResourceClosedError, StaleDataError, InvalidRequestError,
66+
IntegrityError
67+
):
68+
logger.warning(
69+
"Failed operation %s. Retrying %s more times.",
70+
fun.__name__, max_retries - retries - 1,
71+
exc_info=True,
72+
)
4973
if retries + 1 >= max_retries:
5074
raise
5175

@@ -83,8 +107,8 @@ def __init__(self, dburi=None, expires=None,
83107
'Missing connection string! Do you have '
84108
'CELERY_RESULT_DBURI set to a real value?')
85109

86-
def ResultSession(self):
87-
return ResultSession(
110+
def ResultSession(self, session_manager=SessionManager()):
111+
return session_manager.session_factory(
88112
dburi=self.dburi,
89113
short_lived_sessions=self.short_lived_sessions,
90114
**self.engine_options
@@ -95,8 +119,9 @@ def _store_result(self, task_id, result, status,
95119
traceback=None, max_retries=3, **kwargs):
96120
"""Store return value and status of an executed task."""
97121
session = self.ResultSession()
98-
try:
99-
task = session.query(Task).filter(Task.task_id == task_id).first()
122+
with session_cleanup(session):
123+
task = list(session.query(Task).filter(Task.task_id == task_id))
124+
task = task and task[0]
100125
if not task:
101126
task = Task(task_id)
102127
session.add(task)
@@ -106,83 +131,70 @@ def _store_result(self, task_id, result, status,
106131
task.traceback = traceback
107132
session.commit()
108133
return result
109-
finally:
110-
session.close()
111134

112135
@retry
113136
def _get_task_meta_for(self, task_id):
114137
"""Get task metadata for a task by id."""
115138
session = self.ResultSession()
116-
try:
117-
task = session.query(Task).filter(Task.task_id == task_id).first()
118-
if task is None:
139+
with session_cleanup(session):
140+
task = list(session.query(Task).filter(Task.task_id == task_id))
141+
task = task and task[0]
142+
if not task:
119143
task = Task(task_id)
120144
task.status = states.PENDING
121145
task.result = None
122146
return task.to_dict()
123-
finally:
124-
session.close()
125147

126148
@retry
127149
def _save_group(self, group_id, result):
128150
"""Store the result of an executed group."""
129151
session = self.ResultSession()
130-
try:
152+
with session_cleanup(session):
131153
group = TaskSet(group_id, result)
132154
session.add(group)
133155
session.flush()
134156
session.commit()
135157
return result
136-
finally:
137-
session.close()
138158

139159
@retry
140160
def _restore_group(self, group_id):
141161
"""Get metadata for group by id."""
142162
session = self.ResultSession()
143-
try:
163+
with session_cleanup(session):
144164
group = session.query(TaskSet).filter(
145165
TaskSet.taskset_id == group_id).first()
146166
if group:
147167
return group.to_dict()
148-
finally:
149-
session.close()
150168

151169
@retry
152170
def _delete_group(self, group_id):
153171
"""Delete metadata for group by id."""
154172
session = self.ResultSession()
155-
try:
173+
with session_cleanup(session):
156174
session.query(TaskSet).filter(
157175
TaskSet.taskset_id == group_id).delete()
158176
session.flush()
159177
session.commit()
160-
finally:
161-
session.close()
162178

163179
@retry
164180
def _forget(self, task_id):
165181
"""Forget about result."""
166182
session = self.ResultSession()
167-
try:
183+
with session_cleanup(session):
168184
session.query(Task).filter(Task.task_id == task_id).delete()
169185
session.commit()
170-
finally:
171-
session.close()
172186

173187
def cleanup(self):
174188
"""Delete expired metadata."""
175189
session = self.ResultSession()
176190
expires = self.expires
177191
now = self.app.now()
178-
try:
192+
with session_cleanup(session):
179193
session.query(Task).filter(
180194
Task.date_done < (now - expires)).delete()
181195
session.query(TaskSet).filter(
182196
TaskSet.date_done < (now - expires)).delete()
183197
session.commit()
184-
finally:
185-
session.close()
186198

187199
def __reduce__(self, args=(), kwargs={}):
188200
kwargs.update(

celery/backends/database/session.py

Lines changed: 46 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -8,58 +8,55 @@
88
"""
99
from __future__ import absolute_import
1010

11-
from collections import defaultdict
12-
from multiprocessing.util import register_after_fork
11+
from billiard.util import register_after_fork
1312

1413
from sqlalchemy import create_engine
15-
from sqlalchemy.orm import sessionmaker
1614
from sqlalchemy.ext.declarative import declarative_base
15+
from sqlalchemy.orm import sessionmaker
16+
from sqlalchemy.pool import NullPool
1717

1818
ResultModelBase = declarative_base()
1919

20-
_SETUP = defaultdict(lambda: False)
21-
_ENGINES = {}
22-
_SESSIONS = {}
23-
24-
__all__ = ['ResultSession', 'get_engine', 'create_session']
25-
26-
27-
class _after_fork(object):
28-
registered = False
29-
30-
def __call__(self):
31-
self.registered = False # child must reregister
32-
for engine in list(_ENGINES.values()):
33-
engine.dispose()
34-
_ENGINES.clear()
35-
_SESSIONS.clear()
36-
after_fork = _after_fork()
37-
38-
39-
def get_engine(dburi, **kwargs):
40-
try:
41-
return _ENGINES[dburi]
42-
except KeyError:
43-
engine = _ENGINES[dburi] = create_engine(dburi, **kwargs)
44-
after_fork.registered = True
45-
register_after_fork(after_fork, after_fork)
46-
return engine
47-
48-
49-
def create_session(dburi, short_lived_sessions=False, **kwargs):
50-
engine = get_engine(dburi, **kwargs)
51-
if short_lived_sessions or dburi not in _SESSIONS:
52-
_SESSIONS[dburi] = sessionmaker(bind=engine)
53-
return engine, _SESSIONS[dburi]
54-
55-
56-
def setup_results(engine):
57-
if not _SETUP['results']:
58-
ResultModelBase.metadata.create_all(engine)
59-
_SETUP['results'] = True
60-
61-
62-
def ResultSession(dburi, **kwargs):
63-
engine, session = create_session(dburi, **kwargs)
64-
setup_results(engine)
65-
return session()
20+
__all__ = ['SessionManager']
21+
22+
23+
class SessionManager(object):
24+
def __init__(self):
25+
self._engines = {}
26+
self._sessions = {}
27+
self.forked = False
28+
self.prepared = False
29+
register_after_fork(self, self._after_fork)
30+
31+
def _after_fork(self,):
32+
self.forked = True
33+
34+
def get_engine(self, dburi, **kwargs):
35+
if self.forked:
36+
try:
37+
return self._engines[dburi]
38+
except KeyError:
39+
engine = self._engines[dburi] = create_engine(dburi, **kwargs)
40+
return engine
41+
else:
42+
kwargs['poolclass'] = NullPool
43+
return create_engine(dburi, **kwargs)
44+
45+
def create_session(self, dburi, short_lived_sessions=False, **kwargs):
46+
engine = self.get_engine(dburi, **kwargs)
47+
if self.forked:
48+
if short_lived_sessions or dburi not in self._sessions:
49+
self._sessions[dburi] = sessionmaker(bind=engine)
50+
return engine, self._sessions[dburi]
51+
else:
52+
return engine, sessionmaker(bind=engine)
53+
54+
def prepare_models(self, engine):
55+
if not self.prepared:
56+
ResultModelBase.metadata.create_all(engine)
57+
self.prepared = True
58+
59+
def session_factory(self, dburi, **kwargs):
60+
engine, session = self.create_session(dburi, **kwargs)
61+
self.prepare_models(engine)
62+
return session()

0 commit comments

Comments
 (0)