Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50243][SQL][Connect] Cached classloader for ArtifactManager #49007

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ class ArtifactManager(session: SparkSession) extends Logging {
*/
protected val sessionArtifactAdded = new AtomicBoolean(false)

protected val cachedClassLoader = new ThreadLocal[ClassLoader] {
override def initialValue(): ClassLoader = null
}

Comment on lines +91 to +94
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand why this is a ThreadLocal? Artifacts can be added from multiple threads so wouldn't the addition on one thread cause all the others cached classloaders to be invalidated?

private def withClassLoaderIfNeeded[T](f: => T): T = {
val log = s" classloader for session ${session.sessionUUID} because " +
s"alwaysApplyClassLoader=$alwaysApplyClassLoader, " +
Expand Down Expand Up @@ -203,6 +207,7 @@ class ArtifactManager(session: SparkSession) extends Logging {
allowOverwrite = true,
deleteSource = deleteStagedFile)
sessionArtifactAdded.set(true)
cachedClassLoader.remove()
} else {
val target = ArtifactUtils.concatenatePaths(artifactPath, normalizedRemoteRelativePath)
// Disallow overwriting with modified version
Expand All @@ -227,6 +232,7 @@ class ArtifactManager(session: SparkSession) extends Logging {
(SparkContextResourceType.JAR, normalizedRemoteRelativePath, fragment))
jarsList.add(normalizedRemoteRelativePath)
sessionArtifactAdded.set(true)
cachedClassLoader.remove()
} else if (normalizedRemoteRelativePath.startsWith(s"pyfiles${File.separator}")) {
session.sparkContext.addFile(uri)
sparkContextRelativePaths.add(
Expand Down Expand Up @@ -282,10 +288,20 @@ class ArtifactManager(session: SparkSession) extends Logging {
}
}

def classloader: ClassLoader = {
if (cachedClassLoader.get() != null) {
cachedClassLoader.get()
} else {
val loader = buildClassLoader
cachedClassLoader.set(loader)
loader
}
}

/**
* Returns a [[ClassLoader]] for session-specific jar/class file resources.
*/
def classloader: ClassLoader = {
private def buildClassLoader: ClassLoader = {
val urls = (getAddedJars :+ classDir.toUri.toURL).toArray
val prefixes = SparkEnv.get.conf.get(CONNECT_SCALA_UDF_STUB_PREFIXES)
val userClasspathFirst = SparkEnv.get.conf.get(EXECUTOR_USER_CLASS_PATH_FIRST)
Expand Down Expand Up @@ -395,6 +411,9 @@ class ArtifactManager(session: SparkSession) extends Logging {
pythonIncludeList.clear()
cachedBlockIdList.clear()
sparkContextRelativePaths.clear()

// Removed cached classloader
cachedClassLoader.remove()
}

def uploadArtifactToFs(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,16 +184,18 @@ class DatasetOptimizationSuite extends QueryTest with SharedSparkSession {
assert(count3 == count2)
}

withClue("array type") {
checkCodegenCache(() => Seq(Seq("abc")).toDS())
}
withSQLConf(SQLConf.ARTIFACTS_SESSION_ISOLATION_ALWAYS_APPLY_CLASSLOADER.key -> "true") {
withClue("array type") {
checkCodegenCache(() => Seq(Seq("abc")).toDS())
}

withClue("map type") {
checkCodegenCache(() => Seq(Map("abc" -> 1)).toDS())
}
withClue("map type") {
checkCodegenCache(() => Seq(Map("abc" -> 1)).toDS())
}

withClue("array of map") {
checkCodegenCache(() => Seq(Seq(Map("abc" -> 1))).toDS())
withClue("array of map") {
checkCodegenCache(() => Seq(Seq(Map("abc" -> 1))).toDS())
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.nio.file.{Files, Path, Paths}
import org.apache.commons.io.FileUtils

import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.metrics.source.CodegenMetrics
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -447,4 +448,58 @@ class ArtifactManagerSuite extends SharedSparkSession {
assert(msg == "Hello Talon! Nice to meet you!")
}
}

test("Codegen cache should be invalid when artifacts are added - class artifact") {
withTempDir { dir =>
runCodegenTest("class artifact") {
val randomFilePath = dir.toPath.resolve("random.class")
val testBytes = "test".getBytes(StandardCharsets.UTF_8)
Files.write(randomFilePath, testBytes)
spark.addArtifact(randomFilePath.toString)
}
}
}

test("Codegen cache should be invalid when artifacts are added - JAR artifact") {
withTempDir { dir =>
runCodegenTest("JAR artifact") {
val randomFilePath = dir.toPath.resolve("random.jar")
val testBytes = "test".getBytes(StandardCharsets.UTF_8)
Files.write(randomFilePath, testBytes)
spark.addArtifact(randomFilePath.toString)
}
}
}

private def getCodegenCount: Long = CodegenMetrics.METRIC_COMPILATION_TIME.getCount

private def runCodegenTest(msg: String)(addOneArtifact: => Unit): Unit = {
withSQLConf(SQLConf.ARTIFACTS_SESSION_ISOLATION_ALWAYS_APPLY_CLASSLOADER.key -> "true") {
val s = spark
import s.implicits._

val count1 = getCodegenCount
// trigger codegen for Dataset
Seq(Seq("abc")).toDS().collect()
val count2 = getCodegenCount
// codegen happens
assert(count2 > count1, s"$msg: codegen should happen at the first time")

// add one artifact, codegen cache should be invalid after this
addOneArtifact

// trigger codegen for another Dataset of same type
Seq(Seq("abc")).toDS().collect()
// codegen cache should not work for Datasets of same type.
val count3 = getCodegenCount
assert(count3 > count2, s"$msg: codegen should happen again after adding artifact")

// trigger again
Seq(Seq("abc")).toDS().collect()
// codegen should work now as classloader is not changed
val count4 = getCodegenCount
assert(count4 == count3,
s"$msg: codegen should not happen again as classloader is not changed")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,7 @@ class AdaptiveQueryExecSuite
// so retry several times here to avoid unit test failure.
eventually(timeout(15.seconds), interval(500.milliseconds)) {
withSQLConf(
SQLConf.ARTIFACTS_SESSION_ISOLATION_ALWAYS_APPLY_CLASSLOADER.key -> "true",
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.NON_EMPTY_PARTITION_RATIO_FOR_BROADCAST_JOIN.key -> "0.5") {
// `testData` is small enough to be broadcast but has empty partition ratio over the config.
Expand Down