Skip to content

Commit

Permalink
UT
Browse files Browse the repository at this point in the history
Signed-off-by: zhichao-aws <zhichaog@amazon.com>
  • Loading branch information
zhichao-aws committed Nov 20, 2024
1 parent 41d7cb5 commit 339222d
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import lombok.Getter;
import lombok.Setter;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.common.collect.Tuple;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.ingest.ConfigurationUtils;
Expand All @@ -23,11 +22,9 @@
import org.opensearch.search.rescore.QueryRescorerBuilder;
import org.opensearch.search.rescore.RescorerBuilder;

import java.util.Collections;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

/**
* A SearchRequestProcessor to generate two-phase NeuralSparseQueryBuilder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder;
import org.opensearch.neuralsearch.util.pruning.PruneType;
import org.opensearch.neuralsearch.util.pruning.PruneUtils;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.rescore.QueryRescorerBuilder;
import org.opensearch.test.OpenSearchTestCase;
Expand All @@ -19,7 +21,6 @@

public class NeuralSparseTwoPhaseProcessorTests extends OpenSearchTestCase {
static final private String PARAMETER_KEY = "two_phase_parameter";
static final private String RATIO_KEY = "prune_ratio";
static final private String ENABLE_KEY = "enabled";
static final private String EXPANSION_KEY = "expansion_rate";
static final private String MAX_WINDOW_SIZE_KEY = "max_window_size";
Expand All @@ -30,6 +31,7 @@ public void testFactory_whenCreateDefaultPipeline_thenSuccess() throws Exception
assertEquals(0.3f, processor.getPruneRatio(), 1e-3);
assertEquals(4.0f, processor.getWindowExpansion(), 1e-3);
assertEquals(10000, processor.getMaxWindowSize());
assertEquals(PruneType.MAX_RATIO, processor.getPruneType());

NeuralSparseTwoPhaseProcessor defaultProcessor = factory.create(
Collections.emptyMap(),
Expand All @@ -42,11 +44,23 @@ public void testFactory_whenCreateDefaultPipeline_thenSuccess() throws Exception
assertEquals(0.4f, defaultProcessor.getPruneRatio(), 1e-3);
assertEquals(5.0f, defaultProcessor.getWindowExpansion(), 1e-3);
assertEquals(10000, defaultProcessor.getMaxWindowSize());
assertEquals(PruneType.MAX_RATIO, processor.getPruneType());
}

public void testFactory_whenCreatePipelineWithCustomPruneType_thenSuccess() throws Exception {
NeuralSparseTwoPhaseProcessor.Factory factory = new NeuralSparseTwoPhaseProcessor.Factory();
NeuralSparseTwoPhaseProcessor processor = createTestProcessor(factory, 5f, "top_k", true, 5f, 1000);
assertEquals(5f, processor.getPruneRatio(), 1e-6);
assertEquals(PruneType.TOP_K, processor.getPruneType());
}

public void testFactory_whenRatioOutOfRange_thenThrowException() {
NeuralSparseTwoPhaseProcessor.Factory factory = new NeuralSparseTwoPhaseProcessor.Factory();
expectThrows(IllegalArgumentException.class, () -> createTestProcessor(factory, 1.1f, true, 5.0f, 10000));
expectThrows(IllegalArgumentException.class, () -> createTestProcessor(factory, 1.1f, "max_ratio", true, 5.0f, 10000));
expectThrows(IllegalArgumentException.class, () -> createTestProcessor(factory, 0f, "top_k", true, 5.0f, 10000));
expectThrows(IllegalArgumentException.class, () -> createTestProcessor(factory, 1.1f, "alpha_mass", true, 5.0f, 10000));
expectThrows(IllegalArgumentException.class, () -> createTestProcessor(factory, -1f, "abs_value", true, 5.0f, 10000));
}

public void testFactory_whenWindowExpansionOutOfRange_thenThrowException() {
Expand All @@ -72,6 +86,19 @@ public void testProcessRequest_whenTwoPhaseEnabled_thenSuccess() throws Exceptio
assertNotNull(searchRequest.source().rescores());
}

public void testProcessRequest_whenUseCustomPruneType_thenSuccess() throws Exception {
NeuralSparseTwoPhaseProcessor.Factory factory = new NeuralSparseTwoPhaseProcessor.Factory();
NeuralSparseQueryBuilder neuralQueryBuilder = new NeuralSparseQueryBuilder();
SearchRequest searchRequest = new SearchRequest();
searchRequest.source(new SearchSourceBuilder().query(neuralQueryBuilder));
NeuralSparseTwoPhaseProcessor processor = createTestProcessor(factory, 0.5f, "alpha_mass", true, 4.0f, 10000);
processor.processRequest(searchRequest);
NeuralSparseQueryBuilder queryBuilder = (NeuralSparseQueryBuilder) searchRequest.source().query();
assertEquals(queryBuilder.twoPhasePruneRatio(), 0.5f, 1e-3);
assertEquals(queryBuilder.twoPhasePruneType(), PruneType.ALPHA_MASS);
assertNotNull(searchRequest.source().rescores());
}

public void testProcessRequest_whenTwoPhaseEnabledAndNestedBoolean_thenSuccess() throws Exception {
NeuralSparseTwoPhaseProcessor.Factory factory = new NeuralSparseTwoPhaseProcessor.Factory();
NeuralSparseQueryBuilder neuralQueryBuilder = new NeuralSparseQueryBuilder();
Expand Down Expand Up @@ -155,9 +182,28 @@ private NeuralSparseTwoPhaseProcessor createTestProcessor(
Map<String, Object> configMap = new HashMap<>();
configMap.put(ENABLE_KEY, enabled);
Map<String, Object> twoPhaseParaMap = new HashMap<>();
twoPhaseParaMap.put(RATIO_KEY, ratio);
twoPhaseParaMap.put(PruneUtils.PRUNE_RATIO_FIELD, ratio);
twoPhaseParaMap.put(EXPANSION_KEY, expand);
twoPhaseParaMap.put(MAX_WINDOW_SIZE_KEY, max_window);
configMap.put(PARAMETER_KEY, twoPhaseParaMap);
return factory.create(Collections.emptyMap(), null, null, false, configMap, null);
}

private NeuralSparseTwoPhaseProcessor createTestProcessor(
NeuralSparseTwoPhaseProcessor.Factory factory,
float ratio,
String type,
boolean enabled,
float expand,
int max_window
) throws Exception {
Map<String, Object> configMap = new HashMap<>();
configMap.put(ENABLE_KEY, enabled);
Map<String, Object> twoPhaseParaMap = new HashMap<>();
twoPhaseParaMap.put(PruneUtils.PRUNE_RATIO_FIELD, ratio);
twoPhaseParaMap.put(EXPANSION_KEY, expand);
twoPhaseParaMap.put(MAX_WINDOW_SIZE_KEY, max_window);
twoPhaseParaMap.put(PruneUtils.PRUNE_TYPE_FIELD, type);
configMap.put(PARAMETER_KEY, twoPhaseParaMap);
return factory.create(Collections.emptyMap(), null, null, false, configMap, null);
}
Expand All @@ -166,7 +212,7 @@ private NeuralSparseTwoPhaseProcessor createTestProcessor(NeuralSparseTwoPhasePr
Map<String, Object> configMap = new HashMap<>();
configMap.put(ENABLE_KEY, true);
Map<String, Object> twoPhaseParaMap = new HashMap<>();
twoPhaseParaMap.put(RATIO_KEY, 0.3f);
twoPhaseParaMap.put(PruneUtils.PRUNE_RATIO_FIELD, 0.3f);
twoPhaseParaMap.put(EXPANSION_KEY, 4.0f);
twoPhaseParaMap.put(MAX_WINDOW_SIZE_KEY, 10000);
configMap.put(PARAMETER_KEY, twoPhaseParaMap);
Expand Down

0 comments on commit 339222d

Please sign in to comment.