Skip to content

Commit

Permalink
[SPARK-50081][SQL] Codegen Support for XPath*(by Invoke & RuntimeRe…
Browse files Browse the repository at this point in the history
…placeable)

### What changes were proposed in this pull request?
The pr aims to add `Codegen` Support for `xpath*`, include:
- `xpath_boolean`
- `xpath_short`
- `xpath_int`
- `xpath_long`
- `xpath_float`
- `xpath_double`
- `xpath_string`
- `xpath`

### Why are the changes needed?
- improve codegen coverage.
- simplified code.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Pass GA & Existed UT (eg: `XPathFunctionsSuite`, `XPathExpressionSuite`, `CollationSQLExpressionsSuite`#`*XPath*`, `CollationExpressionWalkerSuite`)

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #48610 from panbingkun/xpath_codegen.

Lead-authored-by: panbingkun <panbingkun@baidu.com>
Co-authored-by: panbingkun <panbingkun@apache.org>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
  • Loading branch information
2 people authored and MaxGekk committed Nov 23, 2024
1 parent 779a526 commit 656ece1
Show file tree
Hide file tree
Showing 12 changed files with 162 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

package org.apache.spark.sql.catalyst.expressions.xml

import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.catalyst.xml.XmlInferSchema
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ArrayType, DataType, StructType}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

object XmlExpressionEvalUtils {
Expand All @@ -40,3 +41,82 @@ object XmlExpressionEvalUtils {
UTF8String.fromString(dataType.sql)
}
}

trait XPathEvaluator {

protected val path: UTF8String

@transient protected lazy val xpathUtil: UDFXPathUtil = new UDFXPathUtil

final def evaluate(xml: UTF8String): Any = {
if (xml == null || xml.toString.isEmpty || path == null || path.toString.isEmpty) return null
doEvaluate(xml)
}

def doEvaluate(xml: UTF8String): Any
}

case class XPathBooleanEvaluator(path: UTF8String) extends XPathEvaluator {
override def doEvaluate(xml: UTF8String): Any = {
xpathUtil.evalBoolean(xml.toString, path.toString)
}
}

case class XPathShortEvaluator(path: UTF8String) extends XPathEvaluator {
override def doEvaluate(xml: UTF8String): Any = {
val ret = xpathUtil.evalNumber(xml.toString, path.toString)
if (ret eq null) null.asInstanceOf[Short] else ret.shortValue()
}
}

case class XPathIntEvaluator(path: UTF8String) extends XPathEvaluator {
override def doEvaluate(xml: UTF8String): Any = {
val ret = xpathUtil.evalNumber(xml.toString, path.toString)
if (ret eq null) null.asInstanceOf[Int] else ret.intValue()
}
}

case class XPathLongEvaluator(path: UTF8String) extends XPathEvaluator {
override def doEvaluate(xml: UTF8String): Any = {
val ret = xpathUtil.evalNumber(xml.toString, path.toString)
if (ret eq null) null.asInstanceOf[Long] else ret.longValue()
}
}

case class XPathFloatEvaluator(path: UTF8String) extends XPathEvaluator {
override def doEvaluate(xml: UTF8String): Any = {
val ret = xpathUtil.evalNumber(xml.toString, path.toString)
if (ret eq null) null.asInstanceOf[Float] else ret.floatValue()
}
}

case class XPathDoubleEvaluator(path: UTF8String) extends XPathEvaluator {
override def doEvaluate(xml: UTF8String): Any = {
val ret = xpathUtil.evalNumber(xml.toString, path.toString)
if (ret eq null) null.asInstanceOf[Double] else ret.doubleValue()
}
}

case class XPathStringEvaluator(path: UTF8String) extends XPathEvaluator {
override def doEvaluate(xml: UTF8String): Any = {
val ret = xpathUtil.evalString(xml.toString, path.toString)
UTF8String.fromString(ret)
}
}

