From 302f9496b4e30a36f0d1ef7d9fc8bcc9db2027bc Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Wed, 20 Nov 2024 13:37:50 +0800 Subject: [PATCH] it Signed-off-by: zhichao-aws --- .../NeuralSparseTwoPhaseProcessorIT.java | 30 ++++---------- .../query/NeuralSparseQueryBuilderTests.java | 39 +++++++++++++++++++ 2 files changed, 46 insertions(+), 23 deletions(-) diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorIT.java index 3e4ed8844..5f921809a 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorIT.java @@ -47,15 +47,8 @@ public class NeuralSparseTwoPhaseProcessorIT extends BaseNeuralSearchIT { private final Map testRankFeaturesDoc = createRandomTokenWeightMap(TEST_TOKENS); private static final List TWO_PHASE_TEST_TOKEN = List.of("hello", "world"); - private static final Map testFixedQueryTokens = new HashMap<>(); + private static final Map testFixedQueryTokens = Map.of("hello", 5.0f, "world", 4.0f, "a", 3.0f, "b", 2.0f, "c", 1.0f); private static final Supplier> testFixedQueryTokenSupplier = () -> testFixedQueryTokens; - static { - testFixedQueryTokens.put("hello", 5.0f); - testFixedQueryTokens.put("world", 4.0f); - testFixedQueryTokens.put("a", 3.0f); - testFixedQueryTokens.put("b", 2.0f); - testFixedQueryTokens.put("c", 1.0f); - } @Before public void setUp() throws Exception { @@ -82,7 +75,6 @@ public void testBooleanQuery_withMultipleSparseEncodingQueries_whenTwoPhaseEnabl NeuralSparseQueryBuilder sparseEncodingQueryBuilder1 = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) .queryTokensSupplier(randomTokenWeightSupplier); NeuralSparseQueryBuilder sparseEncodingQueryBuilder2 = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_2) - .queryText(TEST_QUERY_TEXT) .queryTokensSupplier(randomTokenWeightSupplier); boolQueryBuilder.should(sparseEncodingQueryBuilder1).should(sparseEncodingQueryBuilder2); @@ -116,7 +108,7 @@ private void setDefaultSearchPipelineForIndex(String indexName) { * { * "neural_sparse": { * "field": "test-sparse-encoding-1", - * "query_text": "TEST_QUERY_TEXT", + * "query_tokens": testFixedQueryTokens, * "model_id": "dcsdcasd", * "boost": 2.0 * } @@ -127,13 +119,12 @@ private void setDefaultSearchPipelineForIndex(String indexName) { * } */ @SneakyThrows - public void testBasicQueryUsingQueryText_whenTwoPhaseEnabled_thenGetExpectedScore() { + public void testBasicQueryUsingQueryTokens_whenTwoPhaseEnabled_thenGetExpectedScore() { try { initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); initializeTwoPhaseProcessor(); setDefaultSearchPipelineForIndex(TEST_BASIC_INDEX_NAME); NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) - .queryText(TEST_QUERY_TEXT) .queryTokensSupplier(testFixedQueryTokenSupplier) .boost(2.0f); Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, sparseEncodingQueryBuilder, 1); @@ -148,14 +139,13 @@ public void testBasicQueryUsingQueryText_whenTwoPhaseEnabled_thenGetExpectedScor } @SneakyThrows - public void testBasicQueryUsingQueryText_whenTwoPhaseEnabledAndDisabled_thenGetSameScore() { + public void testBasicQueryUsingQueryTokens_whenTwoPhaseEnabledAndDisabled_thenGetSameScore() { try { initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); initializeTwoPhaseProcessor(); setDefaultSearchPipelineForIndex(TEST_BASIC_INDEX_NAME); NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) - .queryText(TEST_QUERY_TEXT) .queryTokensSupplier(testFixedQueryTokenSupplier) .boost(2.0f); Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, sparseEncodingQueryBuilder, 1); @@ -164,7 +154,6 @@ public void testBasicQueryUsingQueryText_whenTwoPhaseEnabledAndDisabled_thenGetS float scoreWithoutTwoPhase = objectToFloat(firstInnerHit.get("_score")); sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) - .queryText(TEST_QUERY_TEXT) .queryTokensSupplier(testFixedQueryTokenSupplier) .boost(2.0f); searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, sparseEncodingQueryBuilder, 1); @@ -190,7 +179,7 @@ public void testBasicQueryUsingQueryText_whenTwoPhaseEnabledAndDisabled_thenGetS * { * "neural_sparse": { * "field": "test-sparse-encoding-1", - * "query_text": "Hello world a b", + * "query_tokens": testFixedQueryTokens, * "model_id": "dcsdcasd", * "boost": 2.0 * } @@ -209,7 +198,6 @@ public void testNeuralSparseQueryAsRescoreQuery_whenTwoPhase_thenGetExpectedScor setDefaultSearchPipelineForIndex(TEST_BASIC_INDEX_NAME); NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) - .queryText(TEST_QUERY_TEXT) .queryTokensSupplier(testFixedQueryTokenSupplier) .boost(2.0f); QueryBuilder queryBuilder = new MatchAllQueryBuilder(); @@ -232,7 +220,7 @@ public void testNeuralSparseQueryAsRescoreQuery_whenTwoPhase_thenGetExpectedScor * { * "neural_sparse": { * "field": "test-sparse-encoding-1", - * "query_text": "Hello world a b", + * "query_tokens": testFixedQueryTokens, * "model_id": "dcsdcasd", * "boost": 2.0 * } @@ -240,7 +228,7 @@ public void testNeuralSparseQueryAsRescoreQuery_whenTwoPhase_thenGetExpectedScor * { * "neural_sparse": { * "field": "test-sparse-encoding-1", - * "query_text": "Hello world a b", + * "query_tokens": testFixedQueryTokens, * "model_id": "dcsdcasd", * "boost": 2.0 * } @@ -316,7 +304,6 @@ public void testMultiNeuralSparseQuery_whenTwoPhaseAndFilter_thenGetExpectedScor setDefaultSearchPipelineForIndex(TEST_BASIC_INDEX_NAME); BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) - .queryText(TEST_QUERY_TEXT) .queryTokensSupplier(testFixedQueryTokenSupplier) .boost(2.0f); boolQueryBuilder.should(sparseEncodingQueryBuilder); @@ -401,7 +388,6 @@ public void testNeuralSParseQuery_whenTwoPhaseAndNestedInConstantScoreQuery_then createNeuralSparseTwoPhaseSearchProcessor(search_pipeline, 0.6f, 5f, 10000); setDefaultSearchPipelineForIndex(TEST_BASIC_INDEX_NAME); NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) - .queryText(TEST_QUERY_TEXT) .queryTokensSupplier(testFixedQueryTokenSupplier) .boost(1.0f); ConstantScoreQueryBuilder constantScoreQueryBuilder = new ConstantScoreQueryBuilder(sparseEncodingQueryBuilder); @@ -421,7 +407,6 @@ public void testNeuralSParseQuery_whenTwoPhaseAndNestedInDisjunctionMaxQuery_the createNeuralSparseTwoPhaseSearchProcessor(search_pipeline, 0.6f, 5f, 10000); setDefaultSearchPipelineForIndex(TEST_BASIC_INDEX_NAME); NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) - .queryText(TEST_QUERY_TEXT) .queryTokensSupplier(testFixedQueryTokenSupplier) .boost(5.0f); DisMaxQueryBuilder disMaxQueryBuilder = new DisMaxQueryBuilder(); @@ -444,7 +429,6 @@ public void testNeuralSparseQuery_whenTwoPhaseAndNestedInFunctionScoreQuery_then createNeuralSparseTwoPhaseSearchProcessor(search_pipeline, 0.6f, 5f, 10000); setDefaultSearchPipelineForIndex(TEST_BASIC_INDEX_NAME); NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) - .queryText(TEST_QUERY_TEXT) .queryTokensSupplier(testFixedQueryTokenSupplier) .boost(5.0f); FunctionScoreQueryBuilder functionScoreQueryBuilder = new FunctionScoreQueryBuilder(sparseEncodingQueryBuilder); diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java index 7509efd42..c4d50ad55 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java @@ -52,6 +52,7 @@ import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.util.NeuralSearchClusterTestUtils; import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; +import org.opensearch.neuralsearch.util.pruning.PruneType; import org.opensearch.test.OpenSearchTestCase; import lombok.SneakyThrows; @@ -649,6 +650,44 @@ public void testRewrite_whenQueryTokensSupplierNull_thenSetQueryTokensSupplier() assertEquals(expectedMap, queryBuilder.queryTokensSupplier().get()); } + @SneakyThrows + public void testRewrite_whenQueryTokensSupplierNull_andPruneSet_thenSuceessPrune() { + NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME) + .queryText(QUERY_TEXT) + .modelId(MODEL_ID) + .twoPhaseSharedQueryToken(Map.of()) + .twoPhasePruneRatio(3.0f) + .twoPhasePruneType(PruneType.ABS_VALUE); + Map expectedMap = Map.of("1", 1f, "2", 5f); + MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(List.of(Map.of("response", List.of(expectedMap)))); + return null; + }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(any(), any(), any()); + NeuralSparseQueryBuilder.initialize(mlCommonsClientAccessor); + + final CountDownLatch inProgressLatch = new CountDownLatch(1); + QueryRewriteContext queryRewriteContext = mock(QueryRewriteContext.class); + doAnswer(invocation -> { + BiConsumer> biConsumer = invocation.getArgument(0); + biConsumer.accept( + null, + ActionListener.wrap( + response -> inProgressLatch.countDown(), + err -> fail("Failed to set query tokens supplier: " + err.getMessage()) + ) + ); + return null; + }).when(queryRewriteContext).registerAsyncAction(any()); + + NeuralSparseQueryBuilder queryBuilder = (NeuralSparseQueryBuilder) sparseEncodingQueryBuilder.doRewrite(queryRewriteContext); + assertNotNull(queryBuilder.queryTokensSupplier()); + assertTrue(inProgressLatch.await(5, TimeUnit.SECONDS)); + assertEquals(Map.of("2", 5f), queryBuilder.queryTokensSupplier().get()); + assertEquals(Map.of("1", 1f), queryBuilder.twoPhaseSharedQueryToken()); + } + @SneakyThrows public void testRewrite_whenQueryTokensSupplierSet_thenReturnSelf() { NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME)