From ec425e4402ee9fb82dbeaf33559bf8b68e996de0 Mon Sep 17 00:00:00 2001 From: kul Date: Fri, 10 May 2024 15:30:02 +0530 Subject: [PATCH] Introduce paritionByRequest and bypassLimitByPredicate functions --- .../jakarta/ServletLimiterBuilder.java | 20 +++++++++ .../ConcurrencyLimitServletFilterTest.java | 19 ++++++++ .../limits/GroupServletLimiterTest.java | 45 +++++++++++++++++++ 3 files changed, 84 insertions(+) diff --git a/concurrency-limits-servlet-jakarta/src/main/java/com/netflix/concurrency/limits/servlet/jakarta/ServletLimiterBuilder.java b/concurrency-limits-servlet-jakarta/src/main/java/com/netflix/concurrency/limits/servlet/jakarta/ServletLimiterBuilder.java index f400bb51..8300f16a 100644 --- a/concurrency-limits-servlet-jakarta/src/main/java/com/netflix/concurrency/limits/servlet/jakarta/ServletLimiterBuilder.java +++ b/concurrency-limits-servlet-jakarta/src/main/java/com/netflix/concurrency/limits/servlet/jakarta/ServletLimiterBuilder.java @@ -64,6 +64,16 @@ public ServletLimiterBuilder partitionByParameter(String name) { return partitionResolver(request -> Optional.ofNullable(request.getParameter(name)).orElse(null)); } + /** + * Partition the limit by the request instance. Percentages of the limit are partitioned to named + * groups. Group membership is derived from the provided mapping function. + * @param requestToGroup Mapping function from request to a named group. + * @return Chainable builder + */ + public ServletLimiterBuilder partitionByRequest(Function requestToGroup) { + return partitionResolver(request -> Optional.ofNullable(request).map(requestToGroup).orElse(null)); + } + /** * Partition the limit by the full path. Percentages of the limit are partitioned to named * groups. Group membership is derived from the provided mapping function. @@ -142,6 +152,16 @@ public ServletLimiterBuilder bypassLimitByMethod(String method) { return bypassLimitResolver((context) -> method.equals(context.getMethod())); } + /** + * Bypass limit if the predicate function returns true. + * @param predicate The predicate function to which {@link HttpServletRequest } instance is passed. + * If the predicate return true, the limit will be bypassed. + * @return Chainable builder + */ + public ServletLimiterBuilder bypassLimitByPredicate(Function predicate) { + return bypassLimitResolver((context) -> predicate.apply(context)); + } + @Override protected ServletLimiterBuilder self() { return this; diff --git a/concurrency-limits-servlet-jakarta/src/test/java/com/netflix/concurrency/limits/ConcurrencyLimitServletFilterTest.java b/concurrency-limits-servlet-jakarta/src/test/java/com/netflix/concurrency/limits/ConcurrencyLimitServletFilterTest.java index 701a9a9e..d2f2e9a5 100644 --- a/concurrency-limits-servlet-jakarta/src/test/java/com/netflix/concurrency/limits/ConcurrencyLimitServletFilterTest.java +++ b/concurrency-limits-servlet-jakarta/src/test/java/com/netflix/concurrency/limits/ConcurrencyLimitServletFilterTest.java @@ -40,6 +40,7 @@ public void beforeEachTest() { limiter = Mockito.spy(new ServletLimiterBuilder() .bypassLimitByMethod("GET") .bypassLimitByPathInfo("/admin/health") + .bypassLimitByPredicate(ctx -> ctx.getMethod().equals("PATCH")) .named(testName.getMethodName()) .metricRegistry(spectatorMetricRegistry) .build()); @@ -130,6 +131,24 @@ public void testDoFilterBypassCheckPassedForPath() throws ServletException, IOEx verifyCounts(0, 0, 0, 0, 1); } + @Test + public void testDoFilterBypassCheckPassedForPredicate() throws ServletException, IOException { + + ConcurrencyLimitServletFilter filter = new ConcurrencyLimitServletFilter(limiter); + + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setMethod("PATCH"); + request.setPathInfo("/admin/patch"); + MockHttpServletResponse response = new MockHttpServletResponse(); + MockFilterChain filterChain = new MockFilterChain(); + + filter.doFilter(request, response, filterChain); + + assertEquals("Request should be passed to the downstream chain", request, filterChain.getRequest()); + assertEquals("Response should be passed to the downstream chain", response, filterChain.getResponse()); + verifyCounts(0, 0, 0, 0, 1); + } + @Test public void testDoFilterBypassCheckFailed() throws ServletException, IOException { diff --git a/concurrency-limits-servlet-jakarta/src/test/java/com/netflix/concurrency/limits/GroupServletLimiterTest.java b/concurrency-limits-servlet-jakarta/src/test/java/com/netflix/concurrency/limits/GroupServletLimiterTest.java index 24ed04b2..fc670001 100644 --- a/concurrency-limits-servlet-jakarta/src/test/java/com/netflix/concurrency/limits/GroupServletLimiterTest.java +++ b/concurrency-limits-servlet-jakarta/src/test/java/com/netflix/concurrency/limits/GroupServletLimiterTest.java @@ -154,6 +154,44 @@ public void nullPathDoesNotMatchesGroup() { Mockito.verify(pathToGroup, Mockito.times(0)).get(Mockito.any()); } + @Test + public void requestMatchesGroup() { + Map requestMethodToGroup = Mockito.spy(new HashMap<>()); + requestMethodToGroup.put("PATCH", "live"); + + Limiter limiter = new ServletLimiterBuilder() + .limit(VegasLimit.newDefault()) + .partitionByRequest(request -> requestMethodToGroup.get(request.getMethod())) + .partition("live", 0.8) + .partition("batch", 0.2) + .build(); + + HttpServletRequest request = createMockRequestWithType("PATCH"); + Optional listener = limiter.acquire(request); + + Assert.assertTrue(listener.isPresent()); + Mockito.verify(requestMethodToGroup, Mockito.times(1)).get("PATCH"); + } + + @Test + public void requestDoesNotMatchesGroup() { + Map requestMethodToGroup = Mockito.spy(new HashMap<>()); + requestMethodToGroup.put("PATCH", "live"); + + Limiter limiter = new ServletLimiterBuilder() + .limit(VegasLimit.newDefault()) + .partitionByRequest(request -> requestMethodToGroup.get(request.getMethod())) + .partition("live", 0.8) + .partition("batch", 0.2) + .build(); + + HttpServletRequest request = createMockRequestWithType("PUT"); + Optional listener = limiter.acquire(request); + + Assert.assertTrue(listener.isPresent()); + Mockito.verify(requestMethodToGroup, Mockito.times(1)).get("PUT"); + } + private HttpServletRequest createMockRequestWithPrincipal(String name) { HttpServletRequest request = Mockito.mock(HttpServletRequest.class); Principal principal = Mockito.mock(Principal.class); @@ -169,4 +207,11 @@ private HttpServletRequest createMockRequestWithPathInfo(String name) { Mockito.when(request.getPathInfo()).thenReturn(name); return request; } + + private HttpServletRequest createMockRequestWithType(String type) { + HttpServletRequest request = Mockito.mock(HttpServletRequest.class); + + Mockito.when(request.getMethod()).thenReturn(type); + return request; + } }