Skip to content

Commit

Permalink
[SPARK-50366][SQL] Isolate user-defined tags on thread level for Spar…
Browse files Browse the repository at this point in the history
…kSession in Classic

### What changes were proposed in this pull request?

This PR changes the implementation of user-provided tags to be thread-local, so that tags added by two threads to the same SparkSession do not interfere with each other.

Overlaps (from the `SparkContext` perspective) are avoided by introducing a thread-local random UUID which is attached to all tags in the same thread.

### Why are the changes needed?

To make tags isolated per thread.

### Does this PR introduce _any_ user-facing change?

Yes, user-provided tags are now isolated on the session level.

### How was this patch tested?

Local test.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #48906 from xupefei/thread-isolated-tags.

Authored-by: Paddy Xu <xupaddy@gmail.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
xupefei authored and HyukjinKwon committed Nov 22, 2024
1 parent d8a6075 commit 6881ec0
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 31 deletions.
47 changes: 33 additions & 14 deletions sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ package org.apache.spark.sql
import java.net.URI
import java.nio.file.Paths
import java.util.{ServiceLoader, UUID}
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicBoolean

import scala.collection.mutable
import scala.concurrent.duration.DurationInt
import scala.jdk.CollectionConverters._
import scala.reflect.runtime.universe.TypeTag
Expand Down Expand Up @@ -133,14 +133,34 @@ class SparkSession private(
/** Tag to mark all jobs owned by this session. */
private[sql] lazy val sessionJobTag = s"spark-session-$sessionUUID"

/**
* A UUID that is unique on the thread level. Used by managedJobTags to make sure that a same
* tag from two threads does not overlap in the underlying SparkContext/SQLExecution.
*/
private[sql] lazy val threadUuid = new InheritableThreadLocal[String] {
override def childValue(parent: String): String = parent

override def initialValue(): String = UUID.randomUUID().toString
}

/**
* A map to hold the mapping from user-defined tags to the real tags attached to Jobs.
* Real tag have the current session ID attached: `"tag1" -> s"spark-session-$sessionUUID-tag1"`.
* Real tag have the current session ID attached:
* tag1 -> spark-session-$sessionUUID-thread-$threadUuid-tag1
*
*/
@transient
private[sql] lazy val managedJobTags: ConcurrentHashMap[String, String] = {
new ConcurrentHashMap(parentManagedJobTags.asJava)
}
private[sql] lazy val managedJobTags = new InheritableThreadLocal[mutable.Map[String, String]] {
override def childValue(parent: mutable.Map[String, String]): mutable.Map[String, String] = {
// Note: make a clone such that changes in the parent tags aren't reflected in
// those of the children threads.
parent.clone()
}

override def initialValue(): mutable.Map[String, String] = {
mutable.Map(parentManagedJobTags.toSeq: _*)
}
}

/** @inheritdoc */
def version: String = SPARK_VERSION
Expand Down Expand Up @@ -243,10 +263,10 @@ class SparkSession private(
Some(sessionState),
extensions,
Map.empty,
managedJobTags.asScala.toMap)
managedJobTags.get().toMap)
result.sessionState // force copy of SessionState
result.sessionState.artifactManager // force copy of ArtifactManager and its resources
result.managedJobTags // force copy of userDefinedToRealTagsMap
result.managedJobTags // force copy of managedJobTags
result
}

