Skip to content

Commit

Permalink
[SPARK-50224][SQL] The replacements of IsValidUTF8|ValidateUTF8|TryVa…
Browse files Browse the repository at this point in the history
…lidateUTF8|MakeValidUTF8 shall be NullIntolerant

### What changes were proposed in this pull request?

This PR makes replacements of IsValidUTF8|ValidateUTF8|TryValidateUTF8|MakeValidUTF8 functions to be NullIntolerant deriving from their origins so that we can actually construct IsNotNull constraints for them.

This is also a common issue for other RuntimeReplaceable expressions, I will revisit them in groups. SPARK-50223.

### Why are the changes needed?

Common strategy for performance improvement.

### Does this PR introduce _any_ user-facing change?
NO

### How was this patch tested?

new tests

### Was this patch authored or co-authored using generative AI tooling?

no

Closes #48758 from yaooqinn/SPARK-50224.

Authored-by: Kent Yao <yao@apache.org>
Signed-off-by: Kent Yao <yao@apache.org>
  • Loading branch information
yaooqinn committed Nov 5, 2024
1 parent 47063a6 commit 642a62b
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,28 @@ case class StaticInvoke(
}.$functionName(${arguments.mkString(", ")}))"
}

object StaticInvoke {
def withNullIntolerant(
staticObject: Class[_],
dataType: DataType,
functionName: String,
arguments: Seq[Expression] = Nil,
inputTypes: Seq[AbstractDataType] = Nil,
propagateNull: Boolean = true,
returnNullable: Boolean = true,
isDeterministic: Boolean = true,
scalarFunction: Option[ScalarFunction[_]] = None): StaticInvoke =
new StaticInvoke(
staticObject,
dataType,
functionName,
arguments,
inputTypes,
propagateNull,
returnNullable,
isDeterministic, scalarFunction) with NullIntolerant
}

/**
* Calls the specified function on an object, optionally passing arguments. If the `targetObject`
* expression evaluates to null then null will be returned.
Expand Down Expand Up @@ -555,6 +577,27 @@ case class Invoke(
copy(targetObject = newChildren.head, arguments = newChildren.tail)
}

object Invoke {
def withNullIntolerant(
targetObject: Expression,
functionName: String,
dataType: DataType,
arguments: Seq[Expression] = Nil,
methodInputTypes: Seq[AbstractDataType] = Nil,
propagateNull: Boolean = true,
returnNullable: Boolean = true,
isDeterministic: Boolean = true): Invoke =
new Invoke(
targetObject,
functionName,
dataType,
arguments,
methodInputTypes,
propagateNull,
returnNullable,
isDeterministic) with NullIntolerant
}

object NewInstance {
def apply(
cls: Class[_],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,8 @@ case class EndsWith(left: Expression, right: Expression) extends StringPredicate
case class IsValidUTF8(input: Expression) extends RuntimeReplaceable with ImplicitCastInputTypes
with UnaryLike[Expression] with NullIntolerant {

override lazy val replacement: Expression = Invoke(input, "isValid", BooleanType)
override lazy val replacement: Expression =
Invoke.withNullIntolerant(input, "isValid", BooleanType)

override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCollation(supportsTrimCollation = true))
Expand Down Expand Up @@ -795,7 +796,8 @@ case class IsValidUTF8(input: Expression) extends RuntimeReplaceable with Implic
case class MakeValidUTF8(input: Expression) extends RuntimeReplaceable with ImplicitCastInputTypes
with UnaryLike[Expression] with NullIntolerant {

override lazy val replacement: Expression = Invoke(input, "makeValid", input.dataType)
override lazy val replacement: Expression =
Invoke.withNullIntolerant(input, "makeValid", input.dataType)

override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCollation(supportsTrimCollation = true))
Expand Down Expand Up @@ -836,12 +838,13 @@ case class MakeValidUTF8(input: Expression) extends RuntimeReplaceable with Impl
case class ValidateUTF8(input: Expression) extends RuntimeReplaceable with ImplicitCastInputTypes
with UnaryLike[Expression] with NullIntolerant {

override lazy val replacement: Expression = StaticInvoke(
classOf[ExpressionImplUtils],
input.dataType,
"validateUTF8String",
Seq(input),
inputTypes)
override lazy val replacement: Expression =
StaticInvoke.withNullIntolerant(
classOf[ExpressionImplUtils],
input.dataType,
"validateUTF8String",
Seq(input),
inputTypes)

override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCollation(supportsTrimCollation = true))
Expand Down Expand Up @@ -886,12 +889,13 @@ case class ValidateUTF8(input: Expression) extends RuntimeReplaceable with Impli
case class TryValidateUTF8(input: Expression) extends RuntimeReplaceable with ImplicitCastInputTypes
with UnaryLike[Expression] with NullIntolerant {

override lazy val replacement: Expression = StaticInvoke(
classOf[ExpressionImplUtils],
input.dataType,
"tryValidateUTF8String",
Seq(input),
inputTypes)
override lazy val replacement: Expression =
StaticInvoke.withNullIntolerant(
classOf[ExpressionImplUtils],
input.dataType,
"tryValidateUTF8String",
Seq(input),
inputTypes)

override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCollation(supportsTrimCollation = true))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ package org.apache.spark.sql

import org.apache.spark.{SPARK_DOC_ROOT, SparkIllegalArgumentException, SparkRuntimeException}
import org.apache.spark.sql.catalyst.expressions.Cast._
import org.apache.spark.sql.execution.{FormattedMode, WholeStageCodegenExec}
import org.apache.spark.sql.catalyst.expressions.IsNotNull
import org.apache.spark.sql.execution.{FilterExec, FormattedMode, WholeStageCodegenExec}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
Expand Down Expand Up @@ -1424,4 +1425,31 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession {
}
}
}

test("SPARK-50224: The replacement of validate utf8 functions should be NullIntolerant") {
def check(df: DataFrame, expected: Seq[Row]): Unit = {
val filter = df.queryExecution
.sparkPlan
.find(_.isInstanceOf[FilterExec])
.get.asInstanceOf[FilterExec]
assert(filter.condition.find(_.isInstanceOf[IsNotNull]).nonEmpty)
checkAnswer(df, expected)
}
withTable("test_table") {
sql("CREATE TABLE test_table" +
" AS SELECT * FROM VALUES ('abc', 'def'), ('ghi', 'jkl'), ('mno', NULL) T(a, b)")
check(
sql("SELECT * FROM test_table WHERE is_valid_utf8(b)"),
Seq(Row("abc", "def"), Row("ghi", "jkl")))
check(
sql("SELECT * FROM test_table WHERE make_valid_utf8(b) = 'def'"),
Seq(Row("abc", "def")))
check(
sql("SELECT * FROM test_table WHERE validate_utf8(b) = 'jkl'"),
Seq(Row("ghi", "jkl")))
check(
sql("SELECT * FROM test_table WHERE try_validate_utf8(b) = 'def'"),
Seq(Row("abc", "def")))
}
}
}

0 comments on commit 642a62b

Please sign in to comment.