Skip to content

Commit c98b33a

Browse files
authored
Merge pull request #69 from zrlay/master
features: ConnectionField customization, LazyReferences, better field descriptions
2 parents a1de06c + 349d5de commit c98b33a

File tree

12 files changed

+444
-139
lines changed

12 files changed

+444
-139
lines changed

graphene_mongo/converter.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,18 @@
1616
import mongoengine
1717

1818
from .advanced_types import PointFieldType, MultiPolygonFieldType
19-
from .fields import MongoengineConnectionField
20-
from .utils import import_single_dispatch
19+
from .utils import import_single_dispatch, get_field_description
2120

2221
singledispatch = import_single_dispatch()
2322

2423

24+
class MongoEngineConversionError(Exception):
25+
pass
26+
27+
2528
@singledispatch
2629
def convert_mongoengine_field(field, registry=None):
27-
raise Exception(
30+
raise MongoEngineConversionError(
2831
"Don't know how to convert the MongoEngine field %s (%s)" %
2932
(field, field.__class__))
3033

@@ -33,36 +36,36 @@ def convert_mongoengine_field(field, registry=None):
3336
@convert_mongoengine_field.register(mongoengine.StringField)
3437
@convert_mongoengine_field.register(mongoengine.URLField)
3538
def convert_field_to_string(field, registry=None):
36-
return String(description=field.db_field, required=field.required)
39+
return String(description=get_field_description(field, registry), required=field.required)
3740

3841

3942
@convert_mongoengine_field.register(mongoengine.UUIDField)
4043
@convert_mongoengine_field.register(mongoengine.ObjectIdField)
4144
def convert_field_to_id(field, registry=None):
42-
return ID(description=field.db_field, required=field.required)
45+
return ID(description=get_field_description(field, registry), required=field.required)
4346

4447

4548
@convert_mongoengine_field.register(mongoengine.IntField)
4649
@convert_mongoengine_field.register(mongoengine.LongField)
4750
def convert_field_to_int(field, registry=None):
48-
return Int(description=field.db_field, required=field.required)
51+
return Int(description=get_field_description(field, registry), required=field.required)
4952

5053

5154
@convert_mongoengine_field.register(mongoengine.BooleanField)
5255
def convert_field_to_boolean(field, registry=None):
53-
return Boolean(description=field.db_field, required=field.required)
56+
return Boolean(description=get_field_description(field, registry), required=field.required)
5457

5558

5659
@convert_mongoengine_field.register(mongoengine.DecimalField)
5760
@convert_mongoengine_field.register(mongoengine.FloatField)
5861
def convert_field_to_float(field, registry=None):
59-
return Float(description=field.db_field, required=field.required)
62+
return Float(description=get_field_description(field, registry), required=field.required)
6063

6164

6265
@convert_mongoengine_field.register(mongoengine.DictField)
6366
@convert_mongoengine_field.register(mongoengine.MapField)
6467
def convert_dict_to_jsonstring(field, registry=None):
65-
return JSONString(description=field.db_field, required=field.required)
68+
return JSONString(description=get_field_description(field, registry), required=field.required)
6669

6770

6871
@convert_mongoengine_field.register(mongoengine.PointField)
@@ -77,7 +80,7 @@ def convert_multipolygon_to_field(field, register=None):
7780

7881
@convert_mongoengine_field.register(mongoengine.DateTimeField)
7982
def convert_field_to_datetime(field, registry=None):
80-
return DateTime(description=field.db_field, required=field.required)
83+
return DateTime(description=get_field_description(field, registry), required=field.required)
8184

8285

8386
@convert_mongoengine_field.register(mongoengine.ListField)
@@ -91,15 +94,15 @@ def convert_field_to_list(field, registry=None):
9194
base_type = base_type._type
9295

9396
if is_node(base_type):
94-
return MongoengineConnectionField(base_type)
97+
return base_type._meta.connection_field_class(base_type)
9598

9699
# Non-relationship field
97100
relations = (mongoengine.ReferenceField, mongoengine.EmbeddedDocumentField)
98101
if not isinstance(base_type, (List, NonNull)) \
99102
and not isinstance(field.field, relations):
100103
base_type = type(base_type)
101104

102-
return List(base_type, description=field.db_field, required=field.required)
105+
return List(base_type, description=get_field_description(field, registry), required=field.required)
103106

104107

105108
@convert_mongoengine_field.register(mongoengine.EmbeddedDocumentField)
@@ -111,6 +114,23 @@ def dynamic_type():
111114
_type = registry.get_type_for_model(model)
112115
if not _type:
113116
return None
114-
return Field(_type)
117+
return Field(_type, description=get_field_description(field, registry))
118+
119+
return Dynamic(dynamic_type)
120+
121+
122+
@convert_mongoengine_field.register(mongoengine.LazyReferenceField)
123+
def convert_lazy_field_to_dynamic(field, registry=None):
124+
model = field.document_type
125+
126+
def lazy_resolver(root, *args, **kwargs):
127+
if getattr(root, field.name or field.db_name):
128+
return getattr(root, field.name or field.db_name).fetch()
129+
130+
def dynamic_type():
131+
_type = registry.get_type_for_model(model)
132+
if not _type:
133+
return None
134+
return Field(_type, resolver=lazy_resolver, description=get_field_description(field, registry))
115135

116136
return Dynamic(dynamic_type)

graphene_mongo/fields.py

Lines changed: 87 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,29 @@
11
from __future__ import absolute_import
22

3-
import mongoengine
43
from collections import OrderedDict
54
from functools import partial, reduce
65

6+
import mongoengine
7+
from graphene import PageInfo
78
from graphene.relay import ConnectionField
8-
from graphene.relay.connection import PageInfo
9-
from graphql_relay.connection.arrayconnection import connection_from_list_slice
10-
from graphql_relay.node.node import from_global_id
119
from graphene.types.argument import to_arguments
1210
from graphene.types.dynamic import Dynamic
13-
from graphene.types.structures import Structure
11+
from graphene.types.structures import Structure, List
12+
from graphql_relay.connection.arrayconnection import connection_from_list_slice
1413

1514
from .advanced_types import PointFieldType, MultiPolygonFieldType
16-
from .utils import get_model_reference_fields
15+
from .converter import convert_mongoengine_field, MongoEngineConversionError
16+
from .registry import get_global_registry
17+
from .utils import get_model_reference_fields, get_node_from_global_id
1718

1819

1920
class MongoengineConnectionField(ConnectionField):
2021

2122
def __init__(self, type, *args, **kwargs):
23+
get_queryset = kwargs.pop('get_queryset', None)
24+
if get_queryset:
25+
assert callable(get_queryset), "Attribute `get_queryset` on {} must be callable.".format(self)
26+
self._get_queryset = get_queryset
2227
super(MongoengineConnectionField, self).__init__(
2328
type,
2429
*args,
@@ -43,6 +48,10 @@ def node_type(self):
4348
def model(self):
4449
return self.node_type._meta.model
4550

51+
@property
52+
def registry(self):
53+
return getattr(self.node_type._meta, 'registry', get_global_registry())
54+
4655
@property
4756
def args(self):
4857
return to_arguments(
@@ -55,12 +64,19 @@ def args(self, args):
5564
self._base_args = args
5665

5766
def _field_args(self, items):
58-
def is_filterable(v):
59-
if isinstance(v, (ConnectionField, Dynamic)):
67+
def is_filterable(k):
68+
if not hasattr(self.model, k):
69+
return False
70+
if isinstance(getattr(self.model, k), property):
6071
return False
61-
# FIXME: Skip PointTypeField at this moment.
62-
if not isinstance(v.type, Structure) \
63-
and isinstance(v.type(), (PointFieldType, MultiPolygonFieldType)):
72+
try:
73+
converted = convert_mongoengine_field(getattr(self.model, k), self.registry)
74+
except MongoEngineConversionError:
75+
return False
76+
if isinstance(converted, (ConnectionField, Dynamic, List)):
77+
return False
78+
if callable(getattr(converted, 'type', None)) and isinstance(converted.type(),
79+
(PointFieldType, MultiPolygonFieldType)):
6480
return False
6581
return True
6682

@@ -69,7 +85,7 @@ def get_type(v):
6985
return v.type.of_type()
7086
return v.type()
7187

72-
return {k: get_type(v) for k, v in items if is_filterable(v)}
88+
return {k: get_type(v) for k, v in items if is_filterable(k)}
7389

7490
@property
7591
def field_args(self):
@@ -78,102 +94,82 @@ def field_args(self):
7894
@property
7995
def reference_args(self):
8096
def get_reference_field(r, kv):
81-
if callable(getattr(kv[1], 'get_type', None)):
82-
node = kv[1].get_type()._type._meta
83-
if not issubclass(node.model, mongoengine.EmbeddedDocument):
84-
r.update({kv[0]: node.fields['id']._type.of_type()})
97+
field = kv[1]
98+
mongo_field = getattr(self.model, kv[0], None)
99+
if isinstance(mongo_field, (mongoengine.LazyReferenceField, mongoengine.ReferenceField)):
100+
field = convert_mongoengine_field(mongo_field, self.registry)
101+
if callable(getattr(field, 'get_type', None)):
102+
_type = field.get_type()
103+
if _type:
104+
node = _type._type._meta
105+
if 'id' in node.fields and not issubclass(node.model, mongoengine.EmbeddedDocument):
106+
r.update({kv[0]: node.fields['id']._type.of_type()})
85107
return r
108+
86109
return reduce(get_reference_field, self.fields.items(), {})
87110

88111
@property
89112
def fields(self):
90113
return self._type._meta.fields
91114

92-
@classmethod
93-
def get_query(cls, model, info, **args):
115+
def get_queryset(self, model, info, **args):
94116

95-
if not callable(getattr(model, 'objects', None)):
96-
return [], 0
97-
98-
objs = model.objects()
99117
if args:
100-
reference_fields = get_model_reference_fields(model)
101-
reference_args = {}
118+
reference_fields = get_model_reference_fields(self.model)
119+
hydrated_references = {}
102120
for arg_name, arg in args.copy().items():
103121
if arg_name in reference_fields:
104-
reference_model = model._fields[arg_name]
105-
pk = from_global_id(args.pop(arg_name))[-1]
106-
reference_obj = reference_model.document_type_obj.objects(pk=pk).get()
107-
reference_args[arg_name] = reference_obj
108-
109-
args.update(reference_args)
110-
first = args.pop('first', None)
111-
last = args.pop('last', None)
112-
id = args.pop('id', None)
113-
before = args.pop('before', None)
114-
after = args.pop('after', None)
115-
116-
if id is not None:
117-
# https://github.com/graphql-python/graphene/issues/124
118-
args['pk'] = from_global_id(id)[-1]
119-
120-
objs = objs.filter(**args)
121-
122-
# https://github.com/graphql-python/graphene-mongo/issues/21
123-
if after is not None:
124-
_after = int(from_global_id(after)[-1])
125-
objs = objs[_after:]
126-
127-
if before is not None:
128-
_before = int(from_global_id(before)[-1])
129-
objs = objs[:_before]
130-
122+
reference_obj = get_node_from_global_id(reference_fields[arg_name], info, args.pop(arg_name))
123+
hydrated_references[arg_name] = reference_obj
124+
args.update(hydrated_references)
125+
if self._get_queryset:
126+
queryset_or_filters = self._get_queryset(model, info, **args)
127+
if isinstance(queryset_or_filters, mongoengine.QuerySet):
128+
return queryset_or_filters
129+
else:
130+
args.update(queryset_or_filters)
131+
return model.objects(**args)
132+
133+
def default_resolver(self, _root, info, **args):
134+
args = args or {}
135+
136+
connection_args = {
137+
'first': args.pop('first', None),
138+
'last': args.pop('last', None),
139+
'before': args.pop('before', None),
140+
'after': args.pop('after', None)
141+
}
142+
143+
_id = args.pop('id', None)
144+
145+
if _id is not None:
146+
objs = [get_node_from_global_id(self.node_type, info, _id)]
147+
list_length = 1
148+
elif callable(getattr(self.model, 'objects', None)):
149+
objs = self.get_queryset(self.model, info, **args)
131150
list_length = objs.count()
132-
133-
if first is not None:
134-
objs = objs[:first]
135-
if last is not None:
136-
# https://github.com/graphql-python/graphene-mongo/issues/20
137-
objs = objs[max(0, list_length - last):]
138151
else:
139-
list_length = objs.count()
140-
141-
return objs, list_length
142-
143-
# noqa
144-
@classmethod
145-
def merge_querysets(cls, default_queryset, queryset):
146-
return queryset & default_queryset
147-
148-
"""
149-
Notes: Not sure how does this work :(
150-
"""
151-
@classmethod
152-
def connection_resolver(cls, resolver, connection, model, root, info, **args):
153-
iterable = resolver(root, info, **args)
154-
155-
if not iterable:
156-
iterable, _len = cls.get_query(model, info, **args)
157-
158-
if root:
159-
# If we have a root, we must be at least 1 layer in, right?
160-
_len = 0
161-
else:
162-
_len = len(iterable)
152+
objs = []
153+
list_length = 0
163154

164155
connection = connection_from_list_slice(
165-
iterable,
166-
args,
167-
slice_start=0,
168-
list_length=_len,
169-
list_slice_length=_len,
170-
connection_type=connection,
156+
list_slice=objs,
157+
args=connection_args,
158+
list_length=list_length,
159+
connection_type=self.type,
160+
edge_type=self.type.Edge,
171161
pageinfo_type=PageInfo,
172-
edge_type=connection.Edge,
173162
)
174-
connection.iterable = iterable
175-
connection.length = _len
163+
connection.iterable = objs
176164
return connection
177165

166+
def chained_resolver(self, resolver, root, info, **args):
167+
resolved = resolver(root, info, **args)
168+
if resolved is not None:
169+
return resolved
170+
return self.default_resolver(root, info, **args)
171+
178172
def get_resolver(self, parent_resolver):
179-
return partial(self.connection_resolver, parent_resolver, self.type, self.model)
173+
super_resolver = self.resolver or parent_resolver
174+
resolver = partial(self.chained_resolver, super_resolver)
175+
return partial(self.connection_resolver, resolver, self.type)

0 commit comments

Comments
 (0)