Skip to content

Commit fd593cd

Browse files
craigdsacdha
authored andcommitted
fix: queryset slicing and reduced code duplication
Now pagination will not lazy-load all earlier pages before returning the result. Thanks to @craigds for the patch Closes django-haystack#1269 Closes django-haystack#960
1 parent 20cee68 commit fd593cd

File tree

4 files changed

+109
-239
lines changed

4 files changed

+109
-239
lines changed

haystack/query.py

Lines changed: 90 additions & 180 deletions
Original file line numberDiff line numberDiff line change
@@ -154,36 +154,6 @@ def _manual_iter(self):
154154
if not self._fill_cache(current_position, current_position + ITERATOR_LOAD_PER_QUERY):
155155
raise StopIteration
156156

157-
def _fill_cache(self, start, end, **kwargs):
158-
# Tell the query where to start from and how many we'd like.
159-
self.query._reset()
160-
self.query.set_limits(start, end)
161-
results = self.query.get_results(**kwargs)
162-
163-
if results is None or len(results) == 0:
164-
return False
165-
166-
# Setup the full cache now that we know how many results there are.
167-
# We need the ``None``s as placeholders to know what parts of the
168-
# cache we have/haven't filled.
169-
# Using ``None`` like this takes up very little memory. In testing,
170-
# an array of 100,000 ``None``s consumed less than .5 Mb, which ought
171-
# to be an acceptable loss for consistent and more efficient caching.
172-
if len(self._result_cache) == 0:
173-
self._result_cache = [None] * self.query.get_count()
174-
175-
if start is None:
176-
start = 0
177-
178-
if end is None:
179-
end = self.query.get_count()
180-
181-
to_cache = self.post_process_results(results)
182-
183-
# Assign by slice.
184-
self._result_cache[start:start + len(to_cache)] = to_cache
185-
return True
186-
187157
def post_process_results(self, results):
188158
to_cache = []
189159

@@ -198,15 +168,7 @@ def post_process_results(self, results):
198168

199169
# Load the objects for each model in turn.
200170
for model in models_pks:
201-
try:
202-
ui = connections[self.query._using].get_unified_index()
203-
index = ui.get_index(model)
204-
objects = index.read_queryset(using=self.query._using)
205-
loaded_objects[model] = objects.in_bulk(models_pks[model])
206-
except NotHandled:
207-
self.log.warning("Model '%s' not handled by the routers", model)
208-
# Revert to old behaviour
209-
loaded_objects[model] = model._default_manager.in_bulk(models_pks[model])
171+
loaded_objects[model] = self._load_model_objects(model, models_pks[model])
210172

211173
for result in results:
212174
if self._load_all:
@@ -223,12 +185,86 @@ def post_process_results(self, results):
223185
# The object was either deleted since we indexed or should
224186
# be ignored; fail silently.
225187
self._ignored_result_count += 1
188+
189+
# avoid an unfilled None at the end of the result cache
190+
self._result_cache.pop()
226191
continue
227192

228193
to_cache.append(result)
229194

230195
return to_cache
231196

