Skip to content

IVF Hierarchical KMeans Flush & Merge #128675

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 54 commits into from
Jun 10, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
206022a
added classes related to running hierarchical kmeans as a clustering …
john-wagster May 30, 2025
85e4d8f
[CI] Auto commit changes from spotless
elasticsearchmachine May 30, 2025
6578e87
Merge branch 'main' into ivf_hkmeans
john-wagster May 30, 2025
4280682
Merge branch 'main' into ivf_hkmeans
john-wagster Jun 2, 2025
58c5991
iter
john-wagster Jun 2, 2025
651efdf
[CI] Auto commit changes from spotless
elasticsearchmachine Jun 2, 2025
5743d59
bringing back some interfaces
john-wagster Jun 2, 2025
47e5d8e
Merge branch 'main' into ivf_hkmeans
john-wagster Jun 2, 2025
786e4f1
Merge branch 'main' into ivf_hkmeans
john-wagster Jun 2, 2025
5ca53d3
accidentally remove suppressforbidden
john-wagster Jun 2, 2025
b1f9ae4
migrated from short to int and fixed IOUtils copy/paste errors
john-wagster Jun 2, 2025
075e2ce
no longer allocating larger arrays for slices that are the entire set…
john-wagster Jun 2, 2025
5fb98ff
[CI] Auto commit changes from spotless
elasticsearchmachine Jun 2, 2025
523c2ca
iter on fvvs
john-wagster Jun 2, 2025
bb4531b
Merge branch 'ivf_hkmeans' of github.com:john-wagster/elasticsearch i…
john-wagster Jun 2, 2025
44b0aa9
iter on fvvs
john-wagster Jun 2, 2025
f5f0538
fixing comment
john-wagster Jun 2, 2025
3893098
switched to reservoir sampling
john-wagster Jun 3, 2025
1f2d053
switched to reservoir sampling
john-wagster Jun 3, 2025
6cda6a6
switched to reservoir sampling
john-wagster Jun 3, 2025
4cd94cf
missed a few short to int in tests
john-wagster Jun 3, 2025
b6d61fa
removed sorting on writeCentroids
john-wagster Jun 3, 2025
c82d719
migrated CentroidAssignments to a class to hide default constructor, …
john-wagster Jun 3, 2025
4bd2c9c
only getting the vector value on sampling when necessary
john-wagster Jun 3, 2025
1d61944
* stepLloyd now passes nextCentroids to prevent creating and rec…
john-wagster Jun 4, 2025
f05a541
Merge branch 'main' into ivf_hkmeans
john-wagster Jun 4, 2025
26698d7
[CI] Auto commit changes from spotless
elasticsearchmachine Jun 4, 2025
5112408
bug fixes around printing cluster metrics; still refactoring this
john-wagster Jun 4, 2025
dd61ba5
split kmeansresult into two classes, updated centroid assignments int…
john-wagster Jun 5, 2025
762839e
comibned kmeans and kmeanslocal classes into one class, and fixed vis…
john-wagster Jun 5, 2025
e5746a1
Merge branch 'main' into ivf_hkmeans
john-wagster Jun 5, 2025
44d0f24
Merge branch 'main' into ivf_hkmeans
john-wagster Jun 6, 2025
e82af9c
Merge branch 'main' into ivf_hkmeans
john-wagster Jun 6, 2025
cf7c6b3
Merge branch 'main' into ivf_hkmeans
john-wagster Jun 6, 2025
2c96c82
added trimtosize and fixed a spot where we should be returning KMeans…
john-wagster Jun 6, 2025
cc5570a
Merge branch 'main' into ivf_hkmeans
john-wagster Jun 6, 2025
3fec326
Merge branch 'main' into ivf_hkmeans
john-wagster Jun 6, 2025
5dffeea
Merge branch 'main' into ivf_hkmeans
john-wagster Jun 6, 2025
904f52d
Merge branch 'main' into ivf_hkmeans
john-wagster Jun 7, 2025
aad4b3b
minor test fixes and edge cases
john-wagster Jun 8, 2025
1f93921
Merge branch 'main' into ivf_hkmeans
john-wagster Jun 8, 2025
968f539
Merge branch 'main' into ivf_hkmeans
john-wagster Jun 9, 2025
1048f7f
Merge branch 'main' into ivf_hkmeans
john-wagster Jun 9, 2025
490946f
Merge remote-tracking branch 'upstream/main' into ivf_hkmeans
benwtrent Jun 9, 2025
12a1207
fixing bugs
benwtrent Jun 9, 2025
8ca12bc
Merge branch 'main' into ivf_hkmeans
john-wagster Jun 9, 2025
69a0b4e
merge
john-wagster Jun 9, 2025
ab5a61c
removed unnecessary int[]
john-wagster Jun 10, 2025
93ca452
Merge branch 'main' into ivf_hkmeans
john-wagster Jun 10, 2025
f935144
removed null checking for ffvslice for now because it's extra cruft; …
john-wagster Jun 10, 2025
fc41d7d
Merge branch 'main' into ivf_hkmeans
john-wagster Jun 10, 2025
ff0fad4
making constructor private to reduce confusion
john-wagster Jun 10, 2025
b48bfcf
Merge branch 'main' into ivf_hkmeans
john-wagster Jun 10, 2025
e05ac74
Merge branch 'main' into ivf_hkmeans
john-wagster Jun 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
migrated from short to int and fixed IOUtils copy/paste errors
  • Loading branch information