case class XPathListEvaluator(path: UTF8String) extends XPathEvaluator {
override def doEvaluate(xml: UTF8String): Any = {
val nodeList = xpathUtil.evalNodeList(xml.toString, path.toString)
if (nodeList ne null) {
val ret = new Array[AnyRef](nodeList.getLength)
var i = 0
while (i < nodeList.getLength) {
ret(i) = UTF8String.fromString(nodeList.item(i).getNodeValue)
i += 1
}
new GenericArrayData(ret)
} else {
null
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Cast._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.StringTypeWithCollation
import org.apache.spark.sql.types._
Expand All @@ -34,10 +33,9 @@ import org.apache.spark.unsafe.types.UTF8String
* This is not the world's most efficient implementation due to type conversion, but works.
*/
abstract class XPathExtract
extends BinaryExpression with ExpectsInputTypes with CodegenFallback {
extends BinaryExpression with RuntimeReplaceable with ExpectsInputTypes {
override def left: Expression = xml
override def right: Expression = path
override def nullIntolerant: Boolean = true

/** XPath expressions are always nullable, e.g. if the xml string is empty. */
override def nullable: Boolean = true
Expand All @@ -60,12 +58,20 @@ abstract class XPathExtract
}
}

@transient protected lazy val xpathUtil = new UDFXPathUtil
@transient protected lazy val pathString: String = path.eval().asInstanceOf[UTF8String].toString

/** Concrete implementations need to override the following three methods. */
def xml: Expression
def path: Expression

@transient protected lazy val pathUTF8String: UTF8String = path.eval().asInstanceOf[UTF8String]

protected def evaluator: XPathEvaluator

override def replacement: Expression = Invoke(
Literal.create(evaluator, ObjectType(classOf[XPathEvaluator])),
"evaluate",
dataType,
Seq(xml),
Seq(xml.dataType))
}

// scalastyle:off line.size.limit
Expand All @@ -81,11 +87,9 @@ abstract class XPathExtract
// scalastyle:on line.size.limit
case class XPathBoolean(xml: Expression, path: Expression) extends XPathExtract with Predicate {

override def prettyName: String = "xpath_boolean"
@transient override lazy val evaluator: XPathEvaluator = XPathBooleanEvaluator(pathUTF8String)

override def nullSafeEval(xml: Any, path: Any): Any = {
xpathUtil.evalBoolean(xml.asInstanceOf[UTF8String].toString, pathString)
}
override def prettyName: String = "xpath_boolean"

override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): XPathBoolean = copy(xml = newLeft, path = newRight)
Expand All @@ -103,14 +107,12 @@ case class XPathBoolean(xml: Expression, path: Expression) extends XPathExtract
group = "xml_funcs")
// scalastyle:on line.size.limit
case class XPathShort(xml: Expression, path: Expression) extends XPathExtract {

@transient override lazy val evaluator: XPathEvaluator = XPathShortEvaluator(pathUTF8String)

override def prettyName: String = "xpath_short"
override def dataType: DataType = ShortType

override def nullSafeEval(xml: Any, path: Any): Any = {
val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, pathString)
if (ret eq null) null else ret.shortValue()
}

override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): XPathShort = copy(xml = newLeft, path = newRight)
}
Expand All @@ -127,14 +129,12 @@ case class XPathShort(xml: Expression, path: Expression) extends XPathExtract {
group = "xml_funcs")
// scalastyle:on line.size.limit
case class XPathInt(xml: Expression, path: Expression) extends XPathExtract {

@transient override lazy val evaluator: XPathEvaluator = XPathIntEvaluator(pathUTF8String)

override def prettyName: String = "xpath_int"
override def dataType: DataType = IntegerType

override def nullSafeEval(xml: Any, path: Any): Any = {
val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, pathString)
if (ret eq null) null else ret.intValue()
}

override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): Expression = copy(xml = newLeft, path = newRight)
}
Expand All @@ -151,14 +151,12 @@ case class XPathInt(xml: Expression, path: Expression) extends XPathExtract {
group = "xml_funcs")
// scalastyle:on line.size.limit
case class XPathLong(xml: Expression, path: Expression) extends XPathExtract {

@transient override lazy val evaluator: XPathEvaluator = XPathLongEvaluator(pathUTF8String)

override def prettyName: String = "xpath_long"
override def dataType: DataType = LongType

override def nullSafeEval(xml: Any, path: Any): Any = {
val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, pathString)
if (ret eq null) null else ret.longValue()
}

