|
55 | 55 | import org.apache.lucene.document.FieldType;
|
56 | 56 | import org.apache.lucene.document.IntField;
|
57 | 57 | import org.apache.lucene.document.IntPoint;
|
| 58 | +import org.apache.lucene.document.KnnByteVectorField; |
58 | 59 | import org.apache.lucene.document.KnnVectorField;
|
59 | 60 | import org.apache.lucene.document.LongField;
|
60 | 61 | import org.apache.lucene.document.LongPoint;
|
@@ -364,7 +365,8 @@ public static final class DocState {
|
364 | 365 | final Field timeSec;
|
365 | 366 | // Necessary for "old style" wiki line files:
|
366 | 367 | final SimpleDateFormat dateParser = new SimpleDateFormat("dd-MMM-yyyy HH:mm:ss", Locale.US);
|
367 |
| - final KnnVectorField vectorField; |
| 368 | + final KnnVectorField floatVectorField; |
| 369 | + final KnnByteVectorField byteVectorField; |
368 | 370 |
|
369 | 371 | // For just y/m/day:
|
370 | 372 | //final SimpleDateFormat dateParser = new SimpleDateFormat("y/M/d", Locale.US);
|
@@ -452,14 +454,18 @@ public static final class DocState {
|
452 | 454 | doc.add(timeSec);
|
453 | 455 |
|
454 | 456 | if (vectorDimension > 0) {
|
455 |
| - // create a throwaway vector so the field's type gets the proper dimension and similarity |
456 |
| - vectorField = switch (vectorEncoding) { |
457 |
| - case BYTE -> new KnnVectorField("vector", new BytesRef(new byte[vectorDimension]), VectorSimilarityFunction.DOT_PRODUCT); |
458 |
| - case FLOAT32 -> new KnnVectorField("vector", new float[vectorDimension], VectorSimilarityFunction.DOT_PRODUCT); |
459 |
| - }; |
460 |
| - doc.add(vectorField); |
| 457 | + if (vectorEncoding == VectorEncoding.FLOAT32) { |
| 458 | + floatVectorField = new KnnVectorField("vector", new float[vectorDimension], VectorSimilarityFunction.DOT_PRODUCT); |
| 459 | + doc.add(floatVectorField); |
| 460 | + byteVectorField = null; |
| 461 | + } else { |
| 462 | + byteVectorField = new KnnByteVectorField("vector", new BytesRef(new byte[vectorDimension]), VectorSimilarityFunction.DOT_PRODUCT); |
| 463 | + doc.add(byteVectorField); |
| 464 | + floatVectorField = null; |
| 465 | + } |
461 | 466 | } else {
|
462 |
| - vectorField = null; |
| 467 | + floatVectorField = null; |
| 468 | + byteVectorField = null; |
463 | 469 | }
|
464 | 470 | }
|
465 | 471 | }
|
@@ -616,7 +622,11 @@ public Document nextDoc(DocState doc, boolean expected) throws IOException {
|
616 | 622 | line = null;
|
617 | 623 |
|
618 | 624 | if (lfd.vector != null) {
|
619 |
| - lfd.getVector(doc.vectorField.vectorValue()); |
| 625 | + if (doc.floatVectorField != null) { |
| 626 | + lfd.getVector(doc.floatVectorField.vectorValue()); |
| 627 | + } else { |
| 628 | + lfd.getVector(doc.byteVectorField.vectorValue().bytes); |
| 629 | + } |
620 | 630 | }
|
621 | 631 |
|
622 | 632 | } else {
|
@@ -672,8 +682,10 @@ public Document nextDoc(DocState doc, boolean expected) throws IOException {
|
672 | 682 | doc.dateCal.setTime(date);
|
673 | 683 | msecSinceEpoch = doc.dateCal.getTimeInMillis();
|
674 | 684 | timeSec = doc.dateCal.get(Calendar.HOUR_OF_DAY)*3600 + doc.dateCal.get(Calendar.MINUTE)*60 + doc.dateCal.get(Calendar.SECOND);
|
675 |
| - if (doc.vectorField != null) { |
676 |
| - doc.vectorField.setVectorValue((float[]) lfd.vector.array()); |
| 685 | + if (doc.floatVectorField != null) { |
| 686 | + doc.floatVectorField.setVectorValue((float[]) lfd.vector.array()); |
| 687 | + } else if (doc.byteVectorField != null) { |
| 688 | + doc.byteVectorField.setVectorValue(new BytesRef((byte[]) lfd.vector.array())); |
677 | 689 | }
|
678 | 690 | }
|
679 | 691 |
|
|
0 commit comments