Skip to content

Commit

Permalink
rename
Browse files Browse the repository at this point in the history
Signed-off-by: zhichao-aws <zhichaog@amazon.com>
  • Loading branch information
zhichao-aws committed Nov 20, 2024
1 parent 302f949 commit 46b9d9a
Show file tree
Hide file tree
Showing 12 changed files with 73 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder;
import org.opensearch.neuralsearch.util.pruning.PruneType;
import org.opensearch.neuralsearch.util.pruning.PruneUtils;
import org.opensearch.neuralsearch.util.prune.PruneType;
import org.opensearch.neuralsearch.util.prune.PruneUtils;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.pipeline.AbstractProcessor;
import org.opensearch.search.pipeline.Processor;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
import org.opensearch.env.Environment;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.util.pruning.PruneType;
import org.opensearch.neuralsearch.util.prune.PruneType;
import org.opensearch.neuralsearch.util.TokenWeightUtil;

import lombok.extern.log4j.Log4j2;
import org.opensearch.neuralsearch.util.pruning.PruneUtils;
import org.opensearch.neuralsearch.util.prune.PruneUtils;

/**
* This processor is used for user input data text sparse encoding processing, model_id can be used to indicate which model user use,
Expand Down Expand Up @@ -62,7 +62,7 @@ public void doExecute(
mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> {
List<Map<String, Float>> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps);
sparseVectors = sparseVectors.stream()
.map(vector -> PruneUtils.pruningSparseVector(pruneType, pruneRatio, vector, false).v1())
.map(vector -> PruneUtils.pruneSparseVector(pruneType, pruneRatio, vector, false).v1())
.toList();
setVectorFieldsToDocument(ingestDocument, ProcessMap, sparseVectors);
handler.accept(ingestDocument, null);
Expand All @@ -74,7 +74,7 @@ public void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler
mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> {
List<Map<String, Float>> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps);
sparseVectors = sparseVectors.stream()
.map(vector -> PruneUtils.pruningSparseVector(pruneType, pruneRatio, vector, false).v1())
.map(vector -> PruneUtils.pruneSparseVector(pruneType, pruneRatio, vector, false).v1())
.toList();
handler.accept(sparseVectors);
}, onException));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import org.opensearch.neuralsearch.processor.SparseEncodingProcessor;

import lombok.extern.log4j.Log4j2;
import org.opensearch.neuralsearch.util.pruning.PruneUtils;
import org.opensearch.neuralsearch.util.pruning.PruneType;
import org.opensearch.neuralsearch.util.prune.PruneUtils;
import org.opensearch.neuralsearch.util.prune.PruneType;

/**
* Factory for sparse encoding ingest processor for ingestion pipeline. Instantiates processor based on user provided input.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@
import lombok.NoArgsConstructor;
import lombok.Setter;
import lombok.experimental.Accessors;
import org.opensearch.neuralsearch.util.pruning.PruneType;
import org.opensearch.neuralsearch.util.pruning.PruneUtils;
import org.opensearch.neuralsearch.util.prune.PruneType;
import org.opensearch.neuralsearch.util.prune.PruneUtils;

/**
* SparseEncodingQueryBuilder is responsible for handling "neural_sparse" query types. It uses an ML NEURAL_SPARSE model
Expand Down Expand Up @@ -146,7 +146,7 @@ public NeuralSparseQueryBuilder getCopyNeuralSparseQueryBuilderForTwoPhase(float
Map<String, Float> tokens = queryTokensSupplier.get();
// Splitting tokens based on a threshold value: tokens greater than the threshold are stored in v1,
// while those less than or equal to the threshold are stored in v2.
Tuple<Map<String, Float>, Map<String, Float>> splitTokens = PruneUtils.pruningSparseVector(pruneType, pruneRatio, tokens, true);
Tuple<Map<String, Float>, Map<String, Float>> splitTokens = PruneUtils.pruneSparseVector(pruneType, pruneRatio, tokens, true);
this.queryTokensSupplier(() -> splitTokens.v1());
copy.queryTokensSupplier(() -> splitTokens.v2());
} else {
Expand Down Expand Up @@ -348,7 +348,7 @@ private BiConsumer<Client, ActionListener<?>> getModelInferenceAsync(SetOnce<Map
ActionListener.wrap(mapResultList -> {
Map<String, Float> queryTokens = TokenWeightUtil.fetchListOfTokenWeightMap(mapResultList).get(0);
if (Objects.nonNull(twoPhaseSharedQueryToken)) {
Tuple<Map<String, Float>, Map<String, Float>> splitQueryTokens = PruneUtils.pruningSparseVector(
Tuple<Map<String, Float>, Map<String, Float>> splitQueryTokens = PruneUtils.pruneSparseVector(
twoPhasePruneType,
twoPhasePruneRatio,
queryTokens,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.util.pruning;
package org.opensearch.neuralsearch.util.prune;

import org.apache.commons.lang.StringUtils;

/**
* Enum representing different types of pruning methods for sparse vectors
* Enum representing different types of prune methods for sparse vectors
*/
public enum PruneType {
NONE("none"),
Expand All @@ -29,9 +29,9 @@ public String getValue() {
/**
* Get PruneType from string value
*
* @param value string representation of pruning type
* @param value string representation of prune type
* @return corresponding PruneType enum
* @throws IllegalArgumentException if value doesn't match any pruning type
* @throws IllegalArgumentException if value doesn't match any prune type
*/
public static PruneType fromString(String value) {
if (StringUtils.isEmpty(value)) return NONE;
Expand All @@ -40,6 +40,6 @@ public static PruneType fromString(String value) {
return type;
}
}
throw new IllegalArgumentException("Unknown pruning type: " + value);
throw new IllegalArgumentException("Unknown prune type: " + value);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.util.pruning;
package org.opensearch.neuralsearch.util.prune;

import org.opensearch.common.collect.Tuple;

Expand All @@ -15,8 +15,8 @@
import java.util.PriorityQueue;

/**
* Utility class providing methods for pruning sparse vectors using different strategies.
* Pruning helps reduce the dimensionality of sparse vectors by removing less significant elements
* Utility class providing methods for prune sparse vectors using different strategies.
* Prune helps reduce the dimensionality of sparse vectors by removing less significant elements
* based on various criteria.
*/
public class PruneUtils {
Expand All @@ -31,7 +31,7 @@ public class PruneUtils {
* @param requiresPrunedEntries Whether to return pruned entries
* @return A tuple containing two maps: the first with top K elements, the second with remaining elements (or null)
*/
private static Tuple<Map<String, Float>, Map<String, Float>> pruningByTopK(
private static Tuple<Map<String, Float>, Map<String, Float>> pruneByTopK(
Map<String, Float> sparseVector,
int k,
boolean requiresPrunedEntries
Expand Down Expand Up @@ -71,7 +71,7 @@ private static Tuple<Map<String, Float>, Map<String, Float>> pruningByTopK(
* @return A tuple containing two maps: the first with elements meeting the ratio threshold,
* the second with elements below the threshold (or null)
*/
private static Tuple<Map<String, Float>, Map<String, Float>> pruningByMaxRatio(
private static Tuple<Map<String, Float>, Map<String, Float>> pruneByMaxRatio(
Map<String, Float> sparseVector,
float ratio,
boolean requiresPrunedEntries
Expand Down Expand Up @@ -101,7 +101,7 @@ private static Tuple<Map<String, Float>, Map<String, Float>> pruningByMaxRatio(
* @return A tuple containing two maps: the first with elements above the threshold,
* the second with elements below the threshold (or null)
*/
private static Tuple<Map<String, Float>, Map<String, Float>> pruningByValue(
private static Tuple<Map<String, Float>, Map<String, Float>> pruneByValue(
Map<String, Float> sparseVector,
float thresh,
boolean requiresPrunedEntries
Expand Down Expand Up @@ -130,7 +130,7 @@ private static Tuple<Map<String, Float>, Map<String, Float>> pruningByValue(
* @return A tuple containing two maps: the first with elements meeting the alpha mass threshold,
* the second with remaining elements (or null)
*/
private static Tuple<Map<String, Float>, Map<String, Float>> pruningByAlphaMass(
private static Tuple<Map<String, Float>, Map<String, Float>> pruneByAlphaMass(
Map<String, Float> sparseVector,
float alpha,
boolean requiresPrunedEntries
Expand Down Expand Up @@ -159,16 +159,16 @@ private static Tuple<Map<String, Float>, Map<String, Float>> pruningByAlphaMass(
}

/**
* Prunes a sparse vector using the specified pruning type and ratio.
* Prunes a sparse vector using the specified prune type and ratio.
*
* @param pruneType The type of pruning strategy to use
* @param pruneRatio The ratio or threshold for pruning
* @param pruneType The type of prune strategy to use
* @param pruneRatio The ratio or threshold for prune
* @param sparseVector The input sparse vector as a map of string keys to float values
* @param requiresPrunedEntries Whether to return pruned entries
* @return A tuple containing two maps: the first with high-scoring elements,
* the second with low-scoring elements (or null if requiresPrunedEntries is false)
*/
public static Tuple<Map<String, Float>, Map<String, Float>> pruningSparseVector(
public static Tuple<Map<String, Float>, Map<String, Float>> pruneSparseVector(
PruneType pruneType,
float pruneRatio,
Map<String, Float> sparseVector,
Expand All @@ -190,29 +190,29 @@ public static Tuple<Map<String, Float>, Map<String, Float>> pruningSparseVector(

switch (pruneType) {
case TOP_K:
return pruningByTopK(sparseVector, (int) pruneRatio, requiresPrunedEntries);
return pruneByTopK(sparseVector, (int) pruneRatio, requiresPrunedEntries);
case ALPHA_MASS:
return pruningByAlphaMass(sparseVector, pruneRatio, requiresPrunedEntries);
return pruneByAlphaMass(sparseVector, pruneRatio, requiresPrunedEntries);
case MAX_RATIO:
return pruningByMaxRatio(sparseVector, pruneRatio, requiresPrunedEntries);
return pruneByMaxRatio(sparseVector, pruneRatio, requiresPrunedEntries);
case ABS_VALUE:
return pruningByValue(sparseVector, pruneRatio, requiresPrunedEntries);
return pruneByValue(sparseVector, pruneRatio, requiresPrunedEntries);
default:
return new Tuple<>(new HashMap<>(sparseVector), requiresPrunedEntries ? new HashMap<>() : null);
}
}

/**
* Validates whether a prune ratio is valid for a given pruning type.
* Validates whether a prune ratio is valid for a given prune type.
*
* @param pruneType The type of pruning strategy
* @param pruneType The type of prune strategy
* @param pruneRatio The ratio or threshold to validate
* @return true if the ratio is valid for the given pruning type, false otherwise
* @throws IllegalArgumentException if pruning type is null
* @return true if the ratio is valid for the given prune type, false otherwise
* @throws IllegalArgumentException if prune type is null
*/
public static boolean isValidPruneRatio(PruneType pruneType, float pruneRatio) {
if (pruneType == null) {
throw new IllegalArgumentException("Pruning type cannot be null");
throw new IllegalArgumentException("Prune type cannot be null");
}

switch (pruneType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder;
import org.opensearch.neuralsearch.util.pruning.PruneType;
import org.opensearch.neuralsearch.util.pruning.PruneUtils;
import org.opensearch.neuralsearch.util.prune.PruneType;
import org.opensearch.neuralsearch.util.prune.PruneUtils;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.rescore.QueryRescorerBuilder;
import org.opensearch.test.OpenSearchTestCase;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
import com.google.common.collect.ImmutableMap;

import lombok.SneakyThrows;
import org.opensearch.neuralsearch.util.pruning.PruneType;
import org.opensearch.neuralsearch.util.prune.PruneType;

public class SparseEncodingProcessorTests extends InferenceProcessorTestCase {
@Mock
Expand Down Expand Up @@ -275,7 +275,7 @@ public void test_batchExecute_exception() {
}

@SuppressWarnings("unchecked")
public void testExecute_withPruningConfig_successful() {
public void testExecute_withPruneConfig_successful() {
Map<String, Object> sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index");
sourceAndMetadata.put("key1", "value1");
Expand Down Expand Up @@ -317,7 +317,7 @@ public void testExecute_withPruningConfig_successful() {
assertEquals(0.4f, second.get("low"), 0.001f);
}

public void test_batchExecute_withPruning_successful() {
public void test_batchExecute_withPrune_successful() {
SparseEncodingProcessor processor = createInstance(PruneType.MAX_RATIO, 0.5f);

List<Map<String, ?>> mockMLResponse = Collections.singletonList(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.MODEL_ID_FIELD;
import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.FIELD_MAP_FIELD;
import static org.opensearch.neuralsearch.processor.SparseEncodingProcessor.TYPE;
import static org.opensearch.neuralsearch.util.pruning.PruneUtils.PRUNE_TYPE_FIELD;
import static org.opensearch.neuralsearch.util.pruning.PruneUtils.PRUNE_RATIO_FIELD;
import static org.opensearch.neuralsearch.util.prune.PruneUtils.PRUNE_TYPE_FIELD;
import static org.opensearch.neuralsearch.util.prune.PruneUtils.PRUNE_RATIO_FIELD;

import lombok.SneakyThrows;
import org.junit.Before;
Expand All @@ -18,7 +18,7 @@
import org.opensearch.env.Environment;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.SparseEncodingProcessor;
import org.opensearch.neuralsearch.util.pruning.PruneType;
import org.opensearch.neuralsearch.util.prune.PruneType;
import org.opensearch.test.OpenSearchTestCase;

import java.util.HashMap;
Expand Down Expand Up @@ -134,7 +134,7 @@ public void testCreateProcessor_whenInvalidPruneType_thenFail() {
IllegalArgumentException.class,
() -> sparseEncodingProcessorFactory.create(Map.of(), PROCESSOR_TAG, DESCRIPTION, config)
);
assertEquals("Unknown pruning type: invalid_prune_type", exception.getMessage());
assertEquals("Unknown prune type: invalid_prune_type", exception.getMessage());
}

@SneakyThrows
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.util.NeuralSearchClusterTestUtils;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;
import org.opensearch.neuralsearch.util.pruning.PruneType;
import org.opensearch.neuralsearch.util.prune.PruneType;
import org.opensearch.test.OpenSearchTestCase;

import lombok.SneakyThrows;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.util.pruning;
package org.opensearch.neuralsearch.util.prune;

import org.opensearch.test.OpenSearchTestCase;

Expand All @@ -25,6 +25,6 @@ public void testFromString() {
assertEquals(PruneType.ABS_VALUE, PruneType.fromString("abs_value"));

IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> PruneType.fromString("test_value"));
assertEquals("Unknown pruning type: test_value", exception.getMessage());
assertEquals("Unknown prune type: test_value", exception.getMessage());
}
}
Loading

0 comments on commit 46b9d9a

Please sign in to comment.