Expand Down Expand Up @@ -550,17 +570,17 @@ class SparkSession private(
/** @inheritdoc */
override def addTag(tag: String): Unit = {
SparkContext.throwIfInvalidTag(tag)
managedJobTags.put(tag, s"spark-session-$sessionUUID-$tag")
managedJobTags.get().put(tag, s"spark-session-$sessionUUID-thread-${threadUuid.get()}-$tag")
}

/** @inheritdoc */
override def removeTag(tag: String): Unit = managedJobTags.remove(tag)
override def removeTag(tag: String): Unit = managedJobTags.get().remove(tag)

/** @inheritdoc */
override def getTags(): Set[String] = managedJobTags.keys().asScala.toSet
override def getTags(): Set[String] = managedJobTags.get().keySet.toSet

/** @inheritdoc */
override def clearTags(): Unit = managedJobTags.clear()
override def clearTags(): Unit = managedJobTags.get().clear()

/**
* Request to interrupt all currently running SQL operations of this session.
Expand Down Expand Up @@ -589,9 +609,8 @@ class SparkSession private(
* @since 4.0.0
*/
override def interruptTag(tag: String): Seq[String] = {
val realTag = managedJobTags.get(tag)
if (realTag == null) return Seq.empty
doInterruptTag(realTag, s"part of cancelled job tags $tag")
val realTag = managedJobTags.get().get(tag)
realTag.map(doInterruptTag(_, s"part of cancelled job tags $tag")).getOrElse(Seq.empty)
}

private def doInterruptTag(tag: String, reason: String): Seq[String] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ object SQLExecution extends Logging {
}

private[sql] def withSessionTagsApplied[T](sparkSession: SparkSession)(block: => T): T = {
val allTags = sparkSession.managedJobTags.values().asScala.toSet + sparkSession.sessionJobTag
val allTags = sparkSession.managedJobTags.get().values.toSet + sparkSession.sessionJobTag
sparkSession.sparkContext.addJobTags(allTags)

try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql

import java.util.concurrent.{ConcurrentHashMap, Semaphore, TimeUnit}
import java.util.concurrent.{ConcurrentHashMap, Executors, Semaphore, TimeUnit}
import java.util.concurrent.atomic.AtomicInteger

import scala.concurrent.{ExecutionContext, Future}
Expand Down Expand Up @@ -100,13 +100,14 @@ class SparkSessionJobTaggingAndCancellationSuite

assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES))
val activeJobsFuture =
session.sparkContext.cancelJobsWithTagWithFuture(session.managedJobTags.get("one"), "reason")
session.sparkContext.cancelJobsWithTagWithFuture(
session.managedJobTags.get()("one"), "reason")
val activeJob = ThreadUtils.awaitResult(activeJobsFuture, 60.seconds).head
val actualTags = activeJob.properties.getProperty(SparkContext.SPARK_JOB_TAGS)
.split(SparkContext.SPARK_JOB_TAGS_SEP)
assert(actualTags.toSet == Set(
session.sessionJobTag,
s"${session.sessionJobTag}-one",
s"${session.sessionJobTag}-thread-${session.threadUuid.get()}-one",
SQLExecution.executionIdJobTag(
session,
activeJob.properties.get(SQLExecution.EXECUTION_ROOT_ID_KEY).asInstanceOf[String].toLong)))
Expand All @@ -118,12 +119,12 @@ class SparkSessionJobTaggingAndCancellationSuite
val globalSession = SparkSession.builder().sparkContext(sc).getOrCreate()
var (sessionA, sessionB, sessionC): (SparkSession, SparkSession, SparkSession) =
(null, null, null)
var (threadUuidA, threadUuidB, threadUuidC): (String, String, String) = (null, null, null)

// global ExecutionContext has only 2 threads in Apache Spark CI
// create own thread pool for four Futures used in this test
val numThreads = 3
val fpool = ThreadUtils.newForkJoinPool("job-tags-test-thread-pool", numThreads)
val executionContext = ExecutionContext.fromExecutorService(fpool)
val threadPool = Executors.newFixedThreadPool(3)
implicit val ec: ExecutionContext = ExecutionContext.fromExecutor(threadPool)

try {
// Add a listener to release the semaphore once jobs are launched.
Expand All @@ -143,28 +144,50 @@ class SparkSessionJobTaggingAndCancellationSuite
}
})

var realTagOneForSessionA: String = null
var childThread: Thread = null
val childThreadLock = new Semaphore(0)

// Note: since tags are added in the Future threads, they don't need to be cleared in between.
val jobA = Future {
sessionA = globalSession.cloneSession()
import globalSession.implicits._

threadUuidA = sessionA.threadUuid.get()
assert(sessionA.getTags() == Set())
sessionA.addTag("two")
assert(sessionA.getTags() == Set("two"))
sessionA.clearTags() // check that clearing all tags works
assert(sessionA.getTags() == Set())
sessionA.addTag("one")
realTagOneForSessionA = sessionA.managedJobTags.get()("one")
assert(realTagOneForSessionA ==
s"${sessionA.sessionJobTag}-thread-${sessionA.threadUuid.get()}-one")
assert(sessionA.getTags() == Set("one"))

// Create a child thread which inherits thread-local variables and tries to interrupt
// the job started from the parent thread. The child thread is blocked until the main
// thread releases the lock.
childThread = new Thread {
override def run(): Unit = {
assert(childThreadLock.tryAcquire(1, 20, TimeUnit.SECONDS))
assert(sessionA.getTags() == Set("one"))
assert(sessionA.interruptTag("one").size == 1)
}
}
childThread.start()
try {
sessionA.range(1, 10000).map { i => Thread.sleep(100); i }.count()
} finally {
childThread.interrupt()
sessionA.clearTags() // clear for the case of thread reuse by another Future
}
}(executionContext)
}
val jobB = Future {
sessionB = globalSession.cloneSession()
import globalSession.implicits._

threadUuidB = sessionB.threadUuid.get()
assert(sessionB.getTags() == Set())
sessionB.addTag("one")
sessionB.addTag("two")
Expand All @@ -176,11 +199,12 @@ class SparkSessionJobTaggingAndCancellationSuite
} finally {
sessionB.clearTags() // clear for the case of thread reuse by another Future
}
}(executionContext)
}
val jobC = Future {
sessionC = globalSession.cloneSession()
import globalSession.implicits._

threadUuidC = sessionC.threadUuid.get()
sessionC.addTag("foo")
sessionC.removeTag("foo")
assert(sessionC.getTags() == Set()) // check that remove works removing the last tag
Expand All @@ -190,12 +214,13 @@ class SparkSessionJobTaggingAndCancellationSuite
} finally {
sessionC.clearTags() // clear for the case of thread reuse by another Future
}
}(executionContext)
}