john-wagster committed Jun 2, 2025
commit b1f9ae4cf807882e9009aefd5f680c04a81c536d
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@

package org.elasticsearch.index.codec.vectors;

record CentroidAssignments(int numCentroids, float[][] cachedCentroids, short[] assignments, short[] soarAssignments) {
record CentroidAssignments(int numCentroids, float[][] cachedCentroids, int[] assignments, int[] soarAssignments) {

CentroidAssignments(float[][] centroids, short[] assignments, short[] soarAssignments) {
CentroidAssignments(float[][] centroids, int[] assignments, int[] soarAssignments) {
this(centroids.length, centroids, assignments, soarAssignments);
}

CentroidAssignments(int numCentroids, short[] assignments, short[] soarAssignments) {
CentroidAssignments(int numCentroids, int[] assignments, int[] soarAssignments) {
this(numCentroids, null, assignments, soarAssignments);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ long[] buildAndWritePostingsLists(
BinarizedFloatVectorValues binarizedByteVectorValues = new BinarizedFloatVectorValues(floatVectorValues, quantizer);
DocIdsWriter docIdsWriter = new DocIdsWriter();

short[] assignments = centroidAssignments.assignments();
short[] soarAssignments = centroidAssignments.soarAssignments();
int[] assignments = centroidAssignments.assignments();
int[] soarAssignments = centroidAssignments.soarAssignments();

int[][] clustersForMetrics = null;
if (infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
Expand Down Expand Up @@ -298,8 +298,8 @@ CentroidAssignments calculateAndWriteCentroids(
// TODO: consider hinting / bootstrapping hierarchical kmeans with the prior segments centroids
KMeansResult kMeansResult = new HierarchicalKMeans().cluster(floatVectorValues, vectorPerCluster);
float[][] centroids = kMeansResult.centroids();
short[] assignments = kMeansResult.assignments();
short[] soarAssignments = kMeansResult.soarAssignments();
int[] assignments = kMeansResult.assignments();
int[] soarAssignments = kMeansResult.soarAssignments();

// TODO: for flush we are doing this over the vectors and here centroids which seems duplicative
// preliminary tests suggest recall is good using only centroids but need to do further evaluation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) thro
success = true;
} finally {
if (success == false && centroidTempName != null) {
org.apache.lucene.util.IOUtils.closeWhileHandlingException(centroidTemp);
IOUtils.closeWhileHandlingException(centroidTemp);
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, centroidTempName);
}
}
Expand All @@ -301,11 +301,11 @@ public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) thro
centroidOffset = ivfCentroids.getFilePointer();
writeMeta(fieldInfo, centroidOffset, 0, new long[0], null);
CodecUtil.writeFooter(centroidTemp);
org.apache.lucene.util.IOUtils.close(centroidTemp);
IOUtils.close(centroidTemp);
return;
}
CodecUtil.writeFooter(centroidTemp);
org.apache.lucene.util.IOUtils.close(centroidTemp);
IOUtils.close(centroidTemp);
centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES);
try (IndexInput centroidsInput = mergeState.segmentInfo.dir.openInput(centroidTempName, IOContext.DEFAULT)) {
ivfCentroids.copyBytes(centroidsInput, centroidsInput.length() - CodecUtil.footerLength());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ public class HierarchicalKMeans {

final int maxIterations;
final int samplesPerCluster;
final short clustersPerNeighborhood;
final int clustersPerNeighborhood;

public HierarchicalKMeans() {
this(MAX_ITERATIONS_DEFAULT, SAMPLES_PER_CLUSTER_DEFAULT, (short) MAXK);
this(MAX_ITERATIONS_DEFAULT, SAMPLES_PER_CLUSTER_DEFAULT, MAXK);
}

HierarchicalKMeans(int maxIterations, int samplesPerCluster, short clustersPerNeighborhood) {
HierarchicalKMeans(int maxIterations, int samplesPerCluster, int clustersPerNeighborhood) {
this.maxIterations = maxIterations;
this.samplesPerCluster = samplesPerCluster;
this.clustersPerNeighborhood = clustersPerNeighborhood;
Expand All @@ -56,7 +56,7 @@ public KMeansResult cluster(FloatVectorValues vectors, int targetSize) throws IO
if (vectors.size() < targetSize) {
float[] centroid = new float[vectors.dimension()];
System.arraycopy(vectors.vectorValue(0), 0, centroid, 0, vectors.dimension());
return new KMeansResult(new float[][] { centroid }, new short[vectors.size()]);
return new KMeansResult(new float[][] { centroid }, new int[vectors.size()]);
}

// partition the space
Expand All @@ -80,7 +80,7 @@ KMeansResult kMeansHierarchical(final FloatVectorValuesSlice vectors, final int
int m = Math.min(k * samplesPerCluster, vectors.size());

// TODO: instead of creating a sub-cluster assignments reuse the parent array each time
short[] assignments = new short[vectors.size()];
int[] assignments = new int[vectors.size()];

KMeans kmeans = new KMeans(m, maxIterations);
float[][] centroids = KMeans.pickInitialCentroids(vectors, m, k);
Expand All @@ -95,9 +95,9 @@ KMeansResult kMeansHierarchical(final FloatVectorValuesSlice vectors, final int
float[][] nextCentroids = new float[centroids.length][vectors.dimension()];
for (int i = 0; i < vectors.size(); i++) {
float smallest = Float.MAX_VALUE;
short centroidIdx = -1;
int centroidIdx = -1;
float[] vector = vectors.vectorValue(i);
for (short j = 0; j < centroids.length; j++) {
for (int j = 0; j < centroids.length; j++) {
float[] centroid = centroids[j];
float d = VectorUtil.squareDistance(vector, centroid);
if (d < smallest) {
Expand All @@ -122,7 +122,7 @@ KMeansResult kMeansHierarchical(final FloatVectorValuesSlice vectors, final int
}
}

short effectiveK = 0;
int effectiveK = 0;
for (int i = 0; i < clusterSizes.length; i++) {
if (clusterSizes[i] > 0) {
effectiveK++;
Expand All @@ -138,7 +138,7 @@ KMeansResult kMeansHierarchical(final FloatVectorValuesSlice vectors, final int
return kMeansResult;
}

for (short c = 0; c < clusterSizes.length; c++) {
for (int c = 0; c < clusterSizes.length; c++) {
// Recurse for each cluster which is larger than targetSize
// Give ourselves 30% margin for the target size
if (100 * clusterSizes[c] > 134 * targetSize) {
Expand All @@ -152,7 +152,7 @@ KMeansResult kMeansHierarchical(final FloatVectorValuesSlice vectors, final int
return kMeansResult;
}

static FloatVectorValuesSlice createClusterSlice(int clusterSize, int cluster, FloatVectorValuesSlice vectors, short[] assignments) {
static FloatVectorValuesSlice createClusterSlice(int clusterSize, int cluster, FloatVectorValuesSlice vectors, int[] assignments) {
int[] slice = new int[clusterSize];
int idx = 0;
for (int i = 0; i < assignments.length; i++) {
Expand All @@ -165,7 +165,7 @@ static FloatVectorValuesSlice createClusterSlice(int clusterSize, int cluster, F
return new FloatVectorValuesSlice(vectors, slice);
}

static void updateAssignmentsWithRecursiveSplit(KMeansResult current, short cluster, KMeansResult splitClusters) {
static void updateAssignmentsWithRecursiveSplit(KMeansResult current, int cluster, KMeansResult splitClusters) {
int orgCentroidsSize = current.centroids().length;

// update based on the outcomes from the split clusters recursion
Expand All @@ -175,7 +175,7 @@ static void updateAssignmentsWithRecursiveSplit(KMeansResult current, short clus
System.arraycopy(current.centroids(), 0, newCentroids, 0, current.centroids().length);

// replace the original cluster
short origCentroidOrd = 0;
int origCentroidOrd = 0;
newCentroids[cluster] = splitClusters.centroids()[0];

// append the remainder
Expand All @@ -188,7 +188,7 @@ static void updateAssignmentsWithRecursiveSplit(KMeansResult current, short clus
if (splitClusters.assignments()[i] != origCentroidOrd) {
int parentOrd = splitClusters.assignmentOrds()[i];
assert current.assignments()[parentOrd] == cluster;
current.assignments()[parentOrd] = (short) (splitClusters.assignments()[i] + orgCentroidsSize - 1);
current.assignments()[parentOrd] = splitClusters.assignments()[i] + orgCentroidsSize - 1;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,21 +70,16 @@ public static float[][] pickInitialCentroids(FloatVectorValues vectors, int samp
return centroids;
}

private boolean stepLloyd(
FloatVectorValues vectors,
float[][] centroids,
short[] assignments,
int sampleSize,
ClusteringAugment augment
) throws IOException {
private boolean stepLloyd(FloatVectorValues vectors, float[][] centroids, int[] assignments, int sampleSize, ClusteringAugment augment)
throws IOException {
boolean changed = false;
int dim = vectors.dimension();
long[] centroidCounts = new long[centroids.length];
float[][] nextCentroids = new float[centroids.length][dim];

for (int i = 0; i < sampleSize; i++) {
float[] vector = vectors.vectorValue(i);
short bestCentroidOffset = getBestCentroidOffset(centroids, vector, i, augment);
int bestCentroidOffset = getBestCentroidOffset(centroids, vector, i, augment);
if (assignments[i] != bestCentroidOffset) {
changed = true;
}
Expand All @@ -98,7 +93,7 @@ private boolean stepLloyd(
for (int clusterIdx = 0; clusterIdx < centroids.length; clusterIdx++) {
if (centroidCounts[clusterIdx] > 0) {
float countF = (float) centroidCounts[clusterIdx];
for (int d = 0; d < dim; d++) {
for (short d = 0; d < dim; d++) {
centroids[clusterIdx][d] = nextCentroids[clusterIdx][d] / countF;
}
}
Expand All @@ -107,10 +102,10 @@ private boolean stepLloyd(
return changed;
}

short getBestCentroidOffset(float[][] centroids, float[] vector, int vectorIdx, ClusteringAugment augment) {
short bestCentroidOffset = -1;
int getBestCentroidOffset(float[][] centroids, float[] vector, int vectorIdx, ClusteringAugment augment) {
int bestCentroidOffset = -1;
float minDsq = Float.MAX_VALUE;
for (short j = 0; j < centroids.length; j++) {
for (int j = 0; j < centroids.length; j++) {
float dsq = VectorUtil.squareDistance(vector, centroids[j]);
if (dsq < minDsq) {
minDsq = dsq;
Expand Down Expand Up @@ -142,7 +137,7 @@ void cluster(FloatVectorValues vectors, KMeansResult kMeansResult, ClusteringAug
return;
}

short[] assignments = new short[n];
int[] assignments = new int[n];
for (int i = 0; i < maxIterations; i++) {
if (stepLloyd(vectors, centroids, assignments, sampleSize, augment) == false) {
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
*/
class KMeansLocal extends KMeans {

final short clustersPerNeighborhood;
final int clustersPerNeighborhood;

KMeansLocal(int sampleSize, int maxIterations, short clustersPerNeighborhood) {
KMeansLocal(int sampleSize, int maxIterations, int clustersPerNeighborhood) {
super(sampleSize, maxIterations);
this.clustersPerNeighborhood = clustersPerNeighborhood;
}
Expand Down Expand Up @@ -66,27 +66,27 @@ private void computeNeighborhoods(
}

@Override
short getBestCentroidOffset(float[][] centroids, float[] vector, int vectorIdx, ClusteringAugment augment) {
int getBestCentroidOffset(float[][] centroids, float[] vector, int vectorIdx, ClusteringAugment augment) {
assert augment instanceof NeighborsClusteringAugment;

short centroidIdx = ((NeighborsClusteringAugment) augment).getCentroidIdx(vectorIdx);
int centroidIdx = ((NeighborsClusteringAugment) augment).getCentroidIdx(vectorIdx);
List<int[]> neighborhoods = ((NeighborsClusteringAugment) augment).neighborhoods;

short bestCentroidOffset = centroidIdx;
int bestCentroidOffset = centroidIdx;
float minDsq = VectorUtil.squareDistance(vector, centroids[centroidIdx]);

int[] neighborOffsets = neighborhoods.get(centroidIdx);
for (int neighborOffset : neighborOffsets) {
float dsq = VectorUtil.squareDistance(vector, centroids[neighborOffset]);
if (dsq < minDsq) {
minDsq = dsq;
bestCentroidOffset = (short) neighborOffset;
bestCentroidOffset = neighborOffset;
}
}
return bestCentroidOffset;
}

private short[] assignSpilled(FloatVectorValues vectors, List<int[]> neighborhoods, float[][] centroids, short[] assignments)
private int[] assignSpilled(FloatVectorValues vectors, List<int[]> neighborhoods, float[][] centroids, int[] assignments)
throws IOException {
// SOAR uses an adjusted distance for assigning spilled documents which is
// given by:
Expand All @@ -97,15 +97,15 @@ private short[] assignSpilled(FloatVectorValues vectors, List<int[]> neighborhoo
// centroid the document was assigned to. The document is assigned to the
// cluster with the smallest soar(x, c).

short[] spilledAssignments = new short[assignments.length];
int[] spilledAssignments = new int[assignments.length];

float[] diffs = new float[vectors.dimension()];
for (int i = 0; i < vectors.size(); i++) {
float[] vector = vectors.vectorValue(i);

short currAssignment = assignments[i];
int currAssignment = assignments[i];
float[] currentCentroid = centroids[currAssignment];
for (int j = 0; j < vectors.dimension(); j++) {
for (short j = 0; j < vectors.dimension(); j++) {
float diff = vector[j] - currentCentroid[j];
diffs[j] = diff;
}
Expand All @@ -128,7 +128,7 @@ private short[] assignSpilled(FloatVectorValues vectors, List<int[]> neighborhoo
}
}

spilledAssignments[i] = (short) bestAssignment;
spilledAssignments[i] = bestAssignment;
}

return spilledAssignments;
Expand Down Expand Up @@ -156,7 +156,7 @@ private float distanceSoar(float[] residual, float[] vector, float[] centroid, f
@Override
void cluster(FloatVectorValues vectors, KMeansResult kMeansResult) throws IOException {
float[][] centroids = kMeansResult.centroids();
short[] assignments = kMeansResult.assignments();
int[] assignments = kMeansResult.assignments();

assert assignments != null;
assert assignments.length == vectors.size();
Expand All @@ -174,14 +174,14 @@ void cluster(FloatVectorValues vectors, KMeansResult kMeansResult) throws IOExce

static class NeighborsClusteringAugment extends ClusteringAugment {
final List<int[]> neighborhoods;
final short[] assignments;
final int[] assignments;

NeighborsClusteringAugment(short[] assignments, List<int[]> neighborhoods) {
NeighborsClusteringAugment(int[] assignments, List<int[]> neighborhoods) {
this.neighborhoods = neighborhoods;
this.assignments = assignments;
}

public short getCentroidIdx(int vectorIdx) {
public int getCentroidIdx(int vectorIdx) {
return this.assignments[vectorIdx];
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
*/
public class KMeansResult {
private float[][] centroids;
private final short[] assignments;
private final int[] assignments;
private final int[] assignmentOrds;
private short[] soarAssignments;
private int[] soarAssignments;

KMeansResult(float[][] centroids, short[] assignments, int[] assignmentOrds, short[] soarAssignments) {
KMeansResult(float[][] centroids, int[] assignments, int[] assignmentOrds, int[] soarAssignments) {
assert centroids != null;
assert assignments != null;
assert assignmentOrds != null;
Expand All @@ -29,20 +29,20 @@ public class KMeansResult {
this.soarAssignments = soarAssignments;
}

KMeansResult(float[][] centroids, short[] assignments, int[] assignmentOrdinals) {
this(centroids, assignments, assignmentOrdinals, new short[0]);
KMeansResult(float[][] centroids, int[] assignments, int[] assignmentOrdinals) {
this(centroids, assignments, assignmentOrdinals, new int[0]);
}

KMeansResult() {
this(new float[0][0], new short[0], new int[0], new short[0]);
this(new float[0][0], new int[0], new int[0], new int[0]);
}

KMeansResult(float[][] centroids) {
this(centroids, new short[0], new int[0], new short[0]);
this(centroids, new int[0], new int[0], new int[0]);
}

KMeansResult(float[][] centroids, short[] assignments) {
this(centroids, assignments, new int[0], new short[0]);
KMeansResult(float[][] centroids, int[] assignments) {
this(centroids, assignments, new int[0], new int[0]);
}

public float[][] centroids() {
Expand All @@ -53,19 +53,19 @@ public void setCentroids(float[][] centroids) {
this.centroids = centroids;
}

public short[] assignments() {
public int[] assignments() {
return assignments;
}

public int[] assignmentOrds() {
return assignmentOrds;
}

public short[] soarAssignments() {
public int[] soarAssignments() {
return soarAssignments;
}

public void setSoarAssignments(short[] soarAssignments) {
public void setSoarAssignments(int[] soarAssignments) {
this.soarAssignments = soarAssignments;
}

Expand Down
Loading