diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index afc0a2d7df604..a7f85db12b214 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql import java.net.URI import java.nio.file.Paths import java.util.{ServiceLoader, UUID} -import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicBoolean +import scala.collection.mutable import scala.concurrent.duration.DurationInt import scala.jdk.CollectionConverters._ import scala.reflect.runtime.universe.TypeTag @@ -133,14 +133,34 @@ class SparkSession private( /** Tag to mark all jobs owned by this session. */ private[sql] lazy val sessionJobTag = s"spark-session-$sessionUUID" + /** + * A UUID that is unique on the thread level. Used by managedJobTags to make sure that a same + * tag from two threads does not overlap in the underlying SparkContext/SQLExecution. + */ + private[sql] lazy val threadUuid = new InheritableThreadLocal[String] { + override def childValue(parent: String): String = parent + + override def initialValue(): String = UUID.randomUUID().toString + } + /** * A map to hold the mapping from user-defined tags to the real tags attached to Jobs. - * Real tag have the current session ID attached: `"tag1" -> s"spark-session-$sessionUUID-tag1"`. + * Real tag have the current session ID attached: + * tag1 -> spark-session-$sessionUUID-thread-$threadUuid-tag1 + * */ @transient - private[sql] lazy val managedJobTags: ConcurrentHashMap[String, String] = { - new ConcurrentHashMap(parentManagedJobTags.asJava) - } + private[sql] lazy val managedJobTags = new InheritableThreadLocal[mutable.Map[String, String]] { + override def childValue(parent: mutable.Map[String, String]): mutable.Map[String, String] = { + // Note: make a clone such that changes in the parent tags aren't reflected in + // those of the children threads. + parent.clone() + } + + override def initialValue(): mutable.Map[String, String] = { + mutable.Map(parentManagedJobTags.toSeq: _*) + } + } /** @inheritdoc */ def version: String = SPARK_VERSION @@ -243,10 +263,10 @@ class SparkSession private( Some(sessionState), extensions, Map.empty, - managedJobTags.asScala.toMap) + managedJobTags.get().toMap) result.sessionState // force copy of SessionState result.sessionState.artifactManager // force copy of ArtifactManager and its resources - result.managedJobTags // force copy of userDefinedToRealTagsMap + result.managedJobTags // force copy of managedJobTags result } @@ -550,17 +570,17 @@ class SparkSession private( /** @inheritdoc */ override def addTag(tag: String): Unit = { SparkContext.throwIfInvalidTag(tag) - managedJobTags.put(tag, s"spark-session-$sessionUUID-$tag") + managedJobTags.get().put(tag, s"spark-session-$sessionUUID-thread-${threadUuid.get()}-$tag") } /** @inheritdoc */ - override def removeTag(tag: String): Unit = managedJobTags.remove(tag) + override def removeTag(tag: String): Unit = managedJobTags.get().remove(tag) /** @inheritdoc */ - override def getTags(): Set[String] = managedJobTags.keys().asScala.toSet + override def getTags(): Set[String] = managedJobTags.get().keySet.toSet /** @inheritdoc */ - override def clearTags(): Unit = managedJobTags.clear() + override def clearTags(): Unit = managedJobTags.get().clear() /** * Request to interrupt all currently running SQL operations of this session. @@ -589,9 +609,8 @@ class SparkSession private( * @since 4.0.0 */ override def interruptTag(tag: String): Seq[String] = { - val realTag = managedJobTags.get(tag) - if (realTag == null) return Seq.empty - doInterruptTag(realTag, s"part of cancelled job tags $tag") + val realTag = managedJobTags.get().get(tag) + realTag.map(doInterruptTag(_, s"part of cancelled job tags $tag")).getOrElse(Seq.empty) } private def doInterruptTag(tag: String, reason: String): Seq[String] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index e805aabe013cf..242149010ceef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -261,7 +261,7 @@ object SQLExecution extends Logging { } private[sql] def withSessionTagsApplied[T](sparkSession: SparkSession)(block: => T): T = { - val allTags = sparkSession.managedJobTags.values().asScala.toSet + sparkSession.sessionJobTag + val allTags = sparkSession.managedJobTags.get().values.toSet + sparkSession.sessionJobTag sparkSession.sparkContext.addJobTags(allTags) try { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala index 1ac51b408301a..89500fe51f3ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.util.concurrent.{ConcurrentHashMap, Semaphore, TimeUnit} +import java.util.concurrent.{ConcurrentHashMap, Executors, Semaphore, TimeUnit} import java.util.concurrent.atomic.AtomicInteger import scala.concurrent.{ExecutionContext, Future} @@ -100,13 +100,14 @@ class SparkSessionJobTaggingAndCancellationSuite assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES)) val activeJobsFuture = - session.sparkContext.cancelJobsWithTagWithFuture(session.managedJobTags.get("one"), "reason") + session.sparkContext.cancelJobsWithTagWithFuture( + session.managedJobTags.get()("one"), "reason") val activeJob = ThreadUtils.awaitResult(activeJobsFuture, 60.seconds).head val actualTags = activeJob.properties.getProperty(SparkContext.SPARK_JOB_TAGS) .split(SparkContext.SPARK_JOB_TAGS_SEP) assert(actualTags.toSet == Set( session.sessionJobTag, - s"${session.sessionJobTag}-one", + s"${session.sessionJobTag}-thread-${session.threadUuid.get()}-one", SQLExecution.executionIdJobTag( session, activeJob.properties.get(SQLExecution.EXECUTION_ROOT_ID_KEY).asInstanceOf[String].toLong))) @@ -118,12 +119,12 @@ class SparkSessionJobTaggingAndCancellationSuite val globalSession = SparkSession.builder().sparkContext(sc).getOrCreate() var (sessionA, sessionB, sessionC): (SparkSession, SparkSession, SparkSession) = (null, null, null) + var (threadUuidA, threadUuidB, threadUuidC): (String, String, String) = (null, null, null) // global ExecutionContext has only 2 threads in Apache Spark CI // create own thread pool for four Futures used in this test - val numThreads = 3 - val fpool = ThreadUtils.newForkJoinPool("job-tags-test-thread-pool", numThreads) - val executionContext = ExecutionContext.fromExecutorService(fpool) + val threadPool = Executors.newFixedThreadPool(3) + implicit val ec: ExecutionContext = ExecutionContext.fromExecutor(threadPool) try { // Add a listener to release the semaphore once jobs are launched. @@ -143,28 +144,50 @@ class SparkSessionJobTaggingAndCancellationSuite } }) + var realTagOneForSessionA: String = null + var childThread: Thread = null + val childThreadLock = new Semaphore(0) + // Note: since tags are added in the Future threads, they don't need to be cleared in between. val jobA = Future { sessionA = globalSession.cloneSession() import globalSession.implicits._ + threadUuidA = sessionA.threadUuid.get() assert(sessionA.getTags() == Set()) sessionA.addTag("two") assert(sessionA.getTags() == Set("two")) sessionA.clearTags() // check that clearing all tags works assert(sessionA.getTags() == Set()) sessionA.addTag("one") + realTagOneForSessionA = sessionA.managedJobTags.get()("one") + assert(realTagOneForSessionA == + s"${sessionA.sessionJobTag}-thread-${sessionA.threadUuid.get()}-one") assert(sessionA.getTags() == Set("one")) + + // Create a child thread which inherits thread-local variables and tries to interrupt + // the job started from the parent thread. The child thread is blocked until the main + // thread releases the lock. + childThread = new Thread { + override def run(): Unit = { + assert(childThreadLock.tryAcquire(1, 20, TimeUnit.SECONDS)) + assert(sessionA.getTags() == Set("one")) + assert(sessionA.interruptTag("one").size == 1) + } + } + childThread.start() try { sessionA.range(1, 10000).map { i => Thread.sleep(100); i }.count() } finally { + childThread.interrupt() sessionA.clearTags() // clear for the case of thread reuse by another Future } - }(executionContext) + } val jobB = Future { sessionB = globalSession.cloneSession() import globalSession.implicits._ + threadUuidB = sessionB.threadUuid.get() assert(sessionB.getTags() == Set()) sessionB.addTag("one") sessionB.addTag("two") @@ -176,11 +199,12 @@ class SparkSessionJobTaggingAndCancellationSuite } finally { sessionB.clearTags() // clear for the case of thread reuse by another Future } - }(executionContext) + } val jobC = Future { sessionC = globalSession.cloneSession() import globalSession.implicits._ + threadUuidC = sessionC.threadUuid.get() sessionC.addTag("foo") sessionC.removeTag("foo") assert(sessionC.getTags() == Set()) // check that remove works removing the last tag @@ -190,12 +214,13 @@ class SparkSessionJobTaggingAndCancellationSuite } finally { sessionC.clearTags() // clear for the case of thread reuse by another Future } - }(executionContext) + } // Block until four jobs have started. assert(sem.tryAcquire(3, 1, TimeUnit.MINUTES)) // Tags are applied + def realUserTag(s: String, t: String, ta: String): String = s"spark-session-$s-thread-$t-$ta" assert(jobProperties.size == 3) for (ss <- Seq(sessionA, sessionB, sessionC)) { val jobProperty = jobProperties.values().asScala.filter(_.get(SparkContext.SPARK_JOB_TAGS) @@ -207,15 +232,17 @@ class SparkSessionJobTaggingAndCancellationSuite val executionRootIdTag = SQLExecution.executionIdJobTag( ss, jobProperty.head.get(SQLExecution.EXECUTION_ROOT_ID_KEY).asInstanceOf[String].toLong) - val userTagsPrefix = s"spark-session-${ss.sessionUUID}-" ss match { case s if s == sessionA => assert(tags.toSet == Set( - s.sessionJobTag, executionRootIdTag, s"${userTagsPrefix}one")) + s.sessionJobTag, executionRootIdTag, realUserTag(s.sessionUUID, threadUuidA, "one"))) case s if s == sessionB => assert(tags.toSet == Set( - s.sessionJobTag, executionRootIdTag, s"${userTagsPrefix}one", s"${userTagsPrefix}two")) + s.sessionJobTag, + executionRootIdTag, + realUserTag(s.sessionUUID, threadUuidB, "one"), + realUserTag(s.sessionUUID, threadUuidB, "two"))) case s if s == sessionC => assert(tags.toSet == Set( - s.sessionJobTag, executionRootIdTag, s"${userTagsPrefix}boo")) + s.sessionJobTag, executionRootIdTag, realUserTag(s.sessionUUID, threadUuidC, "boo"))) } } @@ -239,8 +266,10 @@ class SparkSessionJobTaggingAndCancellationSuite assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES)) assert(jobEnded.intValue == 1) - // Another job cancelled - assert(sessionA.interruptTag("one").size == 1) + // Another job cancelled. The next line cancels nothing because we're now in another thread. + // The real cancel is done through unblocking a child thread, which is waiting for a lock + assert(sessionA.interruptTag("one").isEmpty) + childThreadLock.release() val eA = intercept[SparkException] { ThreadUtils.awaitResult(jobA, 1.minute) }.getCause @@ -257,7 +286,48 @@ class SparkSessionJobTaggingAndCancellationSuite assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES)) assert(jobEnded.intValue == 3) } finally { - fpool.shutdownNow() + threadPool.shutdownNow() + } + } + + test("Tags are isolated in multithreaded environment") { + // Custom thread pool for multi-threaded testing + val threadPool = Executors.newFixedThreadPool(2) + implicit val ec: ExecutionContext = ExecutionContext.fromExecutor(threadPool) + + val session = SparkSession.builder().master("local").getOrCreate() + @volatile var output1: Set[String] = null + @volatile var output2: Set[String] = null + + def tag1(): Unit = { + session.addTag("tag1") + output1 = session.getTags() + } + + def tag2(): Unit = { + session.addTag("tag2") + output2 = session.getTags() + } + + try { + // Run tasks in separate threads + val future1 = Future { + tag1() + } + val future2 = Future { + tag2() + } + + // Wait for threads to complete + ThreadUtils.awaitResult(Future.sequence(Seq(future1, future2)), 1.minute) + + // Assert outputs + assert(output1 != null) + assert(output1 == Set("tag1")) + assert(output2 != null) + assert(output2 == Set("tag2")) + } finally { + threadPool.shutdownNow() } } }