@@ -68,34 +68,34 @@ def _make_table(*args, **kwargs):
68
68
return _make_table
69
69
70
70
71
- def _set_default_query_class (d ):
71
+ def _set_default_query_class (d , cls ):
72
72
if 'query_class' not in d :
73
- d ['query_class' ] = BaseQuery
73
+ d ['query_class' ] = cls
74
74
75
75
76
- def _wrap_with_default_query_class (fn ):
76
+ def _wrap_with_default_query_class (fn , cls ):
77
77
@functools .wraps (fn )
78
78
def newfn (* args , ** kwargs ):
79
- _set_default_query_class (kwargs )
79
+ _set_default_query_class (kwargs , cls )
80
80
if "backref" in kwargs :
81
81
backref = kwargs ['backref' ]
82
82
if isinstance (backref , string_types ):
83
83
backref = (backref , {})
84
- _set_default_query_class (backref [1 ])
84
+ _set_default_query_class (backref [1 ], cls )
85
85
return fn (* args , ** kwargs )
86
86
return newfn
87
87
88
88
89
- def _include_sqlalchemy (obj ):
89
+ def _include_sqlalchemy (obj , cls ):
90
90
for module in sqlalchemy , sqlalchemy .orm :
91
91
for key in module .__all__ :
92
92
if not hasattr (obj , key ):
93
93
setattr (obj , key , getattr (module , key ))
94
94
# Note: obj.Table does not attempt to be a SQLAlchemy Table class.
95
95
obj .Table = _make_table (obj )
96
- obj .relationship = _wrap_with_default_query_class (obj .relationship )
97
- obj .relation = _wrap_with_default_query_class (obj .relation )
98
- obj .dynamic_loader = _wrap_with_default_query_class (obj .dynamic_loader )
96
+ obj .relationship = _wrap_with_default_query_class (obj .relationship , cls )
97
+ obj .relation = _wrap_with_default_query_class (obj .relation , cls )
98
+ obj .dynamic_loader = _wrap_with_default_query_class (obj .dynamic_loader , cls )
99
99
obj .event = event
100
100
101
101
@@ -730,19 +730,21 @@ class User(db.Model):
730
730
naming conventions among other, non-trivial things.
731
731
"""
732
732
733
- def __init__ (self , app = None , use_native_unicode = True , session_options = None , metadata = None ):
733
+ def __init__ (self , app = None , use_native_unicode = True , session_options = None ,
734
+ metadata = None , query_class = BaseQuery , model_class = Model ):
734
735
735
736
if session_options is None :
736
737
session_options = {}
737
738
738
739
session_options .setdefault ('scopefunc' , connection_stack .__ident_func__ )
740
+ session_options .setdefault ('query_cls' , query_class )
739
741
self .use_native_unicode = use_native_unicode
740
742
self .session = self .create_scoped_session (session_options )
741
- self .Model = self . make_declarative_base ( metadata )
742
- self .Query = BaseQuery
743
+ self .Query = query_class
744
+ self .Model = self . make_declarative_base ( model_class , metadata )
743
745
self ._engine_lock = Lock ()
744
746
self .app = app
745
- _include_sqlalchemy (self )
747
+ _include_sqlalchemy (self , query_class )
746
748
747
749
if app is not None :
748
750
self .init_app (app )
@@ -770,11 +772,15 @@ def create_session(self, options):
770
772
"""
771
773
return SignallingSession (self , ** options )
772
774
773
- def make_declarative_base (self , metadata = None ):
775
+ def make_declarative_base (self , model , metadata = None ):
774
776
"""Creates the declarative base."""
775
- base = declarative_base (cls = Model , name = 'Model' ,
777
+ base = declarative_base (cls = model , name = 'Model' ,
776
778
metadata = metadata ,
777
779
metaclass = _BoundDeclarativeMeta )
780
+
781
+ if not getattr (base , 'query_class' , None ):
782
+ base .query_class = self .Query
783
+
778
784
base .query = _QueryProperty (self )
779
785
return base
780
786
0 commit comments