5
5
import logging as std_logging
6
6
import operator
7
7
import unittest
8
+ from contextlib import contextmanager
8
9
from decimal import Decimal
9
10
10
11
import elasticsearch
@@ -229,6 +230,24 @@ def test_kwargs_are_passed_on(self):
229
230
self .assertEqual (backend .conn .transport .max_retries , 42 )
230
231
231
232
233
+ class ElasticSearchMockUnifiedIndex (UnifiedIndex ):
234
+
235
+ spy_args = None
236
+
237
+ def get_index (self , model_klass ):
238
+ if self .spy_args is not None :
239
+ self .spy_args .setdefault ('get_index' , []).append (model_klass )
240
+ return super (ElasticSearchMockUnifiedIndex , self ).get_index (model_klass )
241
+
242
+ @contextmanager
243
+ def spy (self ):
244
+ try :
245
+ self .spy_args = {}
246
+ yield self .spy_args
247
+ finally :
248
+ self .spy_args = None
249
+
250
+
232
251
class ElasticsearchSearchBackendTestCase (TestCase ):
233
252
def setUp (self ):
234
253
super (ElasticsearchSearchBackendTestCase , self ).setUp ()
@@ -239,7 +258,7 @@ def setUp(self):
239
258
240
259
# Stow.
241
260
self .old_ui = connections ['elasticsearch' ].get_unified_index ()
242
- self .ui = UnifiedIndex ()
261
+ self .ui = ElasticSearchMockUnifiedIndex ()
243
262
self .smmi = ElasticsearchMockSearchIndex ()
244
263
self .smmidni = ElasticsearchMockSearchIndexWithSkipDocument ()
245
264
self .smtmmi = ElasticsearchMaintainTypeMockSearchIndex ()
@@ -412,6 +431,13 @@ def test_clear(self):
412
431
self .sb .clear ([AnotherMockModel , MockModel ])
413
432
self .assertEqual (self .raw_search ('*:*' ).get ('hits' , {}).get ('total' , 0 ), 0 )
414
433
434
+ def test_results_ask_for_index_per_entry (self ):
435
+ # Test that index class is obtained per result entry, not per every entry field
436
+ self .sb .update (self .smmi , self .sample_objs )
437
+ with self .ui .spy () as spy :
438
+ self .sb .search ('*:*' , limit_to_registered_models = False )
439
+ self .assertEqual (len (spy .get ('get_index' , [])), len (self .sample_objs ))
440
+
415
441
def test_search (self ):
416
442
self .sb .update (self .smmi , self .sample_objs )
417
443
self .assertEqual (self .raw_search ('*:*' )['hits' ]['total' ], 3 )
0 commit comments