Skip to content

Leverage One Layer of Hierarchical KMeans Structure on Read #129950

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 43 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
8727f3a
wip on sorting centroids and assignments
john-wagster Jun 12, 2025
5741a0d
wip on sorting centroids and assignments
john-wagster Jun 13, 2025
0cc1f67
Merge branch 'main' into ivf_hkmeans_struc2
john-wagster Jun 17, 2025
188446a
got everything working e-to-e w single file centroids, still needs se…
john-wagster Jun 21, 2025
ae53b11
merging
john-wagster Jun 23, 2025
0dab43d
merging w new bulk score api
john-wagster Jun 23, 2025
fc73af3
bug fixes and cleanup
john-wagster Jun 24, 2025
470245b
cleanup
john-wagster Jun 24, 2025
8d25046
iter
john-wagster Jun 24, 2025
3e88c51
iter
john-wagster Jun 24, 2025
b089f8d
Merge branch 'main' into ivf_hkmeans_struc2
john-wagster Jun 24, 2025
6940131
Merge branch 'main' into ivf_hkmeans_struc2
john-wagster Jun 24, 2025
839f0e9
Merge branch 'ivf_hkmeans_struc2' of github.com:john-wagster/elastics…
john-wagster Jun 24, 2025
fff6650
Merge branch 'main' into ivf_hkmeans_struc2
john-wagster Jun 25, 2025
a718197
added exploration for a percentage of parents
john-wagster Jun 26, 2025
c333dcf
Merge branch 'ivf_hkmeans_struc2' of github.com:john-wagster/elastics…
john-wagster Jun 26, 2025
ce16153
Merge branch 'main' into ivf_hkmeans_struc2
john-wagster Jun 26, 2025
13a0007
[CI] Auto commit changes from spotless
elasticsearchmachine Jun 26, 2025
9c3c0fb
Merge branch 'main' into ivf_hkmeans_struc2
john-wagster Jun 26, 2025
618d1ab
minor cleanup
john-wagster Jun 26, 2025
33b778f
Merge branch 'ivf_hkmeans_struc2' of github.com:john-wagster/elastics…
john-wagster Jun 26, 2025
325a21e
Merge branch 'main' into ivf_hkmeans_struc2
john-wagster Jun 26, 2025
42d2135
fix for small data usecase
john-wagster Jun 26, 2025
2e04623
Merge branch 'main' into ivf_hkmeans_struc2
john-wagster Jun 26, 2025
2f5bfd3
merge
john-wagster Jun 27, 2025
5910259
iterated on better mechanism for utilizing parent centroids
john-wagster Jun 29, 2025
ef9402c
scores not distances + added some diagnostics to be removed subsequently
john-wagster Jun 30, 2025
7d5bb39
merge
john-wagster Jun 30, 2025
77f020c
[CI] Auto commit changes from spotless
elasticsearchmachine Jun 30, 2025
052f431
merge
john-wagster Jul 1, 2025
d7f10ea
Merge branch 'ivf_hkmeans_struc2' of github.com:john-wagster/elastics…
john-wagster Jul 1, 2025
89e5699
Merge branch 'main' into ivf_hkmeans_struc2
john-wagster Jul 1, 2025
ddaf610
merge
john-wagster Jul 2, 2025
ed48801
Merge branch 'main' into ivf_hkmeans_struc2
john-wagster Jul 3, 2025
ffe2929
Merge branch 'main' into ivf_hkmeans_struc2
john-wagster Jul 5, 2025
3a2ba61
using full hkmeans to gen parents
john-wagster Jul 7, 2025
7277b0a
[CI] Auto commit changes from spotless
elasticsearchmachine Jul 7, 2025
60ef44d
merging
john-wagster Jul 7, 2025
866ac5a
Merge branch 'ivf_hkmeans_struc2' of github.com:john-wagster/elastics…
john-wagster Jul 7, 2025
402c767
cleanup
john-wagster Jul 7, 2025
0260eac
assert fix and minor cleanup
john-wagster Jul 7, 2025
3b0764e
[CI] Auto commit changes from spotless
elasticsearchmachine Jul 7, 2025
5aa2682
fixed 1 off error and other cleanup
john-wagster Jul 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.index.codec.vectors;

import org.apache.lucene.util.IntroSorter;

