-
Notifications
You must be signed in to change notification settings - Fork 67
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Enhancement] Implement pruning for neural sparse search #988
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: zhichao-aws <zhichaog@amazon.com>
Signed-off-by: zhichao-aws <zhichaog@amazon.com>
1e55b7c
to
46b9d9a
Compare
This PR is ready for review now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you provide an overview of how the overall API will look? I initially thought this change would only affect the query side, but it seems it will also modify the parameters for neural_sparse_two_phase_processor
.
Additionally, the current implementation appears to be focused on two-phase processing with different strategies for splitting vectors, rather than a combination of pruning and two-phase processing?
src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java
Outdated
Show resolved
Hide resolved
src/main/java/org/opensearch/neuralsearch/util/prune/PruneUtils.java
Outdated
Show resolved
Hide resolved
Based on our benchmark results in #946 , when searching, applying prune to 2-phase search has superseded applying it to neural sparse query body, on both precision and latency. Therefore, enhancing the existing 2-phase search pipeline makes more sense.
The existing two-phase use max_ratio prune criteria. And now we add supports for other criteria as well |
Signed-off-by: zhichao-aws <zhichaog@amazon.com>
Signed-off-by: zhichao-aws <zhichaog@amazon.com>
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #988 +/- ##
============================================
+ Coverage 78.46% 81.07% +2.61%
- Complexity 1027 1053 +26
============================================
Files 85 80 -5
Lines 3617 3519 -98
Branches 604 610 +6
============================================
+ Hits 2838 2853 +15
+ Misses 529 424 -105
+ Partials 250 242 -8 ☔ View full report in Codecov by Sentry. 🚨 Try these New Features:
|
Signed-off-by: zhichao-aws <zhichaog@amazon.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apart from minor comment, why this PR is trying to merge into main
?
If this changes API that used to define the processor, it should be checked with application security and for that we need to merge to feature branch in main repo, and only after that's cleared from feature branch to main.
} | ||
if (!PruneUtils.isValidPruneRatio(pruneType, pruneRatio)) throw new IllegalArgumentException( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you please mark block of code with curly braces and use String.format to form the error message
@@ -49,17 +60,19 @@ public void doExecute( | |||
BiConsumer<IngestDocument, Exception> handler | |||
) { | |||
mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> { | |||
setVectorFieldsToDocument(ingestDocument, ProcessMap, TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)); | |||
List<Map<String, Float>> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
List<Map<String, Float>> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps); | |
List<Map<String, Float>> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps).stream().map(vector -> PruneUtils.pruneSparseVector(pruneType, pruneRatio, vector)).toList(); |
// if we have prune type, then prune ratio field must have value | ||
// readDoubleProperty will throw exception if value is not present | ||
pruneRatio = readDoubleProperty(TYPE, tag, config, PruneUtils.PRUNE_RATIO_FIELD).floatValue(); | ||
if (!PruneUtils.isValidPruneRatio(pruneType, pruneRatio)) throw new IllegalArgumentException( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment as before - please put throw exception code into a block marked with "{ }", and use String.format to format exception message
); | ||
} else { | ||
// if we don't have prune type, then prune ratio field must not have value | ||
if (config.containsKey(PruneUtils.PRUNE_RATIO_FIELD)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can merge this if with a previous else
and have one single else if
block
return type; | ||
} | ||
} | ||
throw new IllegalArgumentException("Unknown prune type: " + value); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please use String.format
* the second with low-scoring elements | ||
*/ | ||
public static Tuple<Map<String, Float>, Map<String, Float>> splitSparseVector( | ||
PruneType pruneType, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please use final
for arguments of public method. Same for other methods in this class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should introduce some static checker to enforce this :)
|
||
switch (pruneType) { | ||
case TOP_K: | ||
return pruneByTopK(sparseVector, (int) pruneRatio, true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it possible to move this type case inside the pruneByTopK
method?
} | ||
} | ||
|
||
switch (pruneType) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you think of modifying this into a map of <prune_type> -> <functional_interface>, so instead of switch structure we use map.get()?
|
||
switch (pruneType) { | ||
case TOP_K: | ||
return pruneRatio > 0 && pruneRatio == Math.floor(pruneRatio); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return pruneRatio > 0 && pruneRatio == Math.floor(pruneRatio); | |
return pruneRatio > 0 && pruneRatio == Math.rint(pruneRatio); |
this is more reliable for float numbers, otherwise there is a chance of false positive
} | ||
} | ||
|
||
switch (pruneType) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above, can we use map instead of switch?
@martin-gaievski Thanks for the comments. We didn't create feature branch because there is no other contributors working on this and we regard the PR branch as feature branch. I'm on PTO this week, will follow the app sec issue and solve the comments next week. |
Description
Implement prune for sparse vectors, to save disk space and accelerate search speed with small loss on search relevance. #946
Related Issues
#946
Check List
--signoff
.By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.
For more information on following Developer Certificate of Origin and signing off your commits, please check here.