Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add binary index support for Lucene engine #2292

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.18...2.x)
### Features
- Add binary index support for Lucene engine. (#2292)[https://github.com/opensearch-project/k-NN/pull/2292]
### Enhancements
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
### Bug Fixes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ public float compare(byte[] v1, byte[] v2) {

@Override
public VectorSimilarityFunction getVectorSimilarityFunction() {
throw new IllegalStateException("VectorSimilarityFunction is not available for Hamming space");
// This is not used in binary case
return VectorSimilarityFunction.EUCLIDEAN;
Comment on lines +32 to +33
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure about this. This can lead to unknown behaviors in future. So I would suggest to not do this.

}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public enum VectorDataType {

@Override
public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction) {
throw new IllegalStateException("Unsupported method");
return KnnByteVectorField.createFieldType(dimension / 8, vectorSimilarityFunction);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of 8 use Byte.SIZE

}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,12 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) {
}
}

KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams(params, defaultMaxConnections, defaultBeamWidth);
KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams(
params,
defaultMaxConnections,
defaultBeamWidth,
knnMethodContext.getSpaceType()
);
log.debug(
"Initialize KNN vector format for field [{}] with params [{}] = \"{}\" and [{}] = \"{}\"",
field,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.KNN990Codec;

import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;

import java.io.IOException;

public class KNN990BinaryVectorScorer implements FlatVectorsScorer {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets add java doc on all the public functions and classes.

@Override
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
VectorSimilarityFunction vectorSimilarityFunction,
RandomAccessVectorValues randomAccessVectorValues
) throws IOException {
assert randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this assert makes the IllegalArgumentException line unreachable. should we remove this assert statement?

if (randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes) {
return new BinaryRandomVectorScorerSupplier((RandomAccessVectorValues.Bytes) randomAccessVectorValues);
}
throw new IllegalArgumentException("vectorValues must be an instance of RandomAccessVectorValues.Bytes");
}

@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction vectorSimilarityFunction,
RandomAccessVectorValues randomAccessVectorValues,
float[] queryVector
) throws IOException {
throw new IllegalArgumentException("binary vectors do not support float[] targets");
}

@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction vectorSimilarityFunction,
RandomAccessVectorValues randomAccessVectorValues,
byte[] queryVector
) throws IOException {
assert randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes;
if (randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes) {
return new BinaryRandomVectorScorer((RandomAccessVectorValues.Bytes) randomAccessVectorValues, queryVector);
}
throw new IllegalArgumentException("vectorValues must be an instance of RandomAccessVectorValues.Bytes");
}

static class BinaryRandomVectorScorer implements RandomVectorScorer {
private final RandomAccessVectorValues.Bytes vectorValues;
private final int bitDimensions;
private final byte[] queryVector;

BinaryRandomVectorScorer(RandomAccessVectorValues.Bytes vectorValues, byte[] query) {
this.queryVector = query;
this.bitDimensions = vectorValues.dimension() * Byte.SIZE;
this.vectorValues = vectorValues;
}

@Override
public float score(int node) throws IOException {
return (bitDimensions - VectorUtil.xorBitCount(queryVector, vectorValues.vectorValue(node))) / (float) bitDimensions;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought hamming distance is all about counting bits by doing xor. why do we need dimensions here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and by looking at this formula I think we just need to do this:

1 / (1+VectorUtil.xorBitCount(queryVector, vectorValues.vectorValue(node)))

https://opensearch.org/docs/latest/search-plugins/knn/approximate-knn/#spaces

}

@Override
public int maxOrd() {
return vectorValues.size();
}

@Override
public int ordToDoc(int ord) {
return vectorValues.ordToDoc(ord);
}

@Override
public Bits getAcceptOrds(Bits acceptDocs) {
return vectorValues.getAcceptOrds(acceptDocs);
}
}

static class BinaryRandomVectorScorerSupplier implements RandomVectorScorerSupplier {
protected final RandomAccessVectorValues.Bytes vectorValues;
protected final RandomAccessVectorValues.Bytes vectorValues1;
protected final RandomAccessVectorValues.Bytes vectorValues2;

public BinaryRandomVectorScorerSupplier(RandomAccessVectorValues.Bytes vectorValues) throws IOException {
this.vectorValues = vectorValues;
this.vectorValues1 = vectorValues.copy();
this.vectorValues2 = vectorValues.copy();
}

@Override
public RandomVectorScorer scorer(int ord) throws IOException {
byte[] queryVector = vectorValues1.vectorValue(ord);
return new BinaryRandomVectorScorer(vectorValues2, queryVector);
}

@Override
public RandomVectorScorerSupplier copy() throws IOException {
return new BinaryRandomVectorScorerSupplier(vectorValues.copy());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.KNN990Codec;

import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.search.TaskExecutor;

import java.io.IOException;
import java.util.concurrent.ExecutorService;

public class KNN990HnswBinaryVectorsFormat extends KnnVectorsFormat {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be 990 format? or it should be 9120 format? Because in lucene we moved to 9.12 version. I think we need to move this format to rename this class to 9120. @jmazanec15 what do you think?


private final int maxConn;
private final int beamWidth;
private static final FlatVectorsFormat flatVectorsFormat = new Lucene99FlatVectorsFormat(new KNN990BinaryVectorScorer());
private final int numMergeWorkers;
private final TaskExecutor mergeExec;

private static final String NAME = "KNN990HnswBinaryVectorsFormat";

public KNN990HnswBinaryVectorsFormat() {
this(16, 100, 1, (ExecutorService) null);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are defaults for all these values. Please pick those defaults.

}

public KNN990HnswBinaryVectorsFormat(int maxConn, int beamWidth) {
this(maxConn, beamWidth, 1, (ExecutorService) null);
}

public KNN990HnswBinaryVectorsFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) {
super(NAME);
if (maxConn > 0 && maxConn <= 512) {
if (beamWidth > 0 && beamWidth <= 3200) {
Comment on lines +42 to +43
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For these magic numbers like 512 and 3200, its better to create constants on what they represent.

Also, please add comments on why we are adding these conditions.

this.maxConn = maxConn;
this.beamWidth = beamWidth;
if (numMergeWorkers == 1 && mergeExec != null) {
throw new IllegalArgumentException("No executor service is needed as we'll use single thread to merge");
} else {
this.numMergeWorkers = numMergeWorkers;
if (mergeExec != null) {
this.mergeExec = new TaskExecutor(mergeExec);
} else {
this.mergeExec = null;
}

}
} else {
throw new IllegalArgumentException("beamWidth must be positive and less than or equal to 3200; beamWidth=" + beamWidth);
}
} else {
throw new IllegalArgumentException("maxConn must be positive and less than or equal to 512; maxConn=" + maxConn);
}
}

@Override
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new Lucene99HnswVectorsWriter(
state,
this.maxConn,
this.beamWidth,
flatVectorsFormat.fieldsWriter(state),
this.numMergeWorkers,
this.mergeExec
);
}

@Override
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state));
}

@Override
public int getMaxDimensions(String fieldName) {
return 1024;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This value has been changed for K-NN plugin. For k-NN plugin Max dimension is 16k.
Ref:

private static Map<KNNEngine, Integer> MAX_DIMENSIONS_BY_ENGINE = Map.of(
KNNEngine.NMSLIB,
16_000,
KNNEngine.FAISS,
16_000,
KNNEngine.LUCENE,
16_000
);

}

@Override
public String toString() {
return "KNN990HnswBinaryVectorsFormat(name=KNN990HnswBinaryVectorsFormat, maxConn="
+ this.maxConn
+ ", beamWidth="
+ this.beamWidth
+ ", flatVectorFormat="
+ flatVectorsFormat
+ ")";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.codec.BasePerFieldKnnVectorsFormat;
import org.opensearch.knn.index.engine.KNNEngine;

Expand All @@ -24,11 +25,17 @@ public KNN990PerFieldKnnVectorsFormat(final Optional<MapperService> mapperServic
mapperService,
Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
() -> new Lucene99HnswVectorsFormat(),
knnVectorsFormatParams -> new Lucene99HnswVectorsFormat(
knnVectorsFormatParams.getMaxConnections(),
knnVectorsFormatParams.getBeamWidth()
),
Lucene99HnswVectorsFormat::new,
knnVectorsFormatParams -> {
if (knnVectorsFormatParams.getSpaceType() == SpaceType.HAMMING) {
return new KNN990HnswBinaryVectorsFormat(
knnVectorsFormatParams.getMaxConnections(),
knnVectorsFormatParams.getBeamWidth()
);
} else {
return new Lucene99HnswVectorsFormat(knnVectorsFormatParams.getMaxConnections(), knnVectorsFormatParams.getBeamWidth());
}
},
knnScalarQuantizedVectorsFormatParams -> new Lucene99HnswScalarQuantizedVectorsFormat(
knnScalarQuantizedVectorsFormatParams.getMaxConnections(),
knnScalarQuantizedVectorsFormatParams.getBeamWidth(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import lombok.Getter;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.SpaceType;

import java.util.Map;

Expand All @@ -17,10 +18,16 @@
public class KNNVectorsFormatParams {
private int maxConnections;
private int beamWidth;
private final SpaceType spaceType;

public KNNVectorsFormatParams(final Map<String, Object> params, int defaultMaxConnections, int defaultBeamWidth) {
this(params, defaultMaxConnections, defaultBeamWidth, SpaceType.UNDEFINED);
}

public KNNVectorsFormatParams(final Map<String, Object> params, int defaultMaxConnections, int defaultBeamWidth, SpaceType spaceType) {
initMaxConnections(params, defaultMaxConnections);
initBeamWidth(params, defaultBeamWidth);
this.spaceType = spaceType;
}
Comment on lines 23 to 31
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need both of these constructors? can we not just keep 1.


public boolean validate(final Map<String, Object> params) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,18 @@
*/
public class LuceneHNSWMethod extends AbstractKNNMethod {

private static final Set<VectorDataType> SUPPORTED_DATA_TYPES = ImmutableSet.of(VectorDataType.FLOAT, VectorDataType.BYTE);
private static final Set<VectorDataType> SUPPORTED_DATA_TYPES = ImmutableSet.of(
VectorDataType.FLOAT,
VectorDataType.BYTE,
VectorDataType.BINARY
);

public final static List<SpaceType> SUPPORTED_SPACES = Arrays.asList(
SpaceType.UNDEFINED,
SpaceType.L2,
SpaceType.COSINESIMIL,
SpaceType.INNER_PRODUCT
SpaceType.INNER_PRODUCT,
SpaceType.HAMMING
);

final static Encoder SQ_ENCODER = new LuceneSQEncoder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ public static Query create(CreateQueryRequest createQueryRequest) {
log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
switch (vectorDataType) {
case BYTE:
case BINARY:
return getKnnByteVectorQuery(fieldName, byteVector, luceneK, filterQuery, parentFilter);
case FLOAT:
return getKnnFloatVectorQuery(fieldName, vector, luceneK, filterQuery, parentFilter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
#

org.opensearch.knn.index.codec.KNN990Codec.NativeEngines990KnnVectorsFormat
org.opensearch.knn.index.codec.KNN990Codec.KNN990HnswBinaryVectorsFormat
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.util.BytesRef;
Expand Down Expand Up @@ -109,14 +108,6 @@ private void createKNNByteVectorDocument(Directory directory) throws IOException
writer.close();
}

public void testCreateKnnVectorFieldType_whenBinary_thenException() {
Exception ex = expectThrows(
IllegalStateException.class,
() -> VectorDataType.BINARY.createKnnVectorFieldType(1, VectorSimilarityFunction.EUCLIDEAN)
);
assertTrue(ex.getMessage().contains("Unsupported method"));
}

public void testGetVectorFromBytesRef_whenBinary_thenException() {
byte[] vector = { 1, 2, 3 };
float[] expected = { 1, 2, 3 };
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,13 +283,6 @@ public void testValidateVectorDataType_whenBinaryFaissHNSW_thenValid() {
}

public void testValidateVectorDataType_whenBinaryNonFaiss_thenException() {
validateValidateVectorDataType(
KNNEngine.LUCENE,
KNNConstants.METHOD_HNSW,
VectorDataType.BINARY,
SpaceType.HAMMING,
"UnsupportedMethod"
);
validateValidateVectorDataType(
KNNEngine.NMSLIB,
KNNConstants.METHOD_HNSW,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1528,8 +1528,7 @@ public void testTypeParser_whenBinaryFaissHNSWWithInvalidSpaceType_thenException
}
}

public void testTypeParser_whenBinaryNonFaiss_thenException() throws IOException {
testTypeParserWithBinaryDataType(KNNEngine.LUCENE, SpaceType.HAMMING, METHOD_HNSW, 8, "is not supported for vector data type");
public void testTypeParser_whenBinaryNmslib_thenException() throws IOException {
testTypeParserWithBinaryDataType(KNNEngine.NMSLIB, SpaceType.HAMMING, METHOD_HNSW, 8, "is not supported for vector data type");
}

Expand Down
Loading
Loading