Skip to content

Commit e793166

Browse files
committed
don't register tracking events if config disables them
warn that tracking will be disabled by default in the future
1 parent e05ffe1 commit e793166

File tree

1 file changed

+63
-56
lines changed

1 file changed

+63
-56
lines changed

flask_sqlalchemy/__init__.py

Lines changed: 63 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import sys
1515
import time
1616
import functools
17+
import warnings
1718
import sqlalchemy
1819
from math import ceil
1920
from functools import partial
@@ -24,7 +25,6 @@
2425
from sqlalchemy import orm, event
2526
from sqlalchemy.orm.exc import UnmappedClassError
2627
from sqlalchemy.orm.session import Session as SessionBase
27-
from sqlalchemy.event import listen
2828
from sqlalchemy.engine.url import make_url
2929
from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta
3030
from flask.ext.sqlalchemy._compat import iteritems, itervalues, xrange, \
@@ -146,20 +146,20 @@ class SignallingSession(SessionBase):
146146

147147
def __init__(self, db, autocommit=False, autoflush=True, **options):
148148
#: The application that this session belongs to.
149-
self.app = db.get_app()
149+
self.app = app = db.get_app()
150+
track_modifications = app.config['SQLALCHEMY_TRACK_MODIFICATIONS']
150151
self._model_changes = {}
151-
#: A flag that controls whether this session should keep track of
152-
#: model modifications. The default value for this attribute
153-
#: is set from the ``SQLALCHEMY_TRACK_MODIFICATIONS`` config
154-
#: key.
155-
self.emit_modification_signals = \
156-
self.app.config['SQLALCHEMY_TRACK_MODIFICATIONS']
157152
bind = options.pop('bind', None) or db.engine
158-
SessionBase.__init__(self, autocommit=autocommit, autoflush=autoflush,
159-
bind=bind,
160-
binds=db.get_binds(self.app), **options)
161153

162-
def get_bind(self, mapper, clause=None):
154+
if track_modifications is None or track_modifications:
155+
_SessionSignalEvents.register(self)
156+
157+
SessionBase.__init__(
158+
self, autocommit=autocommit, autoflush=autoflush,
159+
bind=bind, binds=db.get_binds(self.app), **options
160+
)
161+
162+
def get_bind(self, mapper=None, clause=None):
163163
# mapper is None if someone tries to just get a connection
164164
if mapper is not None:
165165
info = getattr(mapper.mapped_table, 'info', {})
@@ -171,11 +171,17 @@ def get_bind(self, mapper, clause=None):
171171

172172

173173
class _SessionSignalEvents(object):
174-
175-
def register(self):
176-
listen(SessionBase, 'before_commit', self.session_signal_before_commit)
177-
listen(SessionBase, 'after_commit', self.session_signal_after_commit)
178-
listen(SessionBase, 'after_rollback', self.session_signal_after_rollback)
174+
@classmethod
175+
def register(cls, session):
176+
event.listen(session, 'before_commit', cls.session_signal_before_commit)
177+
event.listen(session, 'after_commit', cls.session_signal_after_commit)
178+
event.listen(session, 'after_rollback', cls.session_signal_after_rollback)
179+
180+
@classmethod
181+
def unregister(cls, session):
182+
event.remove(session, 'before_commit', cls.session_signal_before_commit)
183+
event.remove(session, 'after_commit', cls.session_signal_after_commit)
184+
event.remove(session, 'after_rollback', cls.session_signal_after_rollback)
179185

180186
@staticmethod
181187
def session_signal_before_commit(session):
@@ -202,33 +208,38 @@ def session_signal_after_rollback(session):
202208

203209