// Block until four jobs have started.
assert(sem.tryAcquire(3, 1, TimeUnit.MINUTES))

// Tags are applied
def realUserTag(s: String, t: String, ta: String): String = s"spark-session-$s-thread-$t-$ta"
assert(jobProperties.size == 3)
for (ss <- Seq(sessionA, sessionB, sessionC)) {
val jobProperty = jobProperties.values().asScala.filter(_.get(SparkContext.SPARK_JOB_TAGS)
Expand All @@ -207,15 +232,17 @@ class SparkSessionJobTaggingAndCancellationSuite
val executionRootIdTag = SQLExecution.executionIdJobTag(
ss,
jobProperty.head.get(SQLExecution.EXECUTION_ROOT_ID_KEY).asInstanceOf[String].toLong)
val userTagsPrefix = s"spark-session-${ss.sessionUUID}-"

ss match {
case s if s == sessionA => assert(tags.toSet == Set(
s.sessionJobTag, executionRootIdTag, s"${userTagsPrefix}one"))
s.sessionJobTag, executionRootIdTag, realUserTag(s.sessionUUID, threadUuidA, "one")))
case s if s == sessionB => assert(tags.toSet == Set(
s.sessionJobTag, executionRootIdTag, s"${userTagsPrefix}one", s"${userTagsPrefix}two"))
s.sessionJobTag,
executionRootIdTag,
realUserTag(s.sessionUUID, threadUuidB, "one"),
realUserTag(s.sessionUUID, threadUuidB, "two")))
case s if s == sessionC => assert(tags.toSet == Set(
s.sessionJobTag, executionRootIdTag, s"${userTagsPrefix}boo"))
s.sessionJobTag, executionRootIdTag, realUserTag(s.sessionUUID, threadUuidC, "boo")))
}
}

Expand All @@ -239,8 +266,10 @@ class SparkSessionJobTaggingAndCancellationSuite
assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES))
assert(jobEnded.intValue == 1)

// Another job cancelled
assert(sessionA.interruptTag("one").size == 1)
// Another job cancelled. The next line cancels nothing because we're now in another thread.
// The real cancel is done through unblocking a child thread, which is waiting for a lock
assert(sessionA.interruptTag("one").isEmpty)
childThreadLock.release()
val eA = intercept[SparkException] {
ThreadUtils.awaitResult(jobA, 1.minute)
}.getCause
Expand All @@ -257,7 +286,48 @@ class SparkSessionJobTaggingAndCancellationSuite
assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES))
assert(jobEnded.intValue == 3)
} finally {
fpool.shutdownNow()
threadPool.shutdownNow()
}
}

test("Tags are isolated in multithreaded environment") {
// Custom thread pool for multi-threaded testing
val threadPool = Executors.newFixedThreadPool(2)
implicit val ec: ExecutionContext = ExecutionContext.fromExecutor(threadPool)

val session = SparkSession.builder().master("local").getOrCreate()
@volatile var output1: Set[String] = null
@volatile var output2: Set[String] = null

def tag1(): Unit = {
session.addTag("tag1")
output1 = session.getTags()
}

def tag2(): Unit = {
session.addTag("tag2")
output2 = session.getTags()
}

try {
// Run tasks in separate threads
val future1 = Future {
tag1()
}
val future2 = Future {
tag2()
}

// Wait for threads to complete
ThreadUtils.awaitResult(Future.sequence(Seq(future1, future2)), 1.minute)

// Assert outputs
assert(output1 != null)
assert(output1 == Set("tag1"))
assert(output2 != null)
assert(output2 == Set("tag2"))
} finally {
threadPool.shutdownNow()
}
}
}

0 comments on commit 6881ec0

Please sign in to comment.