Skip to content

Commit

Permalink
[SPARK-45851][CONNECT][SCALA] Support multiple policies in scala client
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Support multiple retry policies defined at the same time. Each policy determines which error types it can retry and how exactly those should be spread out.

Scala parity for #43591

### Why are the changes needed?

Different error types should be treated differently For instance, networking connectivity errors and remote resources being initialized should be treated separately.

### Does this PR introduce _any_ user-facing change?
No (as long as user doesn't poke within client internals).

### How was this patch tested?
Unit tests, some hand testing.

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #43757 from cdkrot/SPARK-45851-scala-multiple-policies.

Authored-by: Alice Sayutina <alice.sayutina@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
cdkrot authored and HyukjinKwon committed Nov 16, 2023
1 parent 1a65175 commit 182e2d2
Show file tree
Hide file tree
Showing 11 changed files with 311 additions and 144 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ class ArtifactSuite extends ConnectFunSuite with BeforeAndAfterEach {
private var server: Server = _
private var artifactManager: ArtifactManager = _
private var channel: ManagedChannel = _
private var retryPolicy: GrpcRetryHandler.RetryPolicy = _
private var bstub: CustomSparkConnectBlockingStub = _
private var stub: CustomSparkConnectStub = _
private var state: SparkConnectStubState = _
Expand All @@ -58,8 +57,7 @@ class ArtifactSuite extends ConnectFunSuite with BeforeAndAfterEach {

private def createArtifactManager(): Unit = {
channel = InProcessChannelBuilder.forName(getClass.getName).directExecutor().build()
retryPolicy = GrpcRetryHandler.RetryPolicy()
state = new SparkConnectStubState(channel, retryPolicy)
state = new SparkConnectStubState(channel, RetryPolicy.defaultPolicies())
bstub = new CustomSparkConnectBlockingStub(channel, state)
stub = new CustomSparkConnectStub(channel, state)
artifactManager = new ArtifactManager(Configuration(), "", bstub, stub)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach {
client = SparkConnectClient
.builder()
.connectionString(s"sc://localhost:${server.getPort}/;use_ssl=true")
.retryPolicy(GrpcRetryHandler.RetryPolicy(maxRetries = 0))
.retryPolicy(RetryPolicy(maxRetries = Some(0), canRetry = _ => false, name = "TestPolicy"))
.build()

val request = AnalyzePlanRequest.newBuilder().setSessionId("abc123").build()
Expand Down Expand Up @@ -311,7 +311,7 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach {
}
}

private class DummyFn(val e: Throwable, numFails: Int = 3) {
private class DummyFn(e: => Throwable, numFails: Int = 3) {
var counter = 0
def fn(): Int = {
if (counter < numFails) {
Expand All @@ -333,9 +333,9 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach {
}

val dummyFn = new DummyFn(new StatusRuntimeException(Status.UNAVAILABLE), numFails = 100)
val retryHandler = new GrpcRetryHandler(GrpcRetryHandler.RetryPolicy(), sleep)
val retryHandler = new GrpcRetryHandler(RetryPolicy.defaultPolicies(), sleep)

assertThrows[StatusRuntimeException] {
assertThrows[RetriesExceeded] {
retryHandler.retry {
dummyFn.fn()
}
Expand All @@ -347,8 +347,8 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach {

test("SPARK-44275: retry actually retries") {
val dummyFn = new DummyFn(new StatusRuntimeException(Status.UNAVAILABLE))
val retryPolicy = GrpcRetryHandler.RetryPolicy()
val retryHandler = new GrpcRetryHandler(retryPolicy)
val retryPolicies = RetryPolicy.defaultPolicies()
val retryHandler = new GrpcRetryHandler(retryPolicies, sleep = _ => {})
val result = retryHandler.retry { dummyFn.fn() }

assert(result == 42)
Expand All @@ -357,8 +357,8 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach {

test("SPARK-44275: default retryException retries only on UNAVAILABLE") {
val dummyFn = new DummyFn(new StatusRuntimeException(Status.ABORTED))
val retryPolicy = GrpcRetryHandler.RetryPolicy()
val retryHandler = new GrpcRetryHandler(retryPolicy)
val retryPolicies = RetryPolicy.defaultPolicies()
val retryHandler = new GrpcRetryHandler(retryPolicies, sleep = _ => {})

assertThrows[StatusRuntimeException] {
retryHandler.retry { dummyFn.fn() }
Expand All @@ -368,7 +368,7 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach {

test("SPARK-44275: retry uses canRetry to filter exceptions") {
val dummyFn = new DummyFn(new StatusRuntimeException(Status.UNAVAILABLE))
val retryPolicy = GrpcRetryHandler.RetryPolicy(canRetry = _ => false)
val retryPolicy = RetryPolicy(canRetry = _ => false, name = "TestPolicy")
val retryHandler = new GrpcRetryHandler(retryPolicy)

assertThrows[StatusRuntimeException] {
Expand All @@ -379,15 +379,62 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach {

test("SPARK-44275: retry does not exceed maxRetries") {
val dummyFn = new DummyFn(new StatusRuntimeException(Status.UNAVAILABLE))
val retryPolicy = GrpcRetryHandler.RetryPolicy(canRetry = _ => true, maxRetries = 1)
val retryHandler = new GrpcRetryHandler(retryPolicy)
val retryPolicy = RetryPolicy(canRetry = _ => true, maxRetries = Some(1), name = "TestPolicy")
val retryHandler = new GrpcRetryHandler(retryPolicy, sleep = _ => {})

assertThrows[StatusRuntimeException] {
assertThrows[RetriesExceeded] {
retryHandler.retry { dummyFn.fn() }
}
assert(dummyFn.counter == 2)
}

def testPolicySpecificError(maxRetries: Int, status: Status): RetryPolicy = {
RetryPolicy(
maxRetries = Some(maxRetries),
name = s"Policy for ${status.getCode}",
canRetry = {
case e: StatusRuntimeException => e.getStatus.getCode == status.getCode
case _ => false
})
}

test("Test multiple policies") {
val policy1 = testPolicySpecificError(maxRetries = 2, status = Status.UNAVAILABLE)
val policy2 = testPolicySpecificError(maxRetries = 4, status = Status.INTERNAL)

// Tolerate 2 UNAVAILABLE errors and 4 INTERNAL errors

val errors = (List.fill(2)(Status.UNAVAILABLE) ++ List.fill(4)(Status.INTERNAL)).iterator

new GrpcRetryHandler(List(policy1, policy2), sleep = _ => {}).retry({
val e = errors.nextOption()
if (e.isDefined) {
throw e.get.asRuntimeException()
}
})

assert(!errors.hasNext)
}

test("Test multiple policies exceed") {
val policy1 = testPolicySpecificError(maxRetries = 2, status = Status.INTERNAL)
val policy2 = testPolicySpecificError(maxRetries = 4, status = Status.INTERNAL)

val errors = List.fill(10)(Status.INTERNAL).iterator
var countAttempted = 0

assertThrows[RetriesExceeded](
new GrpcRetryHandler(List(policy1, policy2), sleep = _ => {}).retry({
countAttempted += 1
val e = errors.nextOption()
if (e.isDefined) {
throw e.get.asRuntimeException()
}
}))

assert(countAttempted == 7)
}

test("SPARK-45871: Client execute iterator.toSeq consumes the reattachable iterator") {
startDummyServer(0)
client = SparkConnectClient
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package org.apache.spark.sql.test

import java.io.{File, IOException, OutputStream}
import java.lang.ProcessBuilder
import java.lang.ProcessBuilder.Redirect
import java.nio.file.Paths
import java.util.concurrent.TimeUnit
Expand All @@ -28,7 +27,7 @@ import org.scalatest.BeforeAndAfterAll

import org.apache.spark.SparkBuildInfo
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connect.client.GrpcRetryHandler.RetryPolicy
import org.apache.spark.sql.connect.client.RetryPolicy
import org.apache.spark.sql.connect.client.SparkConnectClient
import org.apache.spark.sql.connect.common.config.ConnectCommon
import org.apache.spark.sql.test.IntegrationTestUtils._
Expand Down Expand Up @@ -189,7 +188,9 @@ object SparkConnectServerUtils {
.builder()
.userId("test")
.port(port)
.retryPolicy(RetryPolicy(maxRetries = 7, maxBackoff = FiniteDuration(10, "s")))
.retryPolicy(RetryPolicy
.defaultPolicy()
.copy(maxRetries = Some(7), maxBackoff = Some(FiniteDuration(10, "s"))))
.build())
.create()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ private[connect] class CustomSparkConnectBlockingStub(
request.getSessionId,
request.getUserContext,
request.getClientType,
// Don't use retryHandler - own retry handling is inside.
stubState.responseValidator.wrapIterator(
new ExecutePlanResponseReattachableIterator(request, channel, stubState.retryPolicy)))
// ExecutePlanResponseReattachableIterator does all retries by itself, don't wrap it here
new ExecutePlanResponseReattachableIterator(request, channel, stubState.retryHandler)))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import io.grpc.stub.StreamObserver

import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
import org.apache.spark.sql.connect.client.GrpcRetryHandler.RetryException

/**
* Retryable iterator of ExecutePlanResponses to an ExecutePlan call.
Expand All @@ -50,10 +51,15 @@ import org.apache.spark.internal.Logging
class ExecutePlanResponseReattachableIterator(
request: proto.ExecutePlanRequest,
channel: ManagedChannel,
retryPolicy: GrpcRetryHandler.RetryPolicy)
retryHandler: GrpcRetryHandler)
extends WrappedCloseableIterator[proto.ExecutePlanResponse]
with Logging {

/**
* Retries the given function with exponential backoff according to the client's retryPolicy.
*/
private def retry[T](fn: => T): T = retryHandler.retry(fn)

val operationId = if (request.hasOperationId) {
request.getOperationId
} else {
Expand Down Expand Up @@ -236,7 +242,7 @@ class ExecutePlanResponseReattachableIterator(
}
// Try a new ExecutePlan, and throw upstream for retry.
iter = Some(rawBlockingStub.executePlan(initialRequest))
val error = new GrpcRetryHandler.RetryException()
val error = new RetryException()
error.addSuppressed(ex)
throw error
case NonFatal(e) =>
Expand Down Expand Up @@ -319,12 +325,6 @@ class ExecutePlanResponseReattachableIterator(

release.build()
}

/**
* Retries the given function with exponential backoff according to the client's retryPolicy.
*/
private def retry[T](fn: => T): T =
GrpcRetryHandler.retry(retryPolicy)(fn)
}

private[connect] object ExecutePlanResponseReattachableIterator {
Expand Down
Loading

0 comments on commit 182e2d2

Please sign in to comment.