204210
class _MapperSignalEvents(object):
205-
206-
def __init__(self, mapper):
207-
self.mapper = mapper
208-
209-
def register(self):
210-
listen(self.mapper, 'after_delete', self.mapper_signal_after_delete)
211-
listen(self.mapper, 'after_insert', self.mapper_signal_after_insert)
212-
listen(self.mapper, 'after_update', self.mapper_signal_after_update)
213-
214-
def mapper_signal_after_delete(self, mapper, connection, target):
215-
self._record(mapper, target, 'delete')
216-
217-
def mapper_signal_after_insert(self, mapper, connection, target):
218-
self._record(mapper, target, 'insert')
219-
220-
def mapper_signal_after_update(self, mapper, connection, target):
221-
self._record(mapper, target, 'update')
211+
@classmethod
212+
def register(cls, mapper):
213+
event.listen(mapper, 'after_delete', cls.mapper_signal_after_delete)
214+
event.listen(mapper, 'after_insert', cls.mapper_signal_after_insert)
215+
event.listen(mapper, 'after_update', cls.mapper_signal_after_update)
216+
217+
@classmethod
218+
def unregister(cls, mapper):
219+
event.remove(mapper, 'after_delete', cls.mapper_signal_after_delete)
220+
event.remove(mapper, 'after_insert', cls.mapper_signal_after_insert)
221+
event.remove(mapper, 'after_update', cls.mapper_signal_after_update)
222+
223+
@classmethod
224+
def mapper_signal_after_delete(cls, mapper, connection, target):
225+
cls._record(mapper, target, 'delete')
226+
227+
@classmethod
228+
def mapper_signal_after_insert(cls, mapper, connection, target):
229+
cls._record(mapper, target, 'insert')
230+
231+
@classmethod
232+
def mapper_signal_after_update(cls, mapper, connection, target):
233+
cls._record(mapper, target, 'update')
222234

223235
@staticmethod
224236
def _record(mapper, target, operation):
225237
s = orm.object_session(target)
226-
if isinstance(s, SignallingSession) and s.emit_modification_signals:
238+
if isinstance(s, SignallingSession):
227239
pk = tuple(mapper.primary_key_from_instance(target))
228240
s._model_changes[pk] = (target, operation)
229241

230242

231-
232243
class _EngineDebuggingSignalEvents(object):
233244
"""Sets up handlers for two events that let us track the execution time of queries."""
234245

@@ -237,8 +248,8 @@ def __init__(self, engine, import_name):
237248
self.app_package = import_name
238249

239250
def register(self):
240-
listen(self.engine, 'before_cursor_execute', self.before_cursor_execute)
241-
listen(self.engine, 'after_cursor_execute', self.after_cursor_execute)
251+
event.listen(self.engine, 'before_cursor_execute', self.before_cursor_execute)
252+
event.listen(self.engine, 'after_cursor_execute', self.after_cursor_execute)
242253

243254
def before_cursor_execute(self, conn, cursor, statement,
244255
parameters, context, executemany):
@@ -680,32 +691,21 @@ class User(db.Model):
680691
a custom function which will define the SQLAlchemy session's scoping.
681692
"""
682693

683-
def __init__(self, app=None,
684-
use_native_unicode=True,
685-
session_options=None):
686-
self.use_native_unicode = use_native_unicode
687-
694+
def __init__(self, app=None, use_native_unicode=True, session_options=None):
688695
if session_options is None:
689696
session_options = {}
690697

691-
session_options.setdefault(
692-
'scopefunc', connection_stack.__ident_func__
693-
)
694-
698+
session_options.setdefault('scopefunc', connection_stack.__ident_func__)
699+
self.use_native_unicode = use_native_unicode
695700
self.session = self.create_scoped_session(session_options)
701+
self.Query = BaseQuery
696702
self.Model = self.make_declarative_base()
697703
self._engine_lock = Lock()
704+
self.app = app
705+
_include_sqlalchemy(self)
698706

699707
if app is not None:
700-
self.app = app
701708
self.init_app(app)
702-
else:
703-
self.app = None
704-
705-
_include_sqlalchemy(self)
706-
_MapperSignalEvents(self.mapper).register()
707-
_SessionSignalEvents().register()
708-
self.Query = BaseQuery
709709

710710
@property
711711
def metadata(self):
@@ -753,7 +753,14 @@ def init_app(self, app):
753753
app.config.setdefault('SQLALCHEMY_POOL_RECYCLE', None)
754754
app.config.setdefault('SQLALCHEMY_MAX_OVERFLOW', None)
755755
app.config.setdefault('SQLALCHEMY_COMMIT_ON_TEARDOWN', False)
756-
app.config.setdefault('SQLALCHEMY_TRACK_MODIFICATIONS', True)
756+
track_modifications = app.config.setdefault('SQLALCHEMY_TRACK_MODIFICATIONS', None)
757+
758+
if track_modifications is None:
759+
track_modifications = True
760+
warnings.warn('SQLALCHEMY_TRACK_MODIFICATIONS adds significant overhead and will be disabled by default in the future. Set it to True to suppress this warning.')
761+
762+
if track_modifications:
763+
_MapperSignalEvents.register(self.mapper)
757764

758765
if not hasattr(app, 'extensions'):
759766
app.extensions = {}

0 commit comments

Comments
 (0)