1
1
from __future__ import absolute_import
2
2
3
- import mongoengine
4
3
from collections import OrderedDict
5
4
from functools import partial , reduce
6
5
6
+ import mongoengine
7
+ from graphene import PageInfo
7
8
from graphene .relay import ConnectionField
8
- from graphene .relay .connection import PageInfo
9
- from graphql_relay .connection .arrayconnection import connection_from_list_slice
10
- from graphql_relay .node .node import from_global_id
11
9
from graphene .types .argument import to_arguments
12
10
from graphene .types .dynamic import Dynamic
13
- from graphene .types .structures import Structure
11
+ from graphene .types .structures import Structure , List
12
+ from graphql_relay .connection .arrayconnection import connection_from_list_slice
14
13
15
14
from .advanced_types import PointFieldType , MultiPolygonFieldType
16
- from .utils import get_model_reference_fields
15
+ from .converter import convert_mongoengine_field , MongoEngineConversionError
16
+ from .registry import get_global_registry
17
+ from .utils import get_model_reference_fields , get_node_from_global_id
17
18
18
19
19
20
class MongoengineConnectionField (ConnectionField ):
20
21
21
22
def __init__ (self , type , * args , ** kwargs ):
23
+ get_queryset = kwargs .pop ('get_queryset' , None )
24
+ if get_queryset :
25
+ assert callable (get_queryset ), "Attribute `get_queryset` on {} must be callable." .format (self )
26
+ self ._get_queryset = get_queryset
22
27
super (MongoengineConnectionField , self ).__init__ (
23
28
type ,
24
29
* args ,
@@ -43,6 +48,10 @@ def node_type(self):
43
48
def model (self ):
44
49
return self .node_type ._meta .model
45
50
51
+ @property
52
+ def registry (self ):
53
+ return getattr (self .node_type ._meta , 'registry' , get_global_registry ())
54
+
46
55
@property
47
56
def args (self ):
48
57
return to_arguments (
@@ -55,12 +64,19 @@ def args(self, args):
55
64
self ._base_args = args
56
65
57
66
def _field_args (self , items ):
58
- def is_filterable (v ):
59
- if isinstance (v , (ConnectionField , Dynamic )):
67
+ def is_filterable (k ):
68
+ if not hasattr (self .model , k ):
69
+ return False
70
+ if isinstance (getattr (self .model , k ), property ):
60
71
return False
61
- # FIXME: Skip PointTypeField at this moment.
62
- if not isinstance (v .type , Structure ) \
63
- and isinstance (v .type (), (PointFieldType , MultiPolygonFieldType )):
72
+ try :
73
+ converted = convert_mongoengine_field (getattr (self .model , k ), self .registry )
74
+ except MongoEngineConversionError :
75
+ return False
76
+ if isinstance (converted , (ConnectionField , Dynamic , List )):
77
+ return False
78
+ if callable (getattr (converted , 'type' , None )) and isinstance (converted .type (),
79
+ (PointFieldType , MultiPolygonFieldType )):
64
80
return False
65
81
return True
66
82
@@ -69,7 +85,7 @@ def get_type(v):
69
85
return v .type .of_type ()
70
86
return v .type ()
71
87
72
- return {k : get_type (v ) for k , v in items if is_filterable (v )}
88
+ return {k : get_type (v ) for k , v in items if is_filterable (k )}
73
89
74
90
@property
75
91
def field_args (self ):
@@ -78,102 +94,82 @@ def field_args(self):
78
94
@property
79
95
def reference_args (self ):
80
96
def get_reference_field (r , kv ):
81
- if callable (getattr (kv [1 ], 'get_type' , None )):
82
- node = kv [1 ].get_type ()._type ._meta
83
- if not issubclass (node .model , mongoengine .EmbeddedDocument ):
84
- r .update ({kv [0 ]: node .fields ['id' ]._type .of_type ()})
97
+ field = kv [1 ]
98
+ mongo_field = getattr (self .model , kv [0 ], None )
99
+ if isinstance (mongo_field , (mongoengine .LazyReferenceField , mongoengine .ReferenceField )):
100
+ field = convert_mongoengine_field (mongo_field , self .registry )
101
+ if callable (getattr (field , 'get_type' , None )):
102
+ _type = field .get_type ()
103
+ if _type :
104
+ node = _type ._type ._meta
105
+ if 'id' in node .fields and not issubclass (node .model , mongoengine .EmbeddedDocument ):
106
+ r .update ({kv [0 ]: node .fields ['id' ]._type .of_type ()})
85
107
return r
108
+
86
109
return reduce (get_reference_field , self .fields .items (), {})
87
110
88
111
@property
89
112
def fields (self ):
90
113
return self ._type ._meta .fields
91
114
92
- @classmethod
93
- def get_query (cls , model , info , ** args ):
115
+ def get_queryset (self , model , info , ** args ):
94
116
95
- if not callable (getattr (model , 'objects' , None )):
96
- return [], 0
97
-
98
- objs = model .objects ()
99
117
if args :
100
- reference_fields = get_model_reference_fields (model )
101
- reference_args = {}
118
+ reference_fields = get_model_reference_fields (self . model )
119
+ hydrated_references = {}
102
120
for arg_name , arg in args .copy ().items ():
103
121
if arg_name in reference_fields :
104
- reference_model = model ._fields [arg_name ]
105
- pk = from_global_id (args .pop (arg_name ))[- 1 ]
106
- reference_obj = reference_model .document_type_obj .objects (pk = pk ).get ()
107
- reference_args [arg_name ] = reference_obj
108
-
109
- args .update (reference_args )
110
- first = args .pop ('first' , None )
111
- last = args .pop ('last' , None )
112
- id = args .pop ('id' , None )
113
- before = args .pop ('before' , None )
114
- after = args .pop ('after' , None )
115
-
116
- if id is not None :
117
- # https://github.com/graphql-python/graphene/issues/124
118
- args ['pk' ] = from_global_id (id )[- 1 ]
119
-
120
- objs = objs .filter (** args )
121
-
122
- # https://github.com/graphql-python/graphene-mongo/issues/21
123
- if after is not None :
124
- _after = int (from_global_id (after )[- 1 ])
125
- objs = objs [_after :]
126
-
127
- if before is not None :
128
- _before = int (from_global_id (before )[- 1 ])
129
- objs = objs [:_before ]
130
-
122
+ reference_obj = get_node_from_global_id (reference_fields [arg_name ], info , args .pop (arg_name ))
123
+ hydrated_references [arg_name ] = reference_obj
124
+ args .update (hydrated_references )
125
+ if self ._get_queryset :
126
+ queryset_or_filters = self ._get_queryset (model , info , ** args )
127
+ if isinstance (queryset_or_filters , mongoengine .QuerySet ):
128
+ return queryset_or_filters
129
+ else :
130
+ args .update (queryset_or_filters )
131
+ return model .objects (** args )
132
+
133
+ def default_resolver (self , _root , info , ** args ):
134
+ args = args or {}
135
+
136
+ connection_args = {
137
+ 'first' : args .pop ('first' , None ),
138
+ 'last' : args .pop ('last' , None ),
139
+ 'before' : args .pop ('before' , None ),
140
+ 'after' : args .pop ('after' , None )
141
+ }
142
+
143
+ _id = args .pop ('id' , None )
144
+
145
+ if _id is not None :
146
+ objs = [get_node_from_global_id (self .node_type , info , _id )]
147
+ list_length = 1
148
+ elif callable (getattr (self .model , 'objects' , None )):
149
+ objs = self .get_queryset (self .model , info , ** args )
131
150
list_length = objs .count ()
132
-
133
- if first is not None :
134
- objs = objs [:first ]
135
- if last is not None :
136
- # https://github.com/graphql-python/graphene-mongo/issues/20
137
- objs = objs [max (0 , list_length - last ):]
138
151
else :
139
- list_length = objs .count ()
140
-
141
- return objs , list_length
142
-
143
- # noqa
144
- @classmethod
145
- def merge_querysets (cls , default_queryset , queryset ):
146
- return queryset & default_queryset
147
-
148
- """
149
- Notes: Not sure how does this work :(
150
- """
151
- @classmethod
152
- def connection_resolver (cls , resolver , connection , model , root , info , ** args ):
153
- iterable = resolver (root , info , ** args )
154
-
155
- if not iterable :
156
- iterable , _len = cls .get_query (model , info , ** args )
157
-
158
- if root :
159
- # If we have a root, we must be at least 1 layer in, right?
160
- _len = 0
161
- else :
162
- _len = len (iterable )
152
+ objs = []
153
+ list_length = 0
163
154
164
155
connection = connection_from_list_slice (
165
- iterable ,
166
- args ,
167
- slice_start = 0 ,
168
- list_length = _len ,
169
- list_slice_length = _len ,
170
- connection_type = connection ,
156
+ list_slice = objs ,
157
+ args = connection_args ,
158
+ list_length = list_length ,
159
+ connection_type = self .type ,
160
+ edge_type = self .type .Edge ,
171
161
pageinfo_type = PageInfo ,
172
- edge_type = connection .Edge ,
173
162
)
174
- connection .iterable = iterable
175
- connection .length = _len
163
+ connection .iterable = objs
176
164
return connection
177
165
166
+ def chained_resolver (self , resolver , root , info , ** args ):
167
+ resolved = resolver (root , info , ** args )
168
+ if resolved is not None :
169
+ return resolved
170
+ return self .default_resolver (root , info , ** args )
171
+
178
172
def get_resolver (self , parent_resolver ):
179
- return partial (self .connection_resolver , parent_resolver , self .type , self .model )
173
+ super_resolver = self .resolver or parent_resolver
174
+ resolver = partial (self .chained_resolver , super_resolver )
175
+ return partial (self .connection_resolver , resolver , self .type )
0 commit comments