88"""
99from __future__ import absolute_import
1010
11+ import logging
12+ from contextlib import contextmanager
1113from functools import wraps
1214
1315from celery import states
16+ from celery .backends .base import BaseBackend
1417from celery .exceptions import ImproperlyConfigured
1518from celery .five import range
1619from celery .utils .timeutils import maybe_timedelta
1720
18- from celery .backends .base import BaseBackend
19-
20- from .models import Task , TaskSet
21+ from .models import Task
22+ from .models import TaskSet
2123from .session import ResultSession
2224
25+ logger = logging .getLogger (__name__ )
26+
2327__all__ = ['DatabaseBackend' ]
2428
2529
@@ -33,7 +37,21 @@ def _sqlalchemy_installed():
3337 return sqlalchemy
3438_sqlalchemy_installed ()
3539
36- from sqlalchemy .exc import DatabaseError , OperationalError
40+ from sqlalchemy .exc import DatabaseError , OperationalError , ResourceClosedError
41+ from sqlalchemy .orm .exc import StaleDataError
42+
43+
44+ @contextmanager
45+ def session_cleanup (session ):
46+ try :
47+ yield
48+ except (DatabaseError , OperationalError , ResourceClosedError , StaleDataError ):
49+ session .rollback ()
50+ session .connection ().invalidate ()
51+ session .close ()
52+ raise
53+ else :
54+ session .close ()
3755
3856
3957def retry (fun ):
@@ -45,7 +63,12 @@ def _inner(*args, **kwargs):
4563 for retries in range (max_retries ):
4664 try :
4765 return fun (* args , ** kwargs )
48- except (DatabaseError , OperationalError ):
66+ except (DatabaseError , OperationalError , ResourceClosedError , StaleDataError ):
67+ logger .critical (
68+ "Failed operation %s. Retrying %s more times." ,
69+ fun .__name__ , max_retries - retries - 1 ,
70+ exc_info = True ,
71+ )
4972 if retries + 1 >= max_retries :
5073 raise
5174
@@ -95,8 +118,9 @@ def _store_result(self, task_id, result, status,
95118 traceback = None , max_retries = 3 , ** kwargs ):
96119 """Store return value and status of an executed task."""
97120 session = self .ResultSession ()
98- try :
99- task = session .query (Task ).filter (Task .task_id == task_id ).first ()
121+ with session_cleanup (session ):
122+ task = list (session .query (Task ).filter (Task .task_id == task_id ))
123+ task = task and task [0 ]
100124 if not task :
101125 task = Task (task_id )
102126 session .add (task )
@@ -106,83 +130,70 @@ def _store_result(self, task_id, result, status,
106130 task .traceback = traceback
107131 session .commit ()
108132 return result
109- finally :
110- session .close ()
111133
112134 @retry
113135 def _get_task_meta_for (self , task_id ):
114136 """Get task metadata for a task by id."""
115137 session = self .ResultSession ()
116- try :
117- task = session .query (Task ).filter (Task .task_id == task_id ).first ()
118- if task is None :
138+ with session_cleanup (session ):
139+ task = list (session .query (Task ).filter (Task .task_id == task_id ))
140+ task = task and task [0 ]
141+ if not task :
119142 task = Task (task_id )
120143 task .status = states .PENDING
121144 task .result = None
122145 return task .to_dict ()
123- finally :
124- session .close ()
125146
126147 @retry
127148 def _save_group (self , group_id , result ):
128149 """Store the result of an executed group."""
129150 session = self .ResultSession ()
130- try :
151+ with session_cleanup ( session ) :
131152 group = TaskSet (group_id , result )
132153 session .add (group )
133154 session .flush ()
134155 session .commit ()
135156 return result
136- finally :
137- session .close ()
138157
139158 @retry
140159 def _restore_group (self , group_id ):
141160 """Get metadata for group by id."""
142161 session = self .ResultSession ()
143- try :
162+ with session_cleanup ( session ) :
144163 group = session .query (TaskSet ).filter (
145164 TaskSet .taskset_id == group_id ).first ()
146165 if group :
147166 return group .to_dict ()
148- finally :
149- session .close ()
150167
151168 @retry
152169 def _delete_group (self , group_id ):
153170 """Delete metadata for group by id."""
154171 session = self .ResultSession ()
155- try :
172+ with session_cleanup ( session ) :
156173 session .query (TaskSet ).filter (
157174 TaskSet .taskset_id == group_id ).delete ()
158175 session .flush ()
159176 session .commit ()
160- finally :
161- session .close ()
162177
163178 @retry
164179 def _forget (self , task_id ):
165180 """Forget about result."""
166181 session = self .ResultSession ()
167- try :
182+ with session_cleanup ( session ) :
168183 session .query (Task ).filter (Task .task_id == task_id ).delete ()
169184 session .commit ()
170- finally :
171- session .close ()
172185
173186 def cleanup (self ):
174187 """Delete expired metadata."""
175188 session = self .ResultSession ()
176189 expires = self .expires
177190 now = self .app .now ()
178- try :
191+ with session_cleanup ( session ) :
179192 session .query (Task ).filter (
180193 Task .date_done < (now - expires )).delete ()
181194 session .query (TaskSet ).filter (
182195 TaskSet .date_done < (now - expires )).delete ()
183196 session .commit ()
184- finally :
185- session .close ()
186197
187198 def __reduce__ (self , args = (), kwargs = {}):
188199 kwargs .update (
0 commit comments