197+
def _load_model_objects(self, model, pks):
198+
try:
199+
ui = connections[self.query._using].get_unified_index()
200+
index = ui.get_index(model)
201+
objects = index.read_queryset(using=self.query._using)
202+
return objects.in_bulk(pks)
203+
except NotHandled:
204+
self.log.warning("Model '%s' not handled by the routers.", model)
205+
# Revert to old behaviour
206+
return model._default_manager.in_bulk(pks)
207+
208+
def _fill_cache(self, start, end, **kwargs):
209+
# Tell the query where to start from and how many we'd like.
210+
self.query._reset()
211+
212+
if start is None:
213+
start = 0
214+
215+
query_start = start
216+
query_start += self._ignored_result_count
217+
query_end = end
218+
if query_end is not None:
219+
query_end += self._ignored_result_count
220+
221+
self.query.set_limits(query_start, query_end)
222+
results = self.query.get_results(**kwargs)
223+
224+
if results is None or len(results) == 0:
225+
# trim missing stuff from the result cache
226+
self._result_cache = self._result_cache[:start]
227+
return False
228+
229+
# Setup the full cache now that we know how many results there are.
230+
# We need the ``None``s as placeholders to know what parts of the
231+
# cache we have/haven't filled.
232+
# Using ``None`` like this takes up very little memory. In testing,
233+
# an array of 100,000 ``None``s consumed less than .5 Mb, which ought
234+
# to be an acceptable loss for consistent and more efficient caching.
235+
if len(self._result_cache) == 0:
236+
self._result_cache = [None] * self.query.get_count()
237+
238+
fill_start, fill_end = start, end
239+
if fill_end is None:
240+
fill_end = self.query.get_count()
241+
cache_start = fill_start
242+
243+
while True:
244+
to_cache = self.post_process_results(results)
245+
246+
# Assign by slice.
247+
self._result_cache[cache_start:cache_start + len(to_cache)] = to_cache
248+
249+
if None in self._result_cache[start:end]:
250+
fill_start = fill_end
251+
fill_end += ITERATOR_LOAD_PER_QUERY
252+
cache_start += len(to_cache)
253+
254+
# Tell the query where to start from and how many we'd like.
255+
self.query._reset()
256+
self.query.set_limits(fill_start, fill_end)
257+
results = self.query.get_results()
258+
259+
if results is None or len(results) == 0:
260+
# No more results. Trim missing stuff from the result cache
261+
self._result_cache = self._result_cache[:cache_start]
262+
break
263+
else:
264+
break
265+
266+
return True
267+
232268
def __getitem__(self, k):
233269
"""
234270
Retrieves an item or slice from the set of results.
@@ -665,151 +701,30 @@ def post_process_results(self, results):
665701
class RelatedSearchQuerySet(SearchQuerySet):
666702
"""
667703
A variant of the SearchQuerySet that can handle `load_all_queryset`s.
668-
669-
This is predominantly different in the `_fill_cache` method, as it is
670-
far less efficient but needs to fill the cache before it to maintain
671-
consistency.
672704
"""
673705

674706
def __init__(self, *args, **kwargs):
675707
super(RelatedSearchQuerySet, self).__init__(*args, **kwargs)
676708
self._load_all_querysets = {}
677709
self._result_cache = []
678710

679-
def _cache_is_full(self):
680-
return len(self._result_cache) >= len(self)
681-
682-
def _manual_iter(self):
683-
# If we're here, our cache isn't fully populated.
684-
# For efficiency, fill the cache as we go if we run out of results.
685-
# Also, this can't be part of the __iter__ method due to Python's rules
686-
# about generator functions.
687-
current_position = 0
688-
current_cache_max = 0
689-
690-
while True:
691-
current_cache_max = len(self._result_cache)
692-
693-
while current_position < current_cache_max:
694-
yield self._result_cache[current_position]
695-
current_position += 1
696-
697-
if self._cache_is_full():
698-
raise StopIteration
699-
700-
# We've run out of results and haven't hit our limit.
701-
# Fill more of the cache.
702-
start = current_position + self._ignored_result_count
703-
704-
if not self._fill_cache(start, start + ITERATOR_LOAD_PER_QUERY):
705-
raise StopIteration
706-
707-
def _fill_cache(self, start, end):
708-
# Tell the query where to start from and how many we'd like.
709-
self.query._reset()
710-
self.query.set_limits(start, end)
711-
results = self.query.get_results()
712-
713-
if len(results) == 0:
714-
return False
715-
716-
if start is None:
717-
start = 0
718-
719-
if end is None:
720-
end = self.query.get_count()
721-
722-
# Check if we wish to load all objects.
723-
if self._load_all:
724-
models_pks = {}
725-
loaded_objects = {}
726-
727-
# Remember the search position for each result so we don't have to resort later.
728-
for result in results:
729-
models_pks.setdefault(result.model, []).append(result.pk)
730-
731-
# Load the objects for each model in turn.
732-
for model in models_pks:
733-
if model in self._load_all_querysets:
734-
# Use the overriding queryset.
735-
loaded_objects[model] = self._load_all_querysets[model].in_bulk(models_pks[model])
736-
else:
737-
# Check the SearchIndex for the model for an override.
738-
try:
739-
index = connections[self.query._using].get_unified_index().get_index(model)
740-
qs = index.load_all_queryset()
741-
loaded_objects[model] = qs.in_bulk(models_pks[model])
742-
except NotHandled:
743-
# The model returned doesn't seem to be handled by the
744-
# routers. We should silently fail and populate
745-
# nothing for those objects.
746-
loaded_objects[model] = []
747-
748-
if len(results) + len(self._result_cache) < len(self) and len(results) < ITERATOR_LOAD_PER_QUERY:
749-
self._ignored_result_count += ITERATOR_LOAD_PER_QUERY - len(results)
750-
751-
for result in results:
752-
if self._load_all:
753-
# We have to deal with integer keys being cast from strings; if this
754-
# fails we've got a character pk.
755-
try:
756-
result.pk = int(result.pk)
757-
except ValueError:
758-
pass
759-
try:
760-
result._object = loaded_objects[result.model][result.pk]
761-
except (KeyError, IndexError):
762-
# The object was either deleted since we indexed or should
763-
# be ignored; fail silently.
764-
self._ignored_result_count += 1
765-
continue
766-
767-
self._result_cache.append(result)
768-
769-
return True
770-
771-
def __getitem__(self, k):
772-
"""
773-
Retrieves an item or slice from the set of results.
774-
"""
775-
if not isinstance(k, (slice, six.integer_types)):
776-
raise TypeError
777-
778-
assert ((not isinstance(k, slice) and (k >= 0))
779-
or (isinstance(k, slice) and (k.start is None or k.start >= 0)
780-
and (k.stop is None or k.stop >= 0))), \
781-
"Negative indexing is not supported."
782-
783-
# Remember if it's a slice or not. We're going to treat everything as
784-
# a slice to simply the logic and will `.pop()` at the end as needed.
785-
if isinstance(k, slice):
786-
is_slice = True
787-
start = k.start
788-
789-
if k.stop is not None:
790-
bound = int(k.stop)
791-
else:
792-
bound = None
711+
def _load_model_objects(self, model, pks):
712+
if model in self._load_all_querysets:
713+
# Use the overriding queryset.
714+
return self._load_all_querysets[model].in_bulk(pks)
793715
else:
794-
is_slice = False
795-
start = k
796-
bound = k + 1
716+
# Check the SearchIndex for the model for an override.
797717

798-
# We need check to see if we need to populate more of the cache.
799-
if len(self._result_cache) <= 0 or not self._cache_is_full():
800718
try:
801-
while len(self._result_cache) < bound and not self._cache_is_full():
802-
current_max = len(self._result_cache) + self._ignored_result_count
803-
self._fill_cache(current_max, current_max + ITERATOR_LOAD_PER_QUERY)
804-
except StopIteration:
805-
# There's nothing left, even though the bound is higher.
806-
pass
807-
808-
# Cache should be full enough for our needs.
809-
if is_slice:
810-
return self._result_cache[start:bound]
811-
else:
812-
return self._result_cache[start]
719+
ui = connections[self.query._using].get_unified_index()
720+
index = ui.get_index(model)
721+
qs = index.load_all_queryset()
722+
return qs.in_bulk(pks)
723+
except NotHandled:
724+
# The model returned doesn't seem to be handled by the
725+
# routers. We should silently fail and populate
726+
# nothing for those objects.
727+
return {}
813728

814729
def load_all_queryset(self, model, queryset):
815730
"""
@@ -824,11 +739,6 @@ def load_all_queryset(self, model, queryset):
824739
return clone
825740

826741
def _clone(self, klass=None):
827-
if klass is None:
828-
klass = self.__class__
829-
830-
query = self.query._clone()
831-
clone = klass(query=query)
832-
clone._load_all = self._load_all
742+
clone = super(RelatedSearchQuerySet, self)._clone(klass=klass)
833743
clone._load_all_querysets = self._load_all_querysets
834744
return clone

test_haystack/elasticsearch_tests/test_elasticsearch_backend.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -921,26 +921,26 @@ def test_related_iter(self):
921921
sqs = self.rsqs.all()
922922
results = set([int(result.pk) for result in sqs])
923923
self.assertEqual(results, set([2, 7, 12, 17, 1, 6, 11, 16, 23, 5, 10, 15, 22, 4, 9, 14, 19, 21, 3, 8, 13, 18, 20]))
924-
self.assertEqual(len(connections['elasticsearch'].queries), 4)
924+
self.assertEqual(len(connections['elasticsearch'].queries), 3)
925925

926926
def test_related_slice(self):
927927
reset_search_queries()
928928
self.assertEqual(len(connections['elasticsearch'].queries), 0)
929929
results = self.rsqs.all().order_by('pub_date')
930930
self.assertEqual([int(result.pk) for result in results[1:11]], [3, 2, 4, 5, 6, 7, 8, 9, 10, 11])
931-
self.assertEqual(len(connections['elasticsearch'].queries), 3)
931+
self.assertEqual(len(connections['elasticsearch'].queries), 1)
932932

933933
reset_search_queries()
934934
self.assertEqual(len(connections['elasticsearch'].queries), 0)
935935
results = self.rsqs.all().order_by('pub_date')
936936
self.assertEqual(int(results[21].pk), 22)
937-
self.assertEqual(len(connections['elasticsearch'].queries), 4)
937+
self.assertEqual(len(connections['elasticsearch'].queries), 1)
938938

939939
reset_search_queries()
940940
self.assertEqual(len(connections['elasticsearch'].queries), 0)
941941
results = self.rsqs.all().order_by('pub_date')
942942
self.assertEqual(set([int(result.pk) for result in results[20:30]]), set([21, 22, 23]))
943-
self.assertEqual(len(connections['elasticsearch'].queries), 4)
943+
self.assertEqual(len(connections['elasticsearch'].queries), 1)
944944

945945
def test_related_manual_iter(self):
946946
results = self.rsqs.all()
@@ -949,7 +949,7 @@ def test_related_manual_iter(self):
949949
self.assertEqual(len(connections['elasticsearch'].queries), 0)
950950
results = sorted([int(result.pk) for result in results._manual_iter()])
951951
self.assertEqual(results, list(range(1, 24)))
952-
self.assertEqual(len(connections['elasticsearch'].queries), 4)
952+
self.assertEqual(len(connections['elasticsearch'].queries), 3)
953953

954954
def test_related_fill_cache(self):
955955
reset_search_queries()
@@ -971,7 +971,7 @@ def test_related_cache_is_full(self):
971971
results = self.rsqs.all()
972972
fire_the_iterator_and_fill_cache = [result for result in results]
973973
self.assertEqual(results._cache_is_full(), True)
974-
self.assertEqual(len(connections['elasticsearch'].queries), 5)
974+
self.assertEqual(len(connections['elasticsearch'].queries), 3)
975975

976976
def test_quotes_regression(self):
977977
sqs = self.sqs.auto_query(u"44°48'40''N 20°28'32''E")

0 commit comments

Comments
 (0)