Skip to content

Commit e9d478e

Browse files
committed
Merge branch 'master' of http://github.com/cyberdelia/mongoengine into v0.4
2 parents d6cb5b9 + 9c99036 commit e9d478e

File tree

5 files changed

+78
-7
lines changed

5 files changed

+78
-7
lines changed

docs/guide/document-instances.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,13 @@ you may still use :attr:`id` to access the primary key if you want::
5959
>>> bob.id == bob.email == '[email protected]'
6060
True
6161

62+
You can also access the document's "primary key" using the :attr:`pk` field; in
63+
is an alias to :attr:`id`::
64+
65+
>>> page = Page(title="Another Test Page")
66+
>>> page.save()
67+
>>> page.id == page.pk
68+
6269
.. note::
6370
If you define your own primary key field, the field implicitly becomes
6471
required, so a :class:`ValidationError` will be thrown if you don't provide

mongoengine/base.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -330,14 +330,17 @@ class BaseDocument(object):
330330

331331
def __init__(self, **values):
332332
self._data = {}
333+
# Assign default values to instance
334+
for attr_name in self._fields.keys():
335+
# Use default value if present
336+
value = getattr(self, attr_name, None)
337+
setattr(self, attr_name, value)
333338
# Assign initial values to instance
334-
for attr_name, attr_value in self._fields.items():
335-
if attr_name in values:
339+
for attr_name in values.keys():
340+
try:
336341
setattr(self, attr_name, values.pop(attr_name))
337-
else:
338-
# Use default value if present
339-
value = getattr(self, attr_name, None)
340-
setattr(self, attr_name, value)
342+
except AttributeError:
343+
pass
341344

342345
def validate(self):
343346
"""Ensure that all fields' values are valid and that required fields
@@ -373,6 +376,16 @@ def _get_subclasses(cls):
373376
all_subclasses.update(subclass._get_subclasses())
374377
return all_subclasses
375378

379+
@apply
380+
def pk():
381+
"""Primary key alias
382+
"""
383+
def fget(self):
384+
return getattr(self, self._meta['id_field'])
385+
def fset(self, value):
386+
return setattr(self, self._meta['id_field'], value)
387+
return property(fget, fset)
388+
376389
def __iter__(self):
377390
return iter(self._fields)
378391

mongoengine/queryset.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,9 @@ def _lookup_field(cls, document, parts):
448448
for field_name in parts:
449449
if field is None:
450450
# Look up first field from the document
451+
if field_name == 'pk':
452+
# Deal with "primary key" alias
453+
field_name = document._meta['id_field']
451454
field = document._fields[field_name]
452455
else:
453456
# Look up subfield on the previous field

tests/document.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,12 +386,26 @@ class EmailUser(User):
386386

387387
user_obj = User.objects.first()
388388
self.assertEqual(user_obj.id, 'test')
389+
self.assertEqual(user_obj.pk, 'test')
389390

390391
user_son = User.objects._collection.find_one()
391392
self.assertEqual(user_son['_id'], 'test')
392393
self.assertTrue('username' not in user_son['_id'])
393394

394395
User.drop_collection()
396+
397+
user = User(pk='mongo', name='mongo user')
398+
user.save()
399+
400+
user_obj = User.objects.first()
401+
self.assertEqual(user_obj.id, 'mongo')
402+
self.assertEqual(user_obj.pk, 'mongo')
403+
404+
user_son = User.objects._collection.find_one()
405+
self.assertEqual(user_son['_id'], 'mongo')
406+
self.assertTrue('username' not in user_son['_id'])
407+
408+
User.drop_collection()
395409

396410
def test_creation(self):
397411
"""Ensure that document may be created using keyword arguments.
@@ -510,6 +524,18 @@ def test_save_custom_id(self):
510524
collection = self.db[self.Person._meta['collection']]
511525
person_obj = collection.find_one({'name': 'Test User'})
512526
self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4')
527+
528+
def test_save_custom_pk(self):
529+
"""Ensure that a document may be saved with a custom _id using pk alias.
530+
"""
531+
# Create person object and save it to the database
532+
person = self.Person(name='Test User', age=30,
533+
pk='497ce96f395f2f052a494fd4')
534+
person.save()
535+
# Ensure that the object is in the database with the correct _id
536+
collection = self.db[self.Person._meta['collection']]
537+
person_obj = collection.find_one({'name': 'Test User'})
538+
self.assertEqual(str(person_obj['_id']), '497ce96f395f2f052a494fd4')
513539

514540
def test_save_list(self):
515541
"""Ensure that a list field may be properly saved.

tests/queryset.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1106,20 +1106,42 @@ class BlogPost(Document):
11061106
BlogPost.drop_collection()
11071107

11081108
data = {'title': 'Post 1', 'comments': [Comment(content='test')]}
1109-
BlogPost(**data).save()
1109+
post = BlogPost(**data)
1110+
post.save()
11101111

11111112
self.assertTrue('postTitle' in
11121113
BlogPost.objects(title=data['title'])._query)
11131114
self.assertFalse('title' in
11141115
BlogPost.objects(title=data['title'])._query)
11151116
self.assertEqual(len(BlogPost.objects(title=data['title'])), 1)
11161117

1118+
self.assertTrue('_id' in BlogPost.objects(pk=post.id)._query)
1119+
self.assertEqual(len(BlogPost.objects(pk=post.id)), 1)
1120+
11171121
self.assertTrue('postComments.commentContent' in
11181122
BlogPost.objects(comments__content='test')._query)
11191123
self.assertEqual(len(BlogPost.objects(comments__content='test')), 1)
11201124

11211125
BlogPost.drop_collection()
11221126

1127+
def test_query_pk_field_name(self):
1128+
"""Ensure that the correct "primary key" field name is used when querying
1129+
"""
1130+
class BlogPost(Document):
1131+
title = StringField(primary_key=True, db_field='postTitle')
1132+
1133+
BlogPost.drop_collection()
1134+
1135+
data = { 'title':'Post 1' }
1136+
post = BlogPost(**data)
1137+
post.save()
1138+
1139+
self.assertTrue('_id' in BlogPost.objects(pk=data['title'])._query)
1140+
self.assertTrue('_id' in BlogPost.objects(title=data['title'])._query)
1141+
self.assertEqual(len(BlogPost.objects(pk=data['title'])), 1)
1142+
1143+
BlogPost.drop_collection()
1144+
11231145
def test_query_value_conversion(self):
11241146
"""Ensure that query values are properly converted when necessary.
11251147
"""

0 commit comments

Comments
 (0)