Skip to content

Commit ffea6ca

Browse files
iverasebenwtrent
andauthored
Introduce an int4 off-heap vector scorer (#129824)
* Introduce an int4 off-heap vector scorer * iter * Update server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java Co-authored-by: Benjamin Trent <[email protected]> --------- Co-authored-by: Benjamin Trent <[email protected]>
1 parent 321a397 commit ffea6ca

File tree

11 files changed

+506
-72
lines changed

11 files changed

+506
-72
lines changed
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
package org.elasticsearch.benchmark.vector;
10+
11+
import org.apache.lucene.store.Directory;
12+
import org.apache.lucene.store.IOContext;
13+
import org.apache.lucene.store.IndexInput;
14+
import org.apache.lucene.store.IndexOutput;
15+
import org.apache.lucene.store.MMapDirectory;
16+
import org.apache.lucene.util.VectorUtil;
17+
import org.elasticsearch.common.logging.LogConfigurator;
18+
import org.elasticsearch.core.IOUtils;
19+
import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
20+
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
21+
import org.openjdk.jmh.annotations.Benchmark;
22+
import org.openjdk.jmh.annotations.BenchmarkMode;
23+
import org.openjdk.jmh.annotations.Fork;
24+
import org.openjdk.jmh.annotations.Measurement;
25+
import org.openjdk.jmh.annotations.Mode;
26+
import org.openjdk.jmh.annotations.OutputTimeUnit;
27+
import org.openjdk.jmh.annotations.Param;
28+
import org.openjdk.jmh.annotations.Scope;
29+
import org.openjdk.jmh.annotations.Setup;
30+
import org.openjdk.jmh.annotations.State;
31+
import org.openjdk.jmh.annotations.TearDown;
32+
import org.openjdk.jmh.annotations.Warmup;
33+
import org.openjdk.jmh.infra.Blackhole;
34+
35+
import java.io.IOException;
36+
import java.nio.file.Files;
37+
import java.util.concurrent.ThreadLocalRandom;
38+
import java.util.concurrent.TimeUnit;
39+
40+
@BenchmarkMode(Mode.Throughput)
41+
@OutputTimeUnit(TimeUnit.MILLISECONDS)
42+
@State(Scope.Benchmark)
43+
// first iteration is complete garbage, so make sure we really warmup
44+
@Warmup(iterations = 4, time = 1)
45+
// real iterations. not useful to spend tons of time here, better to fork more
46+
@Measurement(iterations = 5, time = 1)
47+
// engage some noise reduction
48+
@Fork(value = 1)
49+
public class Int4ScorerBenchmark {
50+
51+
static {
52+
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
53+
}
54+
55+
@Param({ "384", "702", "1024" })
56+
int dims;
57+
58+
int numVectors = 200;
59+
int numQueries = 10;
60+
61+
byte[] scratch;
62+
byte[][] binaryVectors;
63+
byte[][] binaryQueries;
64+
65+
ES91Int4VectorsScorer scorer;
66+
Directory dir;
67+
IndexInput in;
68+
69+
@Setup
70+
public void setup() throws IOException {
71+
binaryVectors = new byte[numVectors][dims];
72+
dir = new MMapDirectory(Files.createTempDirectory("vectorData"));
73+
try (IndexOutput out = dir.createOutput("vectors", IOContext.DEFAULT)) {
74+
for (byte[] binaryVector : binaryVectors) {
75+
for (int i = 0; i < dims; i++) {
76+
// 4-bit quantization
77+
binaryVector[i] = (byte) ThreadLocalRandom.current().nextInt(16);
78+
}
79+
out.writeBytes(binaryVector, 0, binaryVector.length);
80+
}
81+
}
82+
83+
in = dir.openInput("vectors", IOContext.DEFAULT);
84+
binaryQueries = new byte[numVectors][dims];
85+
for (byte[] binaryVector : binaryVectors) {
86+
for (int i = 0; i < dims; i++) {
87+
// 4-bit quantization
88+
binaryVector[i] = (byte) ThreadLocalRandom.current().nextInt(16);
89+
}
90+
}
91+
92+
scratch = new byte[dims];
93+
scorer = ESVectorizationProvider.getInstance().newES91Int4VectorsScorer(in, dims);
94+
}
95+
96+
@TearDown
97+
public void teardown() throws IOException {
98+
IOUtils.close(dir, in);
99+
}
100+
101+
@Benchmark
102+
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
103+
public void scoreFromArray(Blackhole bh) throws IOException {
104+
for (int j = 0; j < numQueries; j++) {
105+
in.seek(0);
106+
for (int i = 0; i < numVectors; i++) {
107+
in.readBytes(scratch, 0, dims);
108+
bh.consume(VectorUtil.int4DotProduct(binaryQueries[j], scratch));
109+
}
110+
}
111+
}
112+
113+
@Benchmark
114+
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
115+
public void scoreFromMemorySegmentOnlyVector(Blackhole bh) throws IOException {
116+
for (int j = 0; j < numQueries; j++) {
117+
in.seek(0);
118+
for (int i = 0; i < numVectors; i++) {
119+
bh.consume(scorer.int4DotProduct(binaryQueries[j]));
120+
}
121+
}
122+
}
123+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
package org.elasticsearch.simdvec;
10+
11+
import org.apache.lucene.store.IndexInput;
12+
13+
import java.io.IOException;
14+
15+
/** Scorer for quantized vectors stored as an {@link IndexInput}.
16+
* <p>
17+
* Similar to {@link org.apache.lucene.util.VectorUtil#int4DotProduct(byte[], byte[])} but
18+
* one value is read directly from an {@link IndexInput}.
19+
*
20+
* */
21+
public class ES91Int4VectorsScorer {
22+
23+
/** The wrapper {@link IndexInput}. */
24+
protected final IndexInput in;
25+
protected final int dimensions;
26+
protected byte[] scratch;
27+
28+
/** Sole constructor, called by sub-classes. */
29+
public ES91Int4VectorsScorer(IndexInput in, int dimensions) {
30+
this.in = in;
31+
this.dimensions = dimensions;
32+
scratch = new byte[dimensions];
33+
}
34+
35+
public long int4DotProduct(byte[] b) throws IOException {
36+
in.readBytes(scratch, 0, dimensions);
37+
int total = 0;
38+
for (int i = 0; i < dimensions; i++) {
39+
total += scratch[i] * b[i];
40+
}
41+
return total;
42+
}
43+
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ public static ES91OSQVectorsScorer getES91OSQVectorsScorer(IndexInput input, int
4747
return ESVectorizationProvider.getInstance().newES91OSQVectorsScorer(input, dimension);
4848
}
4949

50+
public static ES91Int4VectorsScorer getES91Int4VectorsScorer(IndexInput input, int dimension) throws IOException {
51+
return ESVectorizationProvider.getInstance().newES91Int4VectorsScorer(input, dimension);
52+
}
53+
5054
public static long ipByteBinByte(byte[] q, byte[] d) {
5155
if (q.length != d.length * B_QUERY) {
5256
throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + B_QUERY + " x " + d.length);

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
package org.elasticsearch.simdvec.internal.vectorization;
1111

1212
import org.apache.lucene.store.IndexInput;
13+
import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
1314
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
1415

1516
import java.io.IOException;
@@ -30,4 +31,9 @@ public ESVectorUtilSupport getVectorUtilSupport() {
3031
public ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException {
3132
return new ES91OSQVectorsScorer(input, dimension);
3233
}
34+
35+
@Override
36+
public ES91Int4VectorsScorer newES91Int4VectorsScorer(IndexInput input, int dimension) throws IOException {
37+
return new ES91Int4VectorsScorer(input, dimension);
38+
}
3339
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
package org.elasticsearch.simdvec.internal.vectorization;
1111

1212
import org.apache.lucene.store.IndexInput;
13+
import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
1314
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
1415

1516
import java.io.IOException;
@@ -31,6 +32,9 @@ public static ESVectorizationProvider getInstance() {
3132
/** Create a new {@link ES91OSQVectorsScorer} for the given {@link IndexInput}. */
3233
public abstract ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException;
3334

35+
/** Create a new {@link ES91Int4VectorsScorer} for the given {@link IndexInput}. */
36+
public abstract ES91Int4VectorsScorer newES91Int4VectorsScorer(IndexInput input, int dimension) throws IOException;
37+
3438
// visible for tests
3539
static ESVectorizationProvider lookup(boolean testMode) {
3640
return new DefaultESVectorizationProvider();

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.apache.lucene.util.Constants;
1414
import org.elasticsearch.logging.LogManager;
1515
import org.elasticsearch.logging.Logger;
16+
import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
1617
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
1718

1819
import java.io.IOException;
@@ -38,6 +39,9 @@ public static ESVectorizationProvider getInstance() {
3839
/** Create a new {@link ES91OSQVectorsScorer} for the given {@link IndexInput}. */
3940
public abstract ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException;
4041

42+
/** Create a new {@link ES91Int4VectorsScorer} for the given {@link IndexInput}. */
43+
public abstract ES91Int4VectorsScorer newES91Int4VectorsScorer(IndexInput input, int dimension) throws IOException;
44+
4145
// visible for tests
4246
static ESVectorizationProvider lookup(boolean testMode) {
4347
final int runtimeVersion = Runtime.version().feature();

0 commit comments

Comments
 (0)