Skip to content

Commit df7f4e0

Browse files
committed
Merge pull request pallets-eco#328 from justanr/custom_query_class_feature
Implement custom query and model class options
2 parents 357ec7b + 3ff4c64 commit df7f4e0

File tree

2 files changed

+92
-17
lines changed

2 files changed

+92
-17
lines changed

flask_sqlalchemy/__init__.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -68,34 +68,34 @@ def _make_table(*args, **kwargs):
6868
return _make_table
6969

7070

71-
def _set_default_query_class(d):
71+
def _set_default_query_class(d, cls):
7272
if 'query_class' not in d:
73-
d['query_class'] = BaseQuery
73+
d['query_class'] = cls
7474

7575

76-
def _wrap_with_default_query_class(fn):
76+
def _wrap_with_default_query_class(fn, cls):
7777
@functools.wraps(fn)
7878
def newfn(*args, **kwargs):
79-
_set_default_query_class(kwargs)
79+
_set_default_query_class(kwargs, cls)
8080
if "backref" in kwargs:
8181
backref = kwargs['backref']
8282
if isinstance(backref, string_types):
8383
backref = (backref, {})
84-
_set_default_query_class(backref[1])
84+
_set_default_query_class(backref[1], cls)
8585
return fn(*args, **kwargs)
8686
return newfn
8787

8888

89-
def _include_sqlalchemy(obj):
89+
def _include_sqlalchemy(obj, cls):
9090
for module in sqlalchemy, sqlalchemy.orm:
9191
for key in module.__all__:
9292
if not hasattr(obj, key):
9393
setattr(obj, key, getattr(module, key))
9494
# Note: obj.Table does not attempt to be a SQLAlchemy Table class.
9595
obj.Table = _make_table(obj)
96-
obj.relationship = _wrap_with_default_query_class(obj.relationship)
97-
obj.relation = _wrap_with_default_query_class(obj.relation)
98-
obj.dynamic_loader = _wrap_with_default_query_class(obj.dynamic_loader)
96+
obj.relationship = _wrap_with_default_query_class(obj.relationship, cls)
97+
obj.relation = _wrap_with_default_query_class(obj.relation, cls)
98+
obj.dynamic_loader = _wrap_with_default_query_class(obj.dynamic_loader, cls)
9999
obj.event = event
100100

101101

@@ -730,19 +730,21 @@ class User(db.Model):
730730
naming conventions among other, non-trivial things.
731731
"""
732732

733-
def __init__(self, app=None, use_native_unicode=True, session_options=None, metadata=None):
733+
def __init__(self, app=None, use_native_unicode=True, session_options=None,
734+
metadata=None, query_class=BaseQuery, model_class=Model):
734735

735736
if session_options is None:
736737
session_options = {}
737738

738739
session_options.setdefault('scopefunc', connection_stack.__ident_func__)
740+
session_options.setdefault('query_cls', query_class)
739741
self.use_native_unicode = use_native_unicode
740742
self.session = self.create_scoped_session(session_options)
741-
self.Model = self.make_declarative_base(metadata)
742-
self.Query = BaseQuery
743+
self.Query = query_class
744+
self.Model = self.make_declarative_base(model_class, metadata)
743745
self._engine_lock = Lock()
744746
self.app = app
745-
_include_sqlalchemy(self)
747+
_include_sqlalchemy(self, query_class)
746748

747749
if app is not None:
748750
self.init_app(app)
@@ -770,11 +772,15 @@ def create_session(self, options):
770772
"""
771773
return SignallingSession(self, **options)
772774

