Skip to content

Commit

Permalink
it
Browse files Browse the repository at this point in the history
Signed-off-by: zhichao-aws <zhichaog@amazon.com>
  • Loading branch information
zhichao-aws committed Nov 20, 2024
1 parent 339222d commit 302f949
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,8 @@ public class NeuralSparseTwoPhaseProcessorIT extends BaseNeuralSearchIT {
private final Map<String, Float> testRankFeaturesDoc = createRandomTokenWeightMap(TEST_TOKENS);
private static final List<String> TWO_PHASE_TEST_TOKEN = List.of("hello", "world");

private static final Map<String, Float> testFixedQueryTokens = new HashMap<>();
private static final Map<String, Float> testFixedQueryTokens = Map.of("hello", 5.0f, "world", 4.0f, "a", 3.0f, "b", 2.0f, "c", 1.0f);
private static final Supplier<Map<String, Float>> testFixedQueryTokenSupplier = () -> testFixedQueryTokens;
static {
testFixedQueryTokens.put("hello", 5.0f);
testFixedQueryTokens.put("world", 4.0f);
testFixedQueryTokens.put("a", 3.0f);
testFixedQueryTokens.put("b", 2.0f);
testFixedQueryTokens.put("c", 1.0f);
}

@Before
public void setUp() throws Exception {
Expand All @@ -82,7 +75,6 @@ public void testBooleanQuery_withMultipleSparseEncodingQueries_whenTwoPhaseEnabl
NeuralSparseQueryBuilder sparseEncodingQueryBuilder1 = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1)
.queryTokensSupplier(randomTokenWeightSupplier);
NeuralSparseQueryBuilder sparseEncodingQueryBuilder2 = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_2)
.queryText(TEST_QUERY_TEXT)
.queryTokensSupplier(randomTokenWeightSupplier);
boolQueryBuilder.should(sparseEncodingQueryBuilder1).should(sparseEncodingQueryBuilder2);

Expand Down Expand Up @@ -116,7 +108,7 @@ private void setDefaultSearchPipelineForIndex(String indexName) {
* {
* "neural_sparse": {
* "field": "test-sparse-encoding-1",
* "query_text": "TEST_QUERY_TEXT",
* "query_tokens": testFixedQueryTokens,
* "model_id": "dcsdcasd",
* "boost": 2.0
* }
Expand All @@ -127,13 +119,12 @@ private void setDefaultSearchPipelineForIndex(String indexName) {
* }
*/
@SneakyThrows
public void testBasicQueryUsingQueryText_whenTwoPhaseEnabled_thenGetExpectedScore() {
public void testBasicQueryUsingQueryTokens_whenTwoPhaseEnabled_thenGetExpectedScore() {
try {
initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME);
initializeTwoPhaseProcessor();
setDefaultSearchPipelineForIndex(TEST_BASIC_INDEX_NAME);
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1)
.queryText(TEST_QUERY_TEXT)
.queryTokensSupplier(testFixedQueryTokenSupplier)
.boost(2.0f);
Map<String, Object> searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, sparseEncodingQueryBuilder, 1);
Expand All @@ -148,14 +139,13 @@ public void testBasicQueryUsingQueryText_whenTwoPhaseEnabled_thenGetExpectedScor
}

@SneakyThrows
public void testBasicQueryUsingQueryText_whenTwoPhaseEnabledAndDisabled_thenGetSameScore() {
public void testBasicQueryUsingQueryTokens_whenTwoPhaseEnabledAndDisabled_thenGetSameScore() {
try {
initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME);
initializeTwoPhaseProcessor();

setDefaultSearchPipelineForIndex(TEST_BASIC_INDEX_NAME);
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1)
.queryText(TEST_QUERY_TEXT)
.queryTokensSupplier(testFixedQueryTokenSupplier)
.boost(2.0f);
Map<String, Object> searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, sparseEncodingQueryBuilder, 1);
Expand All @@ -164,7 +154,6 @@ public void testBasicQueryUsingQueryText_whenTwoPhaseEnabledAndDisabled_thenGetS
float scoreWithoutTwoPhase = objectToFloat(firstInnerHit.get("_score"));

sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1)
.queryText(TEST_QUERY_TEXT)
.queryTokensSupplier(testFixedQueryTokenSupplier)
.boost(2.0f);
searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, sparseEncodingQueryBuilder, 1);
Expand All @@ -190,7 +179,7 @@ public void testBasicQueryUsingQueryText_whenTwoPhaseEnabledAndDisabled_thenGetS
* {
* "neural_sparse": {
* "field": "test-sparse-encoding-1",
* "query_text": "Hello world a b",
* "query_tokens": testFixedQueryTokens,
* "model_id": "dcsdcasd",
* "boost": 2.0
* }
Expand All @@ -209,7 +198,6 @@ public void testNeuralSparseQueryAsRescoreQuery_whenTwoPhase_thenGetExpectedScor

setDefaultSearchPipelineForIndex(TEST_BASIC_INDEX_NAME);
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1)
.queryText(TEST_QUERY_TEXT)
.queryTokensSupplier(testFixedQueryTokenSupplier)
.boost(2.0f);
QueryBuilder queryBuilder = new MatchAllQueryBuilder();
Expand All @@ -232,15 +220,15 @@ public void testNeuralSparseQueryAsRescoreQuery_whenTwoPhase_thenGetExpectedScor
* {
* "neural_sparse": {
* "field": "test-sparse-encoding-1",
* "query_text": "Hello world a b",
* "query_tokens": testFixedQueryTokens,
* "model_id": "dcsdcasd",
* "boost": 2.0
* }
* },
* {
* "neural_sparse": {
* "field": "test-sparse-encoding-1",
* "query_text": "Hello world a b",
* "query_tokens": testFixedQueryTokens,
* "model_id": "dcsdcasd",
* "boost": 2.0
* }
Expand Down Expand Up @@ -316,7 +304,6 @@ public void testMultiNeuralSparseQuery_whenTwoPhaseAndFilter_thenGetExpectedScor
setDefaultSearchPipelineForIndex(TEST_BASIC_INDEX_NAME);
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1)
.queryText(TEST_QUERY_TEXT)
.queryTokensSupplier(testFixedQueryTokenSupplier)
.boost(2.0f);
boolQueryBuilder.should(sparseEncodingQueryBuilder);
Expand Down Expand Up @@ -401,7 +388,6 @@ public void testNeuralSParseQuery_whenTwoPhaseAndNestedInConstantScoreQuery_then
createNeuralSparseTwoPhaseSearchProcessor(search_pipeline, 0.6f, 5f, 10000);
setDefaultSearchPipelineForIndex(TEST_BASIC_INDEX_NAME);
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1)
.queryText(TEST_QUERY_TEXT)
.queryTokensSupplier(testFixedQueryTokenSupplier)
.boost(1.0f);
ConstantScoreQueryBuilder constantScoreQueryBuilder = new ConstantScoreQueryBuilder(sparseEncodingQueryBuilder);
Expand All @@ -421,7 +407,6 @@ public void testNeuralSParseQuery_whenTwoPhaseAndNestedInDisjunctionMaxQuery_the
createNeuralSparseTwoPhaseSearchProcessor(search_pipeline, 0.6f, 5f, 10000);
setDefaultSearchPipelineForIndex(TEST_BASIC_INDEX_NAME);
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1)
.queryText(TEST_QUERY_TEXT)
.queryTokensSupplier(testFixedQueryTokenSupplier)
.boost(5.0f);
DisMaxQueryBuilder disMaxQueryBuilder = new DisMaxQueryBuilder();
Expand All @@ -444,7 +429,6 @@ public void testNeuralSparseQuery_whenTwoPhaseAndNestedInFunctionScoreQuery_then
createNeuralSparseTwoPhaseSearchProcessor(search_pipeline, 0.6f, 5f, 10000);
setDefaultSearchPipelineForIndex(TEST_BASIC_INDEX_NAME);
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1)
.queryText(TEST_QUERY_TEXT)
.queryTokensSupplier(testFixedQueryTokenSupplier)
.boost(5.0f);
FunctionScoreQueryBuilder functionScoreQueryBuilder = new FunctionScoreQueryBuilder(sparseEncodingQueryBuilder);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.util.NeuralSearchClusterTestUtils;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;
import org.opensearch.neuralsearch.util.pruning.PruneType;
import org.opensearch.test.OpenSearchTestCase;

