Skip to content

Fix bbq quantization algorithm but for differently distributed components #126778

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/126778.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 126778
summary: Fix bbq quantization algorithm but for differently distributed components
area: Vector Search
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ public QuantizationResult[] multiScalarQuantize(float[] vector, byte[][] destina
assert bits[i] > 0 && bits[i] <= 8;
int points = (1 << bits[i]);
// Linearly scale the interval to the standard deviation of the vector, ensuring we are within the min/max bounds
intervalScratch[0] = (float) clamp((MINIMUM_MSE_GRID[bits[i] - 1][0] + vecMean) * vecStd, min, max);
intervalScratch[1] = (float) clamp((MINIMUM_MSE_GRID[bits[i] - 1][1] + vecMean) * vecStd, min, max);
intervalScratch[0] = (float) clamp(MINIMUM_MSE_GRID[bits[i] - 1][0] * vecStd + vecMean, min, max);
intervalScratch[1] = (float) clamp(MINIMUM_MSE_GRID[bits[i] - 1][1] * vecStd + vecMean, min, max);
optimizeIntervals(intervalScratch, vector, norm2, points);
float nSteps = ((1 << bits[i]) - 1);
float a = intervalScratch[0];
Expand Down Expand Up @@ -128,8 +128,8 @@ public QuantizationResult scalarQuantize(float[] vector, byte[] destination, byt
vecVar /= vector.length;
double vecStd = Math.sqrt(vecVar);
// Linearly scale the interval to the standard deviation of the vector, ensuring we are within the min/max bounds
intervalScratch[0] = (float) clamp((MINIMUM_MSE_GRID[bits - 1][0] + vecMean) * vecStd, min, max);
intervalScratch[1] = (float) clamp((MINIMUM_MSE_GRID[bits - 1][1] + vecMean) * vecStd, min, max);
intervalScratch[0] = (float) clamp(MINIMUM_MSE_GRID[bits - 1][0] * vecStd + vecMean, min, max);
intervalScratch[1] = (float) clamp(MINIMUM_MSE_GRID[bits - 1][1] * vecStd + vecMean, min, max);
optimizeIntervals(intervalScratch, vector, norm2, points);
float nSteps = ((1 << bits) - 1);
// Now we have the optimized intervals, quantize the vector
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,62 @@ public class OptimizedScalarQuantizerTests extends ESTestCase {

static final byte[] ALL_BITS = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8 };

static float[] deQuantize(byte[] quantized, byte bits, float[] interval, float[] centroid) {
float[] dequantized = new float[quantized.length];
float a = interval[0];
float b = interval[1];
int nSteps = (1 << bits) - 1;
double step = (b - a) / nSteps;
for (int h = 0; h < quantized.length; h++) {
double xi = (double) (quantized[h] & 0xFF) * step + a;
dequantized[h] = (float) (xi + centroid[h]);
}
return dequantized;
}

public void testQuantizationQuality() {
int dims = 16;
int numVectors = 32;
float[][] vectors = new float[numVectors][];
float[] centroid = new float[dims];
for (int i = 0; i < numVectors; ++i) {
vectors[i] = new float[dims];
for (int j = 0; j < dims; ++j) {
vectors[i][j] = randomFloat();
centroid[j] += vectors[i][j];
}
}
for (int j = 0; j < dims; ++j) {
centroid[j] /= numVectors;
}
// similarity doesn't matter for this test
OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(VectorSimilarityFunction.DOT_PRODUCT);
float[] scratch = new float[dims];
for (byte bit : ALL_BITS) {
float eps = (1f / (float) (1 << (bit)));
byte[] destination = new byte[dims];
for (int i = 0; i < numVectors; ++i) {
System.arraycopy(vectors[i], 0, scratch, 0, dims);
OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize(scratch, destination, bit, centroid);
assertValidResults(result);
assertValidQuantizedRange(destination, bit);

float[] dequantized = deQuantize(
destination,
bit,
new float[] { result.lowerInterval(), result.upperInterval() },
centroid
);
float mae = 0;
for (int k = 0; k < dims; ++k) {
mae += Math.abs(dequantized[k] - vectors[i][k]);
}
mae /= dims;
assertTrue("bits: " + bit + " mae: " + mae + " > eps: " + eps, mae <= eps);
}
}
}

public void testAbusiveEdgeCases() {
// large zero array
for (VectorSimilarityFunction vectorSimilarityFunction : VectorSimilarityFunction.values()) {
Expand Down
Loading