override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): XPathLong = copy(xml = newLeft, path = newRight)
}
Expand All @@ -175,14 +173,12 @@ case class XPathLong(xml: Expression, path: Expression) extends XPathExtract {
group = "xml_funcs")
// scalastyle:on line.size.limit
case class XPathFloat(xml: Expression, path: Expression) extends XPathExtract {

@transient override lazy val evaluator: XPathEvaluator = XPathFloatEvaluator(pathUTF8String)

override def prettyName: String = "xpath_float"
override def dataType: DataType = FloatType

override def nullSafeEval(xml: Any, path: Any): Any = {
val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, pathString)
if (ret eq null) null else ret.floatValue()
}

override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): XPathFloat = copy(xml = newLeft, path = newRight)
}
Expand All @@ -199,15 +195,13 @@ case class XPathFloat(xml: Expression, path: Expression) extends XPathExtract {
group = "xml_funcs")
// scalastyle:on line.size.limit
case class XPathDouble(xml: Expression, path: Expression) extends XPathExtract {

@transient override lazy val evaluator: XPathEvaluator = XPathDoubleEvaluator(pathUTF8String)

override def prettyName: String =
getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("xpath_double")
override def dataType: DataType = DoubleType

override def nullSafeEval(xml: Any, path: Any): Any = {
val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, pathString)
if (ret eq null) null else ret.doubleValue()
}

override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): XPathDouble = copy(xml = newLeft, path = newRight)
}
Expand All @@ -224,14 +218,12 @@ case class XPathDouble(xml: Expression, path: Expression) extends XPathExtract {
group = "xml_funcs")
// scalastyle:on line.size.limit
case class XPathString(xml: Expression, path: Expression) extends XPathExtract {

@transient override lazy val evaluator: XPathEvaluator = XPathStringEvaluator(pathUTF8String)

override def prettyName: String = "xpath_string"
override def dataType: DataType = SQLConf.get.defaultStringType

override def nullSafeEval(xml: Any, path: Any): Any = {
val ret = xpathUtil.evalString(xml.asInstanceOf[UTF8String].toString, pathString)
UTF8String.fromString(ret)
}

override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): Expression = copy(xml = newLeft, path = newRight)
}
Expand All @@ -250,24 +242,12 @@ case class XPathString(xml: Expression, path: Expression) extends XPathExtract {
group = "xml_funcs")
// scalastyle:on line.size.limit
case class XPathList(xml: Expression, path: Expression) extends XPathExtract {

@transient override lazy val evaluator: XPathEvaluator = XPathListEvaluator(pathUTF8String)

override def prettyName: String = "xpath"
override def dataType: DataType = ArrayType(SQLConf.get.defaultStringType)

override def nullSafeEval(xml: Any, path: Any): Any = {
val nodeList = xpathUtil.evalNodeList(xml.asInstanceOf[UTF8String].toString, pathString)
if (nodeList ne null) {
val ret = new Array[AnyRef](nodeList.getLength)
var i = 0
while (i < nodeList.getLength) {
ret(i) = UTF8String.fromString(nodeList.item(i).getNodeValue)
i += 1
}
new GenericArrayData(ret)
} else {
null
}
}

override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): XPathList = copy(xml = newLeft, path = newRight)
}
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [xpath(s#0, a/b/text()) AS xpath(s, a/b/text())#0]
Project [invoke(XPathListEvaluator(a/b/text()).evaluate(s#0)) AS xpath(s, a/b/text())#0]
+- LocalRelation <empty>, [d#0, t#0, s#0, x#0L, wt#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [xpath_boolean(s#0, a/b) AS xpath_boolean(s, a/b)#0]
Project [invoke(XPathBooleanEvaluator(a/b).evaluate(s#0)) AS xpath_boolean(s, a/b)#0]
+- LocalRelation <empty>, [d#0, t#0, s#0, x#0L, wt#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [xpath_double(s#0, a/b) AS xpath_double(s, a/b)#0]
Project [invoke(XPathDoubleEvaluator(a/b).evaluate(s#0)) AS xpath_double(s, a/b)#0]
+- LocalRelation <empty>, [d#0, t#0, s#0, x#0L, wt#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [xpath_float(s#0, a/b) AS xpath_float(s, a/b)#0]
Project [invoke(XPathFloatEvaluator(a/b).evaluate(s#0)) AS xpath_float(s, a/b)#0]
+- LocalRelation <empty>, [d#0, t#0, s#0, x#0L, wt#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [xpath_int(s#0, a/b) AS xpath_int(s, a/b)#0]
Project [invoke(XPathIntEvaluator(a/b).evaluate(s#0)) AS xpath_int(s, a/b)#0]
+- LocalRelation <empty>, [d#0, t#0, s#0, x#0L, wt#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [xpath_long(s#0, a/b) AS xpath_long(s, a/b)#0L]
Project [invoke(XPathLongEvaluator(a/b).evaluate(s#0)) AS xpath_long(s, a/b)#0L]
+- LocalRelation <empty>, [d#0, t#0, s#0, x#0L, wt#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [xpath_number(s#0, a/b) AS xpath_number(s, a/b)#0]
Project [invoke(XPathDoubleEvaluator(a/b).evaluate(s#0)) AS xpath_number(s, a/b)#0]
+- LocalRelation <empty>, [d#0, t#0, s#0, x#0L, wt#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [xpath_short(s#0, a/b) AS xpath_short(s, a/b)#0]
Project [invoke(XPathShortEvaluator(a/b).evaluate(s#0)) AS xpath_short(s, a/b)#0]
+- LocalRelation <empty>, [d#0, t#0, s#0, x#0L, wt#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [xpath_string(s#0, a/b) AS xpath_string(s, a/b)#0]
Project [invoke(XPathStringEvaluator(a/b).evaluate(s#0)) AS xpath_string(s, a/b)#0]
+- LocalRelation <empty>, [d#0, t#0, s#0, x#0L, wt#0]
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql

import org.apache.spark.sql.catalyst.expressions.IsNotNull
import org.apache.spark.sql.execution.FilterExec
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSparkSession

Expand Down Expand Up @@ -76,4 +78,38 @@ class XPathFunctionsSuite extends QueryTest with SharedSparkSession {
checkAnswer(df.select(xpath(col("xml"), lit("a/*/text()"))),
Row(Seq("b1", "b2", "b3", "c1", "c2")))
}

test("The replacement of `xpath*` functions should be NullIntolerant") {
def check(df: DataFrame, expected: Seq[Row]): Unit = {
val filter = df.queryExecution
.sparkPlan
.find(_.isInstanceOf[FilterExec])
.get.asInstanceOf[FilterExec]
assert(filter.condition.find(_.isInstanceOf[IsNotNull]).nonEmpty)
checkAnswer(df, expected)
}
withTable("t") {
sql("CREATE TABLE t AS SELECT * FROM VALUES ('<a><b>1</b></a>'), (NULL) T(xml)")
check(sql("SELECT * FROM t WHERE xpath_boolean(xml, 'a/b') = true"),
Seq(Row("<a><b>1</b></a>")))
check(sql("SELECT * FROM t WHERE xpath_short(xml, 'a/b') = 1"),
Seq(Row("<a><b>1</b></a>")))
check(sql("SELECT * FROM t WHERE xpath_int(xml, 'a/b') = 1"),
Seq(Row("<a><b>1</b></a>")))
check(sql("SELECT * FROM t WHERE xpath_long(xml, 'a/b') = 1"),
Seq(Row("<a><b>1</b></a>")))
check(sql("SELECT * FROM t WHERE xpath_float(xml, 'a/b') = 1"),
Seq(Row("<a><b>1</b></a>")))
check(sql("SELECT * FROM t WHERE xpath_double(xml, 'a/b') = 1"),
Seq(Row("<a><b>1</b></a>")))
check(sql("SELECT * FROM t WHERE xpath_string(xml, 'a/b') = '1'"),
Seq(Row("<a><b>1</b></a>")))
}
withTable("t") {
sql("CREATE TABLE t AS SELECT * FROM VALUES " +
"('<a><b>b1</b><b>b2</b><b>b3</b><c>c1</c><c>c2</c></a>'), (NULL) T(xml)")
check(sql("SELECT * FROM t WHERE xpath(xml, 'a/b/text()') = array('b1', 'b2', 'b3')"),
Seq(Row("<a><b>b1</b><b>b2</b><b>b3</b><c>c1</c><c>c2</c></a>")))
}
}
}

0 comments on commit 656ece1

Please sign in to comment.