class AssignmentArraySorter extends IntroSorter {
int pivot = -1;
private final float[][] centroids;
private final int[] centroidsOrds;
private final int[] sortOrdering;

AssignmentArraySorter(float[][] centroids, int[] centroidsOrds, int[] sortOrdering) {
this.centroids = centroids;
this.centroidsOrds = centroidsOrds;
this.sortOrdering = sortOrdering;
}

@Override
protected void setPivot(int i) {
pivot = sortOrdering[i];
}

@Override
protected int comparePivot(int j) {
return Integer.compare(pivot, sortOrdering[j]);
}

@Override
protected void swap(int i, int j) {
final float[] tmpC = centroids[i];
centroids[i] = centroids[j];
centroids[j] = tmpC;

final int tmpA = centroidsOrds[i];
centroidsOrds[i] = centroidsOrds[j];
centroidsOrds[j] = tmpA;

final int tmpSort = sortOrdering[i];
sortOrdering[i] = sortOrdering[j];
sortOrdering[j] = tmpSort;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@

package org.elasticsearch.index.codec.vectors;

record CentroidAssignments(int numCentroids, float[][] centroids, int[][] assignmentsByCluster) {
record CentroidAssignments(int numParentCentroids, int numCentroids, float[][] centroids, int[][] assignmentsByCluster) {

CentroidAssignments(float[][] centroids, int[][] assignmentsByCluster) {
this(centroids.length, centroids, assignmentsByCluster);
CentroidAssignments(int numParentCentroids, float[][] centroids, int[][] assignmentsByCluster) {
this(numParentCentroids, centroids.length, centroids, assignmentsByCluster);
assert centroids.length == assignmentsByCluster.length;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,50 @@ public DefaultIVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVect
super(state, rawVectorsReader);
}

private abstract static class BaseCentroidQueryScorer implements CentroidQueryScorer {

// TODO can we do this in off-heap blocks?
float int4QuantizedScore(
float qcDist,
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
int dims,
float[] targetCorrections,
int targetComponentSum,
float centroidDp,
VectorSimilarityFunction similarityFunction
) {
float ax = targetCorrections[0];
// Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
float lx = (targetCorrections[1] - ax) * FOUR_BIT_SCALE;
float ay = queryCorrections.lowerInterval();
float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE;
float y1 = queryCorrections.quantizedComponentSum();
float score = ax * ay * dims + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist;
if (similarityFunction == EUCLIDEAN) {
score = queryCorrections.additionalCorrection() + targetCorrections[2] - 2 * score;
return Math.max(1 / (1f + score), 0);
} else {
// For cosine and max inner product, we need to apply the additional correction, which is
// assumed to be the non-centered dot-product between the vector and the centroid
score += queryCorrections.additionalCorrection() + targetCorrections[2] - centroidDp;
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
return VectorUtil.scaleMaxInnerProductScore(score);
}
return Math.max((1f + score) / 2f, 0);
}
}
}

private abstract static class ParentCentroidQueryScorer extends BaseCentroidQueryScorer implements CentroidWChildrenQueryScorer {}

@Override
CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] targetQuery)
throws IOException {
CentroidQueryScorer getCentroidScorer(
FieldInfo fieldInfo,
int numParentCentroids,
int numCentroids,
IndexInput centroids,
float[] targetQuery
) throws IOException {
final FieldEntry fieldEntry = fields.get(fieldInfo.number);
final float globalCentroidDp = fieldEntry.globalCentroidDp();
final OptimizedScalarQuantizer scalarQuantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
Expand All @@ -65,11 +106,14 @@ CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, Ind
quantized[i] = (byte) scratch[i];
}
final ES91Int4VectorsScorer scorer = ESVectorUtil.getES91Int4VectorsScorer(centroids, fieldInfo.getVectorDimension());
return new CentroidQueryScorer() {
return new BaseCentroidQueryScorer() {
int currentCentroid = -1;
private final float[] centroid = new float[fieldInfo.getVectorDimension()];
private final float[] centroidCorrectiveValues = new float[3];
private final long rawCentroidsOffset = (long) numCentroids * (fieldInfo.getVectorDimension() + 3 * Float.BYTES + Short.BYTES);
private final long quantizedVectorByteSize = fieldInfo.getVectorDimension() + 3 * Float.BYTES + Short.BYTES;
private final long parentNodeByteSize = quantizedVectorByteSize + 2 * Integer.BYTES;
private final long quantizedCentroidsOffset = numParentCentroids * parentNodeByteSize;
private final long rawCentroidsOffset = numParentCentroids * parentNodeByteSize + numCentroids * quantizedVectorByteSize;
private final long rawCentroidsByteSize = (long) Float.BYTES * fieldInfo.getVectorDimension();

@Override
Expand All @@ -87,10 +131,14 @@ public float[] centroid(int centroidOrdinal) throws IOException {
return centroid;
}

public void bulkScore(NeighborQueue queue) throws IOException {
@Override
public void bulkScore(NeighborQueue queue, int start, int end) throws IOException {
// TODO: bulk score centroids like we do with posting lists
centroids.seek(0L);
for (int i = 0; i < numCentroids; i++) {
assert start >= 0;
assert end > 0;
assert start + end <= numCentroids;
centroids.seek(quantizedCentroidsOffset + quantizedVectorByteSize * start);
for (int i = start; i < end; i++) {
queue.add(i, score());
}
}
Expand All @@ -109,46 +157,125 @@ private float score() throws IOException {
fieldInfo.getVectorSimilarityFunction()
);
}
};
}

// TODO can we do this in off-heap blocks?
private float int4QuantizedScore(
float qcDist,
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
int dims,
float[] targetCorrections,
int targetComponentSum,
float centroidDp,
VectorSimilarityFunction similarityFunction
) {
float ax = targetCorrections[0];
// Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
float lx = (targetCorrections[1] - ax) * FOUR_BIT_SCALE;
float ay = queryCorrections.lowerInterval();
float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE;
float y1 = queryCorrections.quantizedComponentSum();
float score = ax * ay * dims + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist;
if (similarityFunction == EUCLIDEAN) {
score = queryCorrections.additionalCorrection() + targetCorrections[2] - 2 * score;
return Math.max(1 / (1f + score), 0);
} else {
// For cosine and max inner product, we need to apply the additional correction, which is
// assumed to be the non-centered dot-product between the vector and the centroid
score += queryCorrections.additionalCorrection() + targetCorrections[2] - centroidDp;
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
return VectorUtil.scaleMaxInnerProductScore(score);
}
return Math.max((1f + score) / 2f, 0);
@Override
ParentCentroidQueryScorer getParentCentroidScorer(
FieldInfo fieldInfo,
int numParentCentroids,
IndexInput centroids,
float[] targetQuery
) throws IOException {
FieldEntry fieldEntry = fields.get(fieldInfo.number);
float globalCentroidDp = fieldEntry.globalCentroidDp();
OptimizedScalarQuantizer scalarQuantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
final int[] scratch = new int[targetQuery.length];
final OptimizedScalarQuantizer.QuantizationResult queryParams = scalarQuantizer.scalarQuantize(
ArrayUtil.copyArray(targetQuery),
scratch,
(byte) 4,
fieldEntry.globalCentroid()
);
final byte[] quantized = new byte[targetQuery.length];
for (int i = 0; i < quantized.length; i++) {
quantized[i] = (byte) scratch[i];
}
final ES91Int4VectorsScorer scorer = ESVectorUtil.getES91Int4VectorsScorer(centroids, fieldInfo.getVectorDimension());
return new ParentCentroidQueryScorer() {
int currentCentroid = -1;
private final float[] centroidCorrectiveValues = new float[3];
private final long quantizedVectorByteSize = fieldInfo.getVectorDimension() + 3 * Float.BYTES + Short.BYTES;
private final long parentNodeByteSize = quantizedVectorByteSize + 2 * Integer.BYTES;

private int childCentroidStart;
private int childCount;

@Override
public int size() {
return numParentCentroids;
}

@Override
public float[] centroid(int centroidOrdinal) throws IOException {
throw new UnsupportedOperationException("can't score at the parent level");
}

private void readChildDetails(int centroidOrdinal) throws IOException {
if (centroidOrdinal == currentCentroid) {
return;
}
centroids.seek(parentNodeByteSize * centroidOrdinal + quantizedVectorByteSize);
childCentroidStart = centroids.readInt();
childCount = centroids.readInt();
currentCentroid = centroidOrdinal;
}

@Override
public int getChildCentroidStart(int centroidOrdinal) throws IOException {
readChildDetails(centroidOrdinal);
return childCentroidStart;
}

@Override
public int getChildCount(int centroidOrdinal) throws IOException {
readChildDetails(centroidOrdinal);
return childCount;
}

@Override
public void bulkScore(NeighborQueue queue, int start, int end) throws IOException {
assert start > 0;
assert end > 0;
assert start + end <= numParentCentroids;
// TODO: bulk score centroids like we do with posting lists
centroids.seek(parentNodeByteSize * start);
for (int i = start; i < end; i++) {
queue.add(i, score());
}
}

private float score() throws IOException {
final float qcDist = scorer.int4DotProduct(quantized);
centroids.readFloats(centroidCorrectiveValues, 0, 3);
final int quantizedCentroidComponentSum = Short.toUnsignedInt(centroids.readShort());

// TODO: should we consider a different format such as moving these to the beginning of the file to benefit bulk read
// TODO: cache these at this point when scoring since we'll likely read many of them?
// child partition start, child partition count
centroids.skipBytes(Integer.BYTES * 2);

return int4QuantizedScore(
qcDist,
queryParams,
fieldInfo.getVectorDimension(),
centroidCorrectiveValues,
quantizedCentroidComponentSum,
globalCentroidDp,
fieldInfo.getVectorSimilarityFunction()
);
}
};
}

@Override
NeighborQueue scorePostingLists(
FieldInfo fieldInfo,
KnnCollector knnCollector,
CentroidQueryScorer centroidQueryScorer,
int nProbe,
int start,
int count
) throws IOException {
NeighborQueue neighborQueue = new NeighborQueue(count, true);
centroidQueryScorer.bulkScore(neighborQueue, start, start + count);
return neighborQueue;
}

@Override
NeighborQueue scorePostingLists(FieldInfo fieldInfo, KnnCollector knnCollector, CentroidQueryScorer centroidQueryScorer, int nProbe)
throws IOException {
NeighborQueue neighborQueue = new NeighborQueue(centroidQueryScorer.size(), true);
centroidQueryScorer.bulkScore(neighborQueue);
return neighborQueue;
return scorePostingLists(fieldInfo, knnCollector, centroidQueryScorer, nProbe, 0, centroidQueryScorer.size());
}

@Override
Expand Down
Loading