diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 5c786bc5ddbfa..f49fd697492a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -412,6 +412,28 @@ case class StaticInvoke( }.$functionName(${arguments.mkString(", ")}))" } +object StaticInvoke { + def withNullIntolerant( + staticObject: Class[_], + dataType: DataType, + functionName: String, + arguments: Seq[Expression] = Nil, + inputTypes: Seq[AbstractDataType] = Nil, + propagateNull: Boolean = true, + returnNullable: Boolean = true, + isDeterministic: Boolean = true, + scalarFunction: Option[ScalarFunction[_]] = None): StaticInvoke = + new StaticInvoke( + staticObject, + dataType, + functionName, + arguments, + inputTypes, + propagateNull, + returnNullable, + isDeterministic, scalarFunction) with NullIntolerant +} + /** * Calls the specified function on an object, optionally passing arguments. If the `targetObject` * expression evaluates to null then null will be returned. @@ -555,6 +577,27 @@ case class Invoke( copy(targetObject = newChildren.head, arguments = newChildren.tail) } +object Invoke { + def withNullIntolerant( + targetObject: Expression, + functionName: String, + dataType: DataType, + arguments: Seq[Expression] = Nil, + methodInputTypes: Seq[AbstractDataType] = Nil, + propagateNull: Boolean = true, + returnNullable: Boolean = true, + isDeterministic: Boolean = true): Invoke = + new Invoke( + targetObject, + functionName, + dataType, + arguments, + methodInputTypes, + propagateNull, + returnNullable, + isDeterministic) with NullIntolerant +} + object NewInstance { def apply( cls: Class[_], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 2452da5d69682..8e8d3a9574667 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -747,7 +747,8 @@ case class EndsWith(left: Expression, right: Expression) extends StringPredicate case class IsValidUTF8(input: Expression) extends RuntimeReplaceable with ImplicitCastInputTypes with UnaryLike[Expression] with NullIntolerant { - override lazy val replacement: Expression = Invoke(input, "isValid", BooleanType) + override lazy val replacement: Expression = + Invoke.withNullIntolerant(input, "isValid", BooleanType) override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation(supportsTrimCollation = true)) @@ -795,7 +796,8 @@ case class IsValidUTF8(input: Expression) extends RuntimeReplaceable with Implic case class MakeValidUTF8(input: Expression) extends RuntimeReplaceable with ImplicitCastInputTypes with UnaryLike[Expression] with NullIntolerant { - override lazy val replacement: Expression = Invoke(input, "makeValid", input.dataType) + override lazy val replacement: Expression = + Invoke.withNullIntolerant(input, "makeValid", input.dataType) override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation(supportsTrimCollation = true)) @@ -836,12 +838,13 @@ case class MakeValidUTF8(input: Expression) extends RuntimeReplaceable with Impl case class ValidateUTF8(input: Expression) extends RuntimeReplaceable with ImplicitCastInputTypes with UnaryLike[Expression] with NullIntolerant { - override lazy val replacement: Expression = StaticInvoke( - classOf[ExpressionImplUtils], - input.dataType, - "validateUTF8String", - Seq(input), - inputTypes) + override lazy val replacement: Expression = + StaticInvoke.withNullIntolerant( + classOf[ExpressionImplUtils], + input.dataType, + "validateUTF8String", + Seq(input), + inputTypes) override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation(supportsTrimCollation = true)) @@ -886,12 +889,13 @@ case class ValidateUTF8(input: Expression) extends RuntimeReplaceable with Impli case class TryValidateUTF8(input: Expression) extends RuntimeReplaceable with ImplicitCastInputTypes with UnaryLike[Expression] with NullIntolerant { - override lazy val replacement: Expression = StaticInvoke( - classOf[ExpressionImplUtils], - input.dataType, - "tryValidateUTF8String", - Seq(input), - inputTypes) + override lazy val replacement: Expression = + StaticInvoke.withNullIntolerant( + classOf[ExpressionImplUtils], + input.dataType, + "tryValidateUTF8String", + Seq(input), + inputTypes) override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation(supportsTrimCollation = true)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 54d9fdbf8c231..2e91d60e4ba04 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql import org.apache.spark.{SPARK_DOC_ROOT, SparkIllegalArgumentException, SparkRuntimeException} import org.apache.spark.sql.catalyst.expressions.Cast._ -import org.apache.spark.sql.execution.{FormattedMode, WholeStageCodegenExec} +import org.apache.spark.sql.catalyst.expressions.IsNotNull +import org.apache.spark.sql.execution.{FilterExec, FormattedMode, WholeStageCodegenExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -1424,4 +1425,31 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { } } } + + test("SPARK-50224: The replacement of validate utf8 functions should be NullIntolerant") { + def check(df: DataFrame, expected: Seq[Row]): Unit = { + val filter = df.queryExecution + .sparkPlan + .find(_.isInstanceOf[FilterExec]) + .get.asInstanceOf[FilterExec] + assert(filter.condition.find(_.isInstanceOf[IsNotNull]).nonEmpty) + checkAnswer(df, expected) + } + withTable("test_table") { + sql("CREATE TABLE test_table" + + " AS SELECT * FROM VALUES ('abc', 'def'), ('ghi', 'jkl'), ('mno', NULL) T(a, b)") + check( + sql("SELECT * FROM test_table WHERE is_valid_utf8(b)"), + Seq(Row("abc", "def"), Row("ghi", "jkl"))) + check( + sql("SELECT * FROM test_table WHERE make_valid_utf8(b) = 'def'"), + Seq(Row("abc", "def"))) + check( + sql("SELECT * FROM test_table WHERE validate_utf8(b) = 'jkl'"), + Seq(Row("ghi", "jkl"))) + check( + sql("SELECT * FROM test_table WHERE try_validate_utf8(b) = 'def'"), + Seq(Row("abc", "def"))) + } + } }