773-
def make_declarative_base(self, metadata=None):
775+
def make_declarative_base(self, model, metadata=None):
774776
"""Creates the declarative base."""
775-
base = declarative_base(cls=Model, name='Model',
777+
base = declarative_base(cls=model, name='Model',
776778
metadata=metadata,
777779
metaclass=_BoundDeclarativeMeta)
780+
781+
if not getattr(base, 'query_class', None):
782+
base.query_class = self.Query
783+
778784
base.query = _QueryProperty(self)
779785
return base
780786

test_sqlalchemy.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -467,18 +467,87 @@ def test_default_query_class(self):
467467

468468
class Parent(db.Model):
469469
id = db.Column(db.Integer, primary_key=True)
470-
children = db.relationship("Child", backref = "parents", lazy='dynamic')
470+
children = db.relationship("Child", backref = "parent", lazy='dynamic')
471+
471472
class Child(db.Model):
472473
id = db.Column(db.Integer, primary_key=True)
473474
parent_id = db.Column(db.Integer, db.ForeignKey('parent.id'))
475+
474476
p = Parent()
475477
c = Child()
476478
c.parent = p
479+
477480
self.assertEqual(type(Parent.query), sqlalchemy.BaseQuery)
478481
self.assertEqual(type(Child.query), sqlalchemy.BaseQuery)
479482
self.assertTrue(isinstance(p.children, sqlalchemy.BaseQuery))
480-
#self.assertTrue(isinstance(c.parents, sqlalchemy.BaseQuery))
483+
self.assertTrue(isinstance(db.session.query(Parent), sqlalchemy.BaseQuery))
484+
485+
486+
class CustomQueryClassTestCase(unittest.TestCase):
487+
488+
def test_custom_query_class(self):
489+
class CustomQueryClass(sqlalchemy.BaseQuery):
490+
pass
491+
492+
class MyModelClass(object):
493+
pass
494+
495+
app = flask.Flask(__name__)
496+
app.config['SQLALCHEMY_ENGINE'] = 'sqlite://'
497+
app.config['TESTING'] = True
498+
db = sqlalchemy.SQLAlchemy(app, query_class=CustomQueryClass,
499+
model_class=MyModelClass)
500+
501+
class Parent(db.Model):
502+
id = db.Column(db.Integer, primary_key=True)
503+
children = db.relationship("Child", backref = "parent", lazy='dynamic')
504+
505+
class Child(db.Model):
506+
id = db.Column(db.Integer, primary_key=True)
507+
parent_id = db.Column(db.Integer, db.ForeignKey('parent.id'))
508+
509+
p = Parent()
510+
c = Child()
511+
c.parent = p
512+
513+
self.assertEqual(type(Parent.query), CustomQueryClass)
514+
self.assertEqual(type(Child.query), CustomQueryClass)
515+
self.assertTrue(isinstance(p.children, CustomQueryClass))
516+
self.assertEqual(db.Query, CustomQueryClass)
517+
self.assertEqual(db.Model.query_class, CustomQueryClass)
518+
self.assertTrue(isinstance(db.session.query(Parent), CustomQueryClass))
519+
520+
521+
def test_dont_override_model_default(self):
522+
class CustomQueryClass(sqlalchemy.BaseQuery):
523+
pass
524+
525+
app = flask.Flask(__name__)
526+
app.config['SQLALCHEMY_ENGINE'] = 'sqlite://'
527+
app.config['TESTING'] = True
528+
db = sqlalchemy.SQLAlchemy(app, query_class=CustomQueryClass)
529+
530+
class SomeModel(db.Model):
531+
id = db.Column(db.Integer, primary_key=True)
532+
533+
self.assertEqual(type(SomeModel.query), sqlalchemy.BaseQuery)
534+
535+
536+
class CustomModelClassTestCase(unittest.TestCase):
537+
538+
def test_custom_query_class(self):
539+
class CustomModelClass(sqlalchemy.Model):
540+
pass
541+
542+
app = flask.Flask(__name__)
543+
app.config['SQLALCHEMY_ENGINE'] = 'sqlite://'
544+
app.config['TESTING'] = True
545+
db = sqlalchemy.SQLAlchemy(app, model_class=CustomModelClass)
546+
547+
class SomeModel(db.Model):
548+
id = db.Column(db.Integer, primary_key=True)
481549

550+
self.assertTrue(isinstance(SomeModel(), CustomModelClass))
482551

483552
class SQLAlchemyIncludesTestCase(unittest.TestCase):
484553

0 commit comments

Comments
 (0)