88"""
99from __future__ import absolute_import
1010
11+ import logging
12+ from contextlib import contextmanager
1113from functools import wraps
1214
1315from celery import states
16+ from celery .backends .base import BaseBackend
1417from celery .exceptions import ImproperlyConfigured
1518from celery .five import range
1619from celery .utils .timeutils import maybe_timedelta
1720
18- from celery .backends .base import BaseBackend
21+ from .models import Task
22+ from .models import TaskSet
23+ from .session import SessionManager
1924
20- from .models import Task , TaskSet
21- from .session import ResultSession
25+ logger = logging .getLogger (__name__ )
2226
2327__all__ = ['DatabaseBackend' ]
2428
@@ -33,7 +37,19 @@ def _sqlalchemy_installed():
3337 return sqlalchemy
3438_sqlalchemy_installed ()
3539
36- from sqlalchemy .exc import DatabaseError , OperationalError
40+ from sqlalchemy .exc import DatabaseError , OperationalError , ResourceClosedError , InvalidRequestError , IntegrityError
41+ from sqlalchemy .orm .exc import StaleDataError
42+
43+
44+ @contextmanager
45+ def session_cleanup (session ):
46+ try :
47+ yield
48+ except Exception :
49+ session .rollback ()
50+ raise
51+ finally :
52+ session .close ()
3753
3854
3955def retry (fun ):
@@ -45,7 +61,15 @@ def _inner(*args, **kwargs):
4561 for retries in range (max_retries ):
4662 try :
4763 return fun (* args , ** kwargs )
48- except (DatabaseError , OperationalError ):
64+ except (
65+ DatabaseError , OperationalError , ResourceClosedError , StaleDataError , InvalidRequestError ,
66+ IntegrityError
67+ ):
68+ logger .warning (
69+ "Failed operation %s. Retrying %s more times." ,
70+ fun .__name__ , max_retries - retries - 1 ,
71+ exc_info = True ,
72+ )
4973 if retries + 1 >= max_retries :
5074 raise
5175
@@ -83,8 +107,8 @@ def __init__(self, dburi=None, expires=None,
83107 'Missing connection string! Do you have '
84108 'CELERY_RESULT_DBURI set to a real value?' )
85109
86- def ResultSession (self ):
87- return ResultSession (
110+ def ResultSession (self , session_manager = SessionManager () ):
111+ return session_manager . session_factory (
88112 dburi = self .dburi ,
89113 short_lived_sessions = self .short_lived_sessions ,
90114 ** self .engine_options
@@ -95,8 +119,9 @@ def _store_result(self, task_id, result, status,
95119 traceback = None , max_retries = 3 , ** kwargs ):
96120 """Store return value and status of an executed task."""
97121 session = self .ResultSession ()
98- try :
99- task = session .query (Task ).filter (Task .task_id == task_id ).first ()
122+ with session_cleanup (session ):
123+ task = list (session .query (Task ).filter (Task .task_id == task_id ))
124+ task = task and task [0 ]
100125 if not task :
101126 task = Task (task_id )
102127 session .add (task )
@@ -106,83 +131,70 @@ def _store_result(self, task_id, result, status,
106131 task .traceback = traceback
107132 session .commit ()
108133 return result
109- finally :
110- session .close ()
111134
112135 @retry
113136 def _get_task_meta_for (self , task_id ):
114137 """Get task metadata for a task by id."""
115138 session = self .ResultSession ()
116- try :
117- task = session .query (Task ).filter (Task .task_id == task_id ).first ()
118- if task is None :
139+ with session_cleanup (session ):
140+ task = list (session .query (Task ).filter (Task .task_id == task_id ))
141+ task = task and task [0 ]
142+ if not task :
119143 task = Task (task_id )
120144 task .status = states .PENDING
121145 task .result = None
122146 return task .to_dict ()
123- finally :
124- session .close ()
125147
126148 @retry
127149 def _save_group (self , group_id , result ):
128150 """Store the result of an executed group."""
129151 session = self .ResultSession ()
130- try :
152+ with session_cleanup ( session ) :
131153 group = TaskSet (group_id , result )
132154 session .add (group )
133155 session .flush ()
134156 session .commit ()
135157 return result
136- finally :
137- session .close ()
138158
139159 @retry
140160 def _restore_group (self , group_id ):
141161 """Get metadata for group by id."""
142162 session = self .ResultSession ()
143- try :
163+ with session_cleanup ( session ) :
144164 group = session .query (TaskSet ).filter (
145165 TaskSet .taskset_id == group_id ).first ()
146166 if group :
147167 return group .to_dict ()
148- finally :
149- session .close ()
150168
151169 @retry
152170 def _delete_group (self , group_id ):
153171 """Delete metadata for group by id."""
154172 session = self .ResultSession ()
155- try :
173+ with session_cleanup ( session ) :
156174 session .query (TaskSet ).filter (
157175 TaskSet .taskset_id == group_id ).delete ()
158176 session .flush ()
159177 session .commit ()
160- finally :
161- session .close ()
162178
163179 @retry
164180 def _forget (self , task_id ):
165181 """Forget about result."""
166182 session = self .ResultSession ()
167- try :
183+ with session_cleanup ( session ) :
168184 session .query (Task ).filter (Task .task_id == task_id ).delete ()
169185 session .commit ()
170- finally :
171- session .close ()
172186
173187 def cleanup (self ):
174188 """Delete expired metadata."""
175189 session = self .ResultSession ()
176190 expires = self .expires
177191 now = self .app .now ()
178- try :
192+ with session_cleanup ( session ) :
179193 session .query (Task ).filter (
180194 Task .date_done < (now - expires )).delete ()
181195 session .query (TaskSet ).filter (
182196 TaskSet .date_done < (now - expires )).delete ()
183197 session .commit ()
184- finally :
185- session .close ()
186198
187199 def __reduce__ (self , args = (), kwargs = {}):
188200 kwargs .update (
0 commit comments