Skip to content

Commit 2a450d8

Browse files
Terracdha
authored andcommitted
new: Support ManyToManyFields in model_attr lookups
Thanks to @Terr for the patch
1 parent cd7380e commit 2a450d8

File tree

5 files changed

+195
-28
lines changed

5 files changed

+195
-28
lines changed

AUTHORS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,4 @@ Thanks to
111111
* Claude Paroz (@claudep) for Django 1.9 support
112112
* Chris Brooke (@chrisbrooke) for patching around a backwards-incompatible change in ElasticSearch 2
113113
* Gilad Beeri (@giladbeeri) for adding retries when updating a backend
114+
* Arjen Verstoep (@terr) for a patch that allows attribute lookups through Django ManyToManyField relationships

haystack/fields.py

Lines changed: 66 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -80,40 +80,80 @@ def prepare(self, obj):
8080
if self.use_template:
8181
return self.prepare_template(obj)
8282
elif self.model_attr is not None:
83-
# Check for `__` in the field for looking through the relation.
84-
attrs = self.model_attr.split('__')
85-
current_object = obj
86-
87-
for attr in attrs:
88-
if not hasattr(current_object, attr):
89-
raise SearchFieldError("The model '%s' does not have a model_attr '%s'." % (repr(current_object), attr))
90-
91-
current_object = getattr(current_object, attr, None)
92-
93-
if current_object is None:
94-
if self.has_default():
95-
current_object = self._default
96-
# Fall out of the loop, given any further attempts at
97-
# accesses will fail miserably.
98-
break
99-
elif self.null:
100-
current_object = None
101-
# Fall out of the loop, given any further attempts at
102-
# accesses will fail miserably.
103-
break
104-
else:
105-
raise SearchFieldError("The model '%s' combined with model_attr '%s' returned None, but doesn't allow a default or null value." % (repr(obj), self.model_attr))
83+
attrs = self.split_model_attr_lookups()
84+
current_objects = [obj]
10685

107-
if callable(current_object):
108-
return current_object()
86+
values = self.resolve_attributes_lookup(current_objects, attrs)
10987

110-
return current_object
88+
if len(values) == 1:
89+
return values[0]
90+
else:
91+
return values
11192

11293
if self.has_default():
11394
return self.default
11495
else:
11596
return None
11697

98+
def resolve_attributes_lookup(self, current_objects, attributes):
99+
"""
100+
Recursive method that looks, for one or more objects, for an attribute that can be multiple
101+
objects (relations) deep.
102+
"""
103+
values = []
104+
105+
for current_object in current_objects:
106+
if not hasattr(current_object, attributes[0]):
107+
raise SearchFieldError(
108+
"The model '%s' does not have a model_attr '%s'." % (repr(current_object), attributes[0])
109+
)
110+
111+
if len(attributes) > 1:
112+
current_objects_in_attr = self.get_iterable_objects(getattr(current_object, attributes[0]))
113+
114+
return self.resolve_attributes_lookup(current_objects_in_attr, attributes[1:])
115+
116+
current_object = getattr(current_object, attributes[0])
117+
118+
if current_object is None:
119+
if self.has_default():
120+
current_object = self._default
121+
elif self.null:
122+
current_object = None
123+
else:
124+
raise SearchFieldError(
125+
"The model '%s' combined with model_attr '%s' returned None, but doesn't allow "
126+
"a default or null value." % (repr(current_object), self.model_attr)
127+
)
128+
129+
if callable(current_object):
130+
values.append(current_object())
131+
else:
132+
values.append(current_object)
133+
134+
return values
135+
136+
def split_model_attr_lookups(self):
137+
"""Returns list of nested attributes for looking through the relation."""
138+
return self.model_attr.split('__')
139+
140+
@classmethod
141+
def get_iterable_objects(cls, current_objects):
142+
"""
143+
Returns iterable of objects that contain data. For example, resolves Django ManyToMany relationship
144+
so the attributes of the related models can then be accessed.
145+
"""
146+
if current_objects is None:
147+
return []
148+
149+
if hasattr(current_objects, 'all'):
150+
# i.e, Django ManyToMany relationships
151+
current_objects = current_objects.all()
152+
elif not hasattr(current_objects, '__iter__'):
153+
current_objects = [current_objects]
154+
155+
return current_objects
156+
117157
def prepare_template(self, obj):
118158
"""
119159
Flattens an object for indexing.

test_haystack/core/models.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,22 @@ class ScoreMockModel(models.Model):
7979

8080
def __unicode__(self):
8181
return self.score
82+
83+
84+
class ManyToManyLeftSideModel(models.Model):
85+
related_models = models.ManyToManyField('ManyToManyRightSideModel')
86+
87+
88+
class ManyToManyRightSideModel(models.Model):
89+
name = models.CharField(max_length=32, default='Default name')
90+
91+
def __unicode__(self):
92+
return self.name
93+
94+
95+
class OneToManyLeftSideModel(models.Model):
96+
pass
97+
98+
99+
class OneToManyRightSideModel(models.Model):
100+
left_side = models.ForeignKey(OneToManyLeftSideModel, related_name='right_side')

test_haystack/test_fields.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,87 @@
55
import datetime
66
from decimal import Decimal
77

8+
from mock import Mock
9+
810
from django.template import TemplateDoesNotExist
911
from django.test import TestCase
10-
from test_haystack.core.models import MockModel, MockTag
12+
from test_haystack.core.models import MockModel, MockTag, ManyToManyLeftSideModel, ManyToManyRightSideModel, \
13+
OneToManyLeftSideModel, OneToManyRightSideModel
1114

1215
from haystack.fields import *
1316

1417

18+
class SearchFieldTestCase(TestCase):
19+
def test_get_iterable_objects_with_none(self):
20+
self.assertEqual([], SearchField.get_iterable_objects(None))
21+
22+
def test_get_iterable_objects_with_single_non_iterable_object(self):
23+
obj = object()
24+
expected = [obj]
25+
26+
self.assertEqual(expected, SearchField.get_iterable_objects(obj))
27+
28+
def test_get_iterable_objects_with_list_stays_the_same(self):
29+
objects = [object(), object()]
30+
31+
self.assertIs(objects, SearchField.get_iterable_objects(objects))
32+
33+
def test_get_iterable_objects_with_django_manytomany_rel(self):
34+
left_model = ManyToManyLeftSideModel.objects.create()
35+
right_model_1 = ManyToManyRightSideModel.objects.create(name='Right side 1')
36+
right_model_2 = ManyToManyRightSideModel.objects.create()
37+
left_model.related_models.add(right_model_1)
38+
left_model.related_models.add(right_model_2)
39+
40+
result = SearchField.get_iterable_objects(left_model.related_models)
41+
42+
self.assertTrue(right_model_1 in result)
43+
self.assertTrue(right_model_2 in result)
44+
45+
def test_get_iterable_objects_with_django_onetomany_rel(self):
46+
left_model = OneToManyLeftSideModel.objects.create()
47+
right_model_1 = OneToManyRightSideModel.objects.create(left_side=left_model)
48+
right_model_2 = OneToManyRightSideModel.objects.create(left_side=left_model)
49+
50+
result = SearchField.get_iterable_objects(left_model.right_side)
51+
52+
self.assertTrue(right_model_1 in result)
53+
self.assertTrue(right_model_2 in result)
54+
55+
def test_resolve_attributes_lookup_with_field_that_points_to_none(self):
56+
related = Mock(spec=['none_field'], none_field=None)
57+
obj = Mock(spec=['related'], related=[related])
58+
59+
field = SearchField(null=False)
60+
61+
self.assertRaises(SearchFieldError, field.resolve_attributes_lookup, [obj], ['related', 'none_field'])
62+
63+
def test_resolve_attributes_lookup_with_field_that_points_to_none_but_is_allowed_to_be_null(self):
64+
related = Mock(spec=['none_field'], none_field=None)
65+
obj = Mock(spec=['related'], related=[related])
66+
67+
field = SearchField(null=True)
68+
69+
self.assertEqual([None], field.resolve_attributes_lookup([obj], ['related', 'none_field']))
70+
71+
def test_resolve_attributes_lookup_with_field_that_points_to_none_but_has_default(self):
72+
related = Mock(spec=['none_field'], none_field=None)
73+
obj = Mock(spec=['related'], related=[related])
74+
75+
field = SearchField(default='Default value')
76+
77+
self.assertEqual(['Default value'], field.resolve_attributes_lookup([obj], ['related', 'none_field']))
78+
79+
def test_resolve_attributes_lookup_with_deep_relationship(self):
80+
related_lvl_2 = Mock(spec=['value'], value=1)
81+
related = Mock(spec=['related'], related=[related_lvl_2, related_lvl_2])
82+
obj = Mock(spec=['related'], related=[related])
83+
84+
field = SearchField()
85+
86+
self.assertEqual([1, 1], field.resolve_attributes_lookup([obj], ['related', 'related', 'value']))
87+
88+
1589
class CharFieldTestCase(TestCase):
1690
def test_init(self):
1791
try:

test_haystack/test_indexes.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
from django.test import TestCase
1010
from django.utils.six.moves import queue
11-
from test_haystack.core.models import AFifthMockModel, AThirdMockModel, MockModel
11+
from test_haystack.core.models import (AFifthMockModel, AThirdMockModel, MockModel, ManyToManyLeftSideModel,
12+
ManyToManyRightSideModel)
1213

1314
from haystack import connection_router, connections, indexes
1415
from haystack.exceptions import SearchFieldError
@@ -134,6 +135,14 @@ class MROFieldsSearchChild(MROFieldsSearchIndexA, MROFieldsSearchIndexB):
134135
pass
135136

136137

138+
class ModelWithManyToManyFieldAndAttributeLookupSearchIndex(indexes.SearchIndex, indexes.Indexable):
139+
text = indexes.CharField(document=True)
140+
related_models = indexes.MultiValueField(model_attr='related_models__name')
141+
142+
def get_model(self):
143+
return ManyToManyLeftSideModel
144+
145+
137146
class SearchIndexTestCase(TestCase):
138147
fixtures = ['base_data']
139148

@@ -650,3 +659,27 @@ def test_float_integer_fields(self):
650659
self.assertTrue(isinstance(self.yabmsi.fields['average_delay'], indexes.FloatField))
651660
self.assertEqual(self.yabmsi.fields['average_delay'].null, False)
652661
self.assertEqual(self.yabmsi.fields['average_delay'].index_fieldname, 'average_delay')
662+
663+
664+
class ModelWithManyToManyFieldAndAttributeLookupSearchIndexTestCase(TestCase):
665+
def test_full_prepare(self):
666+
index = ModelWithManyToManyFieldAndAttributeLookupSearchIndex()
667+
668+
left_model = ManyToManyLeftSideModel.objects.create()
669+
right_model_1 = ManyToManyRightSideModel.objects.create(name='Right side 1')
670+
right_model_2 = ManyToManyRightSideModel.objects.create()
671+
left_model.related_models.add(right_model_1)
672+
left_model.related_models.add(right_model_2)
673+
674+
result = index.full_prepare(left_model)
675+
676+
self.assertDictEqual(
677+
result,
678+
{
679+
'django_ct': 'core.manytomanyleftsidemodel',
680+
'django_id': '1',
681+
'text': None,
682+
'id': 'core.manytomanyleftsidemodel.1',
683+
'related_models': ['Right side 1', 'Default name'],
684+
}
685+
)

0 commit comments

Comments
 (0)