28
28
import org .apache .lucene .store .IOContext ;
29
29
import org .apache .lucene .store .IndexInput ;
30
30
import org .apache .lucene .store .IndexOutput ;
31
+ import org .apache .lucene .store .RandomAccessInput ;
31
32
import org .apache .lucene .util .VectorUtil ;
32
33
import org .elasticsearch .core .IOUtils ;
33
34
import org .elasticsearch .core .SuppressForbidden ;
@@ -237,36 +238,60 @@ public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) thro
237
238
private void mergeOneFieldIVF (FieldInfo fieldInfo , MergeState mergeState ) throws IOException {
238
239
final int numVectors ;
239
240
String tempRawVectorsFileName = null ;
241
+ String docsFileName = null ;
240
242
boolean success = false ;
241
243
// build a float vector values with random access. In order to do that we dump the vectors to
242
- // a temporary file
243
- // and write the docID follow by the vector
244
- try (IndexOutput out = mergeState .segmentInfo .dir .createTempOutput (mergeState .segmentInfo .name , "ivf_" , IOContext .DEFAULT )) {
245
- tempRawVectorsFileName = out .getName ();
246
- // TODO do this better, we shouldn't have to write to a temp file, we should be able to
247
- // to just from the merged vector values, the tricky part is the random access.
248
- numVectors = writeFloatVectorValues (fieldInfo , out , MergedVectorValues .mergeFloatVectorValues (fieldInfo , mergeState ));
249
- CodecUtil .writeFooter (out );
250
- success = true ;
244
+ // a temporary file and if the segment is not dense, the docs to another file/
245
+ try (
246
+ IndexOutput vectorsOut = mergeState .segmentInfo .dir .createTempOutput (mergeState .segmentInfo .name , "ivfvec_" , IOContext .DEFAULT )
247
+ ) {
248
+ tempRawVectorsFileName = vectorsOut .getName ();
249
+ FloatVectorValues mergedFloatVectorValues = MergedVectorValues .mergeFloatVectorValues (fieldInfo , mergeState );
250
+ // if the segment is dense, we don't need to do anything with docIds.
251
+ boolean dense = mergedFloatVectorValues .size () == mergeState .segmentInfo .maxDoc ();
252
+ try (
253
+ IndexOutput docsOut = dense
254
+ ? null
255
+ : mergeState .segmentInfo .dir .createTempOutput (mergeState .segmentInfo .name , "ivfdoc_" , IOContext .DEFAULT )
256
+ ) {
257
+ if (docsOut != null ) {
258
+ docsFileName = docsOut .getName ();
259
+ }
260
+ // TODO do this better, we shouldn't have to write to a temp file, we should be able to
261
+ // to just from the merged vector values, the tricky part is the random access.
262
+ numVectors = writeFloatVectorValues (fieldInfo , docsOut , vectorsOut , mergedFloatVectorValues );
263
+ CodecUtil .writeFooter (vectorsOut );
264
+ if (docsOut != null ) {
265
+ CodecUtil .writeFooter (docsOut );
266
+ }
267
+ success = true ;
268
+ }
251
269
} finally {
252
- if (success == false && tempRawVectorsFileName != null ) {
253
- org .apache .lucene .util .IOUtils .deleteFilesIgnoringExceptions (mergeState .segmentInfo .dir , tempRawVectorsFileName );
270
+ if (success == false ) {
271
+ if (tempRawVectorsFileName != null ) {
272
+ org .apache .lucene .util .IOUtils .deleteFilesIgnoringExceptions (mergeState .segmentInfo .dir , tempRawVectorsFileName );
273
+ }
274
+ if (docsFileName != null ) {
275
+ org .apache .lucene .util .IOUtils .deleteFilesIgnoringExceptions (mergeState .segmentInfo .dir , docsFileName );
276
+ }
254
277
}
255
278
}
256
- try (IndexInput in = mergeState .segmentInfo .dir .openInput (tempRawVectorsFileName , IOContext .DEFAULT )) {
257
- float [] calculatedGlobalCentroid = new float [fieldInfo .getVectorDimension ()];
258
- final FloatVectorValues floatVectorValues = getFloatVectorValues (fieldInfo , in , numVectors );
279
+ try (
280
+ IndexInput vectors = mergeState .segmentInfo .dir .openInput (tempRawVectorsFileName , IOContext .DEFAULT );
281
+ IndexInput docs = docsFileName == null ? null : mergeState .segmentInfo .dir .openInput (docsFileName , IOContext .DEFAULT )
282
+ ) {
283
+ final FloatVectorValues floatVectorValues = getFloatVectorValues (fieldInfo , docs , vectors , numVectors );
259
284
success = false ;
260
285
long centroidOffset ;
261
286
long centroidLength ;
262
287
String centroidTempName = null ;
263
288
int numCentroids ;
264
289
IndexOutput centroidTemp = null ;
265
290
CentroidAssignments centroidAssignments ;
291
+ float [] calculatedGlobalCentroid = new float [fieldInfo .getVectorDimension ()];
266
292
try {
267
293
centroidTemp = mergeState .segmentInfo .dir .createTempOutput (mergeState .segmentInfo .name , "civf_" , IOContext .DEFAULT );
268
294
centroidTempName = centroidTemp .getName ();
269
-
270
295
centroidAssignments = calculateAndWriteCentroids (
271
296
fieldInfo ,
272
297
floatVectorValues ,
@@ -318,28 +343,34 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws
318
343
writeMeta (fieldInfo , centroidOffset , centroidLength , offsets , calculatedGlobalCentroid );
319
344
}
320
345
} finally {
346
+ org .apache .lucene .util .IOUtils .deleteFilesIgnoringExceptions (mergeState .segmentInfo .dir , centroidTempName );
347
+ }
348
+ } finally {
349
+ if (docsFileName != null ) {
321
350
org .apache .lucene .util .IOUtils .deleteFilesIgnoringExceptions (
322
351
mergeState .segmentInfo .dir ,
323
352
tempRawVectorsFileName ,
324
- centroidTempName
353
+ docsFileName
325
354
);
355
+ } else {
356
+ org .apache .lucene .util .IOUtils .deleteFilesIgnoringExceptions (mergeState .segmentInfo .dir , tempRawVectorsFileName );
326
357
}
327
- } finally {
328
- org .apache .lucene .util .IOUtils .deleteFilesIgnoringExceptions (mergeState .segmentInfo .dir , tempRawVectorsFileName );
329
358
}
330
359
}
331
360
332
- private static FloatVectorValues getFloatVectorValues (FieldInfo fieldInfo , IndexInput randomAccessInput , int numVectors ) {
361
+ private static FloatVectorValues getFloatVectorValues (FieldInfo fieldInfo , IndexInput docs , IndexInput vectors , int numVectors )
362
+ throws IOException {
333
363
if (numVectors == 0 ) {
334
364
return FloatVectorValues .fromFloats (List .of (), fieldInfo .getVectorDimension ());
335
365
}
336
- final long length = (long ) Float .BYTES * fieldInfo .getVectorDimension () + Integer . BYTES ;
366
+ final long vectorLength = (long ) Float .BYTES * fieldInfo .getVectorDimension ();
337
367
final float [] vector = new float [fieldInfo .getVectorDimension ()];
368
+ final RandomAccessInput randomDocs = docs == null ? null : docs .randomAccessSlice (0 , docs .length ());
338
369
return new FloatVectorValues () {
339
370
@ Override
340
371
public float [] vectorValue (int ord ) throws IOException {
341
- randomAccessInput .seek (ord * length + Integer . BYTES );
342
- randomAccessInput .readFloats (vector , 0 , vector .length );
372
+ vectors .seek (ord * vectorLength );
373
+ vectors .readFloats (vector , 0 , vector .length );
343
374
return vector ;
344
375
}
345
376
@@ -360,27 +391,34 @@ public int size() {
360
391
361
392
@ Override
362
393
public int ordToDoc (int ord ) {
394
+ if (randomDocs == null ) {
395
+ return ord ;
396
+ }
363
397
try {
364
- randomAccessInput .seek (ord * length );
365
- return randomAccessInput .readInt ();
398
+ return randomDocs .readInt ((long ) ord * Integer .BYTES );
366
399
} catch (IOException e ) {
367
400
throw new UncheckedIOException (e );
368
401
}
369
402
}
370
403
};
371
404
}
372
405
373
- private static int writeFloatVectorValues (FieldInfo fieldInfo , IndexOutput out , FloatVectorValues floatVectorValues )
374
- throws IOException {
406
+ private static int writeFloatVectorValues (
407
+ FieldInfo fieldInfo ,
408
+ IndexOutput docsOut ,
409
+ IndexOutput vectorsOut ,
410
+ FloatVectorValues floatVectorValues
411
+ ) throws IOException {
375
412
int numVectors = 0 ;
376
413
final ByteBuffer buffer = ByteBuffer .allocate (fieldInfo .getVectorDimension () * Float .BYTES ).order (ByteOrder .LITTLE_ENDIAN );
377
414
final KnnVectorValues .DocIndexIterator iterator = floatVectorValues .iterator ();
378
415
for (int docV = iterator .nextDoc (); docV != NO_MORE_DOCS ; docV = iterator .nextDoc ()) {
379
416
numVectors ++;
380
- float [] vector = floatVectorValues .vectorValue (iterator .index ());
381
- out .writeInt (iterator .docID ());
382
- buffer .asFloatBuffer ().put (vector );
383
- out .writeBytes (buffer .array (), buffer .array ().length );
417
+ buffer .asFloatBuffer ().put (floatVectorValues .vectorValue (iterator .index ()));
418
+ vectorsOut .writeBytes (buffer .array (), buffer .array ().length );
419
+ if (docsOut != null ) {
420
+ docsOut .writeInt (iterator .docID ());
421
+ }
384
422
}
385
423
return numVectors ;
386
424
}
0 commit comments