import lombok.SneakyThrows;
Expand Down Expand Up @@ -649,6 +650,44 @@ public void testRewrite_whenQueryTokensSupplierNull_thenSetQueryTokensSupplier()
assertEquals(expectedMap, queryBuilder.queryTokensSupplier().get());
}

@SneakyThrows
public void testRewrite_whenQueryTokensSupplierNull_andPruneSet_thenSuceessPrune() {
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME)
.queryText(QUERY_TEXT)
.modelId(MODEL_ID)
.twoPhaseSharedQueryToken(Map.of())
.twoPhasePruneRatio(3.0f)
.twoPhasePruneType(PruneType.ABS_VALUE);
Map<String, Float> expectedMap = Map.of("1", 1f, "2", 5f);
MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class);
doAnswer(invocation -> {
ActionListener<List<Map<String, ?>>> listener = invocation.getArgument(2);
listener.onResponse(List.of(Map.of("response", List.of(expectedMap))));
return null;
}).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(any(), any(), any());
NeuralSparseQueryBuilder.initialize(mlCommonsClientAccessor);

final CountDownLatch inProgressLatch = new CountDownLatch(1);
QueryRewriteContext queryRewriteContext = mock(QueryRewriteContext.class);
doAnswer(invocation -> {
BiConsumer<Client, ActionListener<?>> biConsumer = invocation.getArgument(0);
biConsumer.accept(
null,
ActionListener.wrap(
response -> inProgressLatch.countDown(),
err -> fail("Failed to set query tokens supplier: " + err.getMessage())
)
);
return null;
}).when(queryRewriteContext).registerAsyncAction(any());

NeuralSparseQueryBuilder queryBuilder = (NeuralSparseQueryBuilder) sparseEncodingQueryBuilder.doRewrite(queryRewriteContext);
assertNotNull(queryBuilder.queryTokensSupplier());
assertTrue(inProgressLatch.await(5, TimeUnit.SECONDS));
assertEquals(Map.of("2", 5f), queryBuilder.queryTokensSupplier().get());
assertEquals(Map.of("1", 1f), queryBuilder.twoPhaseSharedQueryToken());
}

@SneakyThrows
public void testRewrite_whenQueryTokensSupplierSet_thenReturnSelf() {
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME)
Expand Down

0 comments on commit 302f949

Please sign in to comment.