Skip to content

Commit 72b488c

Browse files
authored
[IVF] Improve the format of the tmp file written during merging (#129828)
This commit separe vector and docIds on the tmp file.
1 parent b1741e8 commit 72b488c

File tree

1 file changed

+68
-30
lines changed

1 file changed

+68
-30
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java

Lines changed: 68 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.apache.lucene.store.IOContext;
2929
import org.apache.lucene.store.IndexInput;
3030
import org.apache.lucene.store.IndexOutput;
31+
import org.apache.lucene.store.RandomAccessInput;
3132
import org.apache.lucene.util.VectorUtil;
3233
import org.elasticsearch.core.IOUtils;
3334
import org.elasticsearch.core.SuppressForbidden;
@@ -237,36 +238,60 @@ public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) thro
237238
private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
238239
final int numVectors;
239240
String tempRawVectorsFileName = null;
241+
String docsFileName = null;
240242
boolean success = false;
241243
// build a float vector values with random access. In order to do that we dump the vectors to
242-
// a temporary file
243-
// and write the docID follow by the vector
244-
try (IndexOutput out = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "ivf_", IOContext.DEFAULT)) {
245-
tempRawVectorsFileName = out.getName();
246-
// TODO do this better, we shouldn't have to write to a temp file, we should be able to
247-
// to just from the merged vector values, the tricky part is the random access.
248-
numVectors = writeFloatVectorValues(fieldInfo, out, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState));
249-
CodecUtil.writeFooter(out);
250-
success = true;
244+
// a temporary file and if the segment is not dense, the docs to another file/
245+
try (
246+
IndexOutput vectorsOut = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "ivfvec_", IOContext.DEFAULT)
247+
) {
248+
tempRawVectorsFileName = vectorsOut.getName();
249+
FloatVectorValues mergedFloatVectorValues = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
250+
// if the segment is dense, we don't need to do anything with docIds.
251+
boolean dense = mergedFloatVectorValues.size() == mergeState.segmentInfo.maxDoc();
252+
try (
253+
IndexOutput docsOut = dense
254+
? null
255+
: mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "ivfdoc_", IOContext.DEFAULT)
256+
) {
257+
if (docsOut != null) {
258+
docsFileName = docsOut.getName();
259+
}
260+
// TODO do this better, we shouldn't have to write to a temp file, we should be able to
261+
// to just from the merged vector values, the tricky part is the random access.
262+
numVectors = writeFloatVectorValues(fieldInfo, docsOut, vectorsOut, mergedFloatVectorValues);
263+
CodecUtil.writeFooter(vectorsOut);
264+
if (docsOut != null) {
265+
CodecUtil.writeFooter(docsOut);
266+
}
267+
success = true;
268+
}
251269
} finally {
252-
if (success == false && tempRawVectorsFileName != null) {
253-
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, tempRawVectorsFileName);
270+
if (success == false) {
271+
if (tempRawVectorsFileName != null) {
272+
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, tempRawVectorsFileName);
273+
}
274+
if (docsFileName != null) {
275+
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, docsFileName);
276+
}
254277
}
255278
}
256-
try (IndexInput in = mergeState.segmentInfo.dir.openInput(tempRawVectorsFileName, IOContext.DEFAULT)) {
257-
float[] calculatedGlobalCentroid = new float[fieldInfo.getVectorDimension()];
258-
final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldInfo, in, numVectors);
279+
try (
280+
IndexInput vectors = mergeState.segmentInfo.dir.openInput(tempRawVectorsFileName, IOContext.DEFAULT);
281+
IndexInput docs = docsFileName == null ? null : mergeState.segmentInfo.dir.openInput(docsFileName, IOContext.DEFAULT)
282+
) {
283+
final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldInfo, docs, vectors, numVectors);
259284
success = false;
260285
long centroidOffset;
261286
long centroidLength;
262287
String centroidTempName = null;
263288
int numCentroids;
264289
IndexOutput centroidTemp = null;
265290
CentroidAssignments centroidAssignments;
291+
float[] calculatedGlobalCentroid = new float[fieldInfo.getVectorDimension()];
266292
try {
267293
centroidTemp = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "civf_", IOContext.DEFAULT);
268294
centroidTempName = centroidTemp.getName();
269-
270295
centroidAssignments = calculateAndWriteCentroids(
271296
fieldInfo,
272297
floatVectorValues,
@@ -318,28 +343,34 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws
318343
writeMeta(fieldInfo, centroidOffset, centroidLength, offsets, calculatedGlobalCentroid);
319344
}
320345
} finally {
346+
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, centroidTempName);
347+
}
348+
} finally {
349+
if (docsFileName != null) {
321350
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(
322351
mergeState.segmentInfo.dir,
323352
tempRawVectorsFileName,
324-
centroidTempName
353+
docsFileName
325354
);
355+
} else {
356+
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, tempRawVectorsFileName);
326357
}
327-
} finally {
328-
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, tempRawVectorsFileName);
329358
}
330359
}
331360

332-
private static FloatVectorValues getFloatVectorValues(FieldInfo fieldInfo, IndexInput randomAccessInput, int numVectors) {
361+
private static FloatVectorValues getFloatVectorValues(FieldInfo fieldInfo, IndexInput docs, IndexInput vectors, int numVectors)
362+
throws IOException {
333363
if (numVectors == 0) {
334364
return FloatVectorValues.fromFloats(List.of(), fieldInfo.getVectorDimension());
335365
}
336-
final long length = (long) Float.BYTES * fieldInfo.getVectorDimension() + Integer.BYTES;
366+
final long vectorLength = (long) Float.BYTES * fieldInfo.getVectorDimension();
337367
final float[] vector = new float[fieldInfo.getVectorDimension()];
368+
final RandomAccessInput randomDocs = docs == null ? null : docs.randomAccessSlice(0, docs.length());
338369
return new FloatVectorValues() {
339370
@Override
340371
public float[] vectorValue(int ord) throws IOException {
341-
randomAccessInput.seek(ord * length + Integer.BYTES);
342-
randomAccessInput.readFloats(vector, 0, vector.length);
372+
vectors.seek(ord * vectorLength);
373+
vectors.readFloats(vector, 0, vector.length);
343374
return vector;
344375
}
345376

@@ -360,27 +391,34 @@ public int size() {
360391

361392
@Override
362393
public int ordToDoc(int ord) {
394+
if (randomDocs == null) {
395+
return ord;
396+
}
363397
try {
364-
randomAccessInput.seek(ord * length);
365-
return randomAccessInput.readInt();
398+
return randomDocs.readInt((long) ord * Integer.BYTES);
366399
} catch (IOException e) {
367400
throw new UncheckedIOException(e);
368401
}
369402
}
370403
};
371404
}
372405

373-
private static int writeFloatVectorValues(FieldInfo fieldInfo, IndexOutput out, FloatVectorValues floatVectorValues)
374-
throws IOException {
406+
private static int writeFloatVectorValues(
407+
FieldInfo fieldInfo,
408+
IndexOutput docsOut,
409+
IndexOutput vectorsOut,
410+
FloatVectorValues floatVectorValues
411+
) throws IOException {
375412
int numVectors = 0;
376413
final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
377414
final KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator();
378415
for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) {
379416
numVectors++;
380-
float[] vector = floatVectorValues.vectorValue(iterator.index());
381-
out.writeInt(iterator.docID());
382-
buffer.asFloatBuffer().put(vector);
383-
out.writeBytes(buffer.array(), buffer.array().length);
417+
buffer.asFloatBuffer().put(floatVectorValues.vectorValue(iterator.index()));
418+
vectorsOut.writeBytes(buffer.array(), buffer.array().length);
419+
if (docsOut != null) {
420+
docsOut.writeInt(iterator.docID());
421+
}
384422
}
385423
return numVectors;
386424
}

0 commit comments

Comments
 (0)