Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50370][SQL] Codegen Support for json_tuple #48908

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
*/
package org.apache.spark.sql.catalyst.expressions.json

import java.io.CharArrayWriter
import java.io.{ByteArrayOutputStream, CharArrayWriter}

import com.fasterxml.jackson.core.JsonFactory
import com.fasterxml.jackson.core._

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.ExprUtils
import org.apache.spark.sql.catalyst.expressions.{ExprUtils, GenericInternalRow, SharedFactory}
import org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils
import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonGenerator, JacksonParser, JsonInferSchema, JSONOptions}
import org.apache.spark.sql.catalyst.util.{ArrayData, FailFastMode, FailureSafeParser, MapData, PermissiveMode}
Expand Down Expand Up @@ -159,3 +160,90 @@ case class StructsToJsonEvaluator(
converter(value)
}
}

case class JsonTupleEvaluator(fieldsLength: Int) {

import SharedFactory._

// if processing fails this shared value will be returned
@transient private lazy val nullRow: Seq[InternalRow] =
new GenericInternalRow(Array.ofDim[Any](fieldsLength)) :: Nil

private def parseRow(parser: JsonParser, fieldNames: Seq[String]): Seq[InternalRow] = {
// only objects are supported
if (parser.nextToken() != JsonToken.START_OBJECT) return nullRow

val row = Array.ofDim[Any](fieldNames.length)

// start reading through the token stream, looking for any requested field names
while (parser.nextToken() != JsonToken.END_OBJECT) {
if (parser.getCurrentToken == JsonToken.FIELD_NAME) {
// check to see if this field is desired in the output
val jsonField = parser.currentName
var idx = fieldNames.indexOf(jsonField)
if (idx >= 0) {
// it is, copy the child tree to the correct location in the output row
val output = new ByteArrayOutputStream()

// write the output directly to UTF8 encoded byte array
if (parser.nextToken() != JsonToken.VALUE_NULL) {
Utils.tryWithResource(jsonFactory.createGenerator(output, JsonEncoding.UTF8)) {
generator => copyCurrentStructure(generator, parser)
}

val jsonValue = UTF8String.fromBytes(output.toByteArray)

// SPARK-21804: json_tuple returns null values within repeated columns
// except the first one; so that we need to check the remaining fields.
do {
row(idx) = jsonValue
idx = fieldNames.indexOf(jsonField, idx + 1)
} while (idx >= 0)
}
}
}

// always skip children, it's cheap enough to do even if copyCurrentStructure was called
parser.skipChildren()
}
new GenericInternalRow(row) :: Nil
}

private def copyCurrentStructure(generator: JsonGenerator, parser: JsonParser): Unit = {
parser.getCurrentToken match {
// if the user requests a string field it needs to be returned without enclosing
// quotes which is accomplished via JsonGenerator.writeRaw instead of JsonGenerator.write
case JsonToken.VALUE_STRING if parser.hasTextCharacters =>
// slight optimization to avoid allocating a String instance, though the characters
// still have to be decoded... Jackson doesn't have a way to access the raw bytes
generator.writeRaw(parser.getTextCharacters, parser.getTextOffset, parser.getTextLength)

case JsonToken.VALUE_STRING =>
// the normal String case, pass it through to the output without enclosing quotes
generator.writeRaw(parser.getText)

case JsonToken.VALUE_NULL =>
// a special case that needs to be handled outside of this method.
// if a requested field is null, the result must be null. the easiest
// way to achieve this is just by ignoring null tokens entirely
throw SparkException.internalError("Do not attempt to copy a null field.")

case _ =>
// handle other types including objects, arrays, booleans and numbers
generator.copyCurrentStructure(parser)
}
}

final def evaluate(json: UTF8String, fieldNames: Seq[String]): Seq[InternalRow] = {
if (json == null) return nullRow
try {
/* We know the bytes are UTF-8 encoded. Pass a Reader to avoid having Jackson
detect character encoding which could fail for some malformed strings */
Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, json)) { parser =>
parseRow(parser, fieldNames)
}
} catch {
case _: JsonProcessingException => nullRow
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import java.io._

import scala.collection.immutable.ArraySeq
import scala.util.parsing.combinator.RegexParsers

import com.fasterxml.jackson.core._
Expand All @@ -28,9 +29,9 @@ import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
import org.apache.spark.sql.catalyst.expressions.json.{JsonExpressionEvalUtils, JsonExpressionUtils, JsonToStructsEvaluator, StructsToJsonEvaluator}
import org.apache.spark.sql.catalyst.expressions.json.{JsonExpressionEvalUtils, JsonExpressionUtils, JsonToStructsEvaluator, JsonTupleEvaluator, StructsToJsonEvaluator}
import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke}
import org.apache.spark.sql.catalyst.json._
import org.apache.spark.sql.catalyst.trees.TreePattern.{JSON_TO_STRUCT, RUNTIME_REPLACEABLE, TreePattern}
Expand Down Expand Up @@ -106,7 +107,7 @@ private[this] object JsonPathParser extends RegexParsers {
}
}

private[this] object SharedFactory {
private[expressions] object SharedFactory {
val jsonFactory = new JsonFactoryBuilder()
// The two options below enabled for Hive compatibility
.enable(JsonReadFeature.ALLOW_UNESCAPED_CONTROL_CHARS)
Expand Down Expand Up @@ -446,20 +447,8 @@ class GetJsonObjectEvaluator(cachedPath: UTF8String) {
// scalastyle:on line.size.limit line.contains.tab
case class JsonTuple(children: Seq[Expression])
extends Generator
with CodegenFallback
with QueryErrorsBase {

import SharedFactory._

override def nullable: Boolean = {
// a row is always returned
false
}

// if processing fails this shared value will be returned
@transient private lazy val nullRow: Seq[InternalRow] =
new GenericInternalRow(Array.ofDim[Any](fieldExpressions.length)) :: Nil

// the json body is the first child
@transient private lazy val jsonExpr: Expression = children.head

Expand All @@ -477,6 +466,11 @@ case class JsonTuple(children: Seq[Expression])
// and count the number of foldable fields, we'll use this later to optimize evaluation
@transient private lazy val constantFields: Int = foldableFieldNames.count(_ != null)

override def nullable: Boolean = {
// a row is always returned
false
}

override def elementSchema: StructType = StructType(fieldExpressions.zipWithIndex.map {
case (_, idx) => StructField(s"c$idx", children.head.dataType, nullable = true)
})
Expand All @@ -499,29 +493,11 @@ case class JsonTuple(children: Seq[Expression])
}
}

@transient
private lazy val evaluator: JsonTupleEvaluator = JsonTupleEvaluator(fieldExpressions.length)

override def eval(input: InternalRow): IterableOnce[InternalRow] = {
val json = jsonExpr.eval(input).asInstanceOf[UTF8String]
if (json == null) {
return nullRow
}

try {
/* We know the bytes are UTF-8 encoded. Pass a Reader to avoid having Jackson
detect character encoding which could fail for some malformed strings */
Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, json)) { parser =>
parseRow(parser, input)
}
} catch {
case _: JsonProcessingException =>
nullRow
}
}

private def parseRow(parser: JsonParser, input: InternalRow): Seq[InternalRow] = {
// only objects are supported
if (parser.nextToken() != JsonToken.START_OBJECT) {
return nullRow
}

// evaluate the field names as String rather than UTF8String to
// optimize lookups from the json token, which is also a String
Expand All @@ -544,66 +520,95 @@ case class JsonTuple(children: Seq[Expression])
}
}

val row = Array.ofDim[Any](fieldNames.length)

// start reading through the token stream, looking for any requested field names
while (parser.nextToken() != JsonToken.END_OBJECT) {
if (parser.getCurrentToken == JsonToken.FIELD_NAME) {
// check to see if this field is desired in the output
val jsonField = parser.currentName
var idx = fieldNames.indexOf(jsonField)
if (idx >= 0) {
// it is, copy the child tree to the correct location in the output row
val output = new ByteArrayOutputStream()
evaluator.evaluate(json, fieldNames)
}

// write the output directly to UTF8 encoded byte array
if (parser.nextToken() != JsonToken.VALUE_NULL) {
Utils.tryWithResource(jsonFactory.createGenerator(output, JsonEncoding.UTF8)) {
generator => copyCurrentStructure(generator, parser)
}
private def genFieldNamesCode(
ctx: CodegenContext,
refFoldableFieldNames: String,
fieldNamesTerm: String): String = {

val jsonValue = UTF8String.fromBytes(output.toByteArray)
def genFoldableFieldNameCode(refIndexedSeq: String, i: Int): String = {
s"(String)((scala.Option<String>)$refIndexedSeq.apply($i)).get();"
}

// SPARK-21804: json_tuple returns null values within repeated columns
// except the first one; so that we need to check the remaining fields.
do {
row(idx) = jsonValue
idx = fieldNames.indexOf(jsonField, idx + 1)
} while (idx >= 0)
// evaluate the field names as String rather than UTF8String to
// optimize lookups from the json token, which is also a String
val (fieldNamesEval, setFieldNames) = if (constantFields == fieldExpressions.length) {
// typically the user will provide the field names as foldable expressions
// so we can use the cached copy
val s = foldableFieldNames.zipWithIndex.map {
case (v, i) =>
if (v != null && v.isDefined) {
s"$fieldNamesTerm[$i] = ${genFoldableFieldNameCode(refFoldableFieldNames, i)};"
} else {
s"$fieldNamesTerm[$i] = null;"
}
}
}

// always skip children, it's cheap enough to do even if copyCurrentStructure was called
parser.skipChildren()
(Seq.empty[ExprCode], s)
} else if (constantFields == 0) {
// none are foldable so all field names need to be evaluated from the input row
val f = fieldExpressions.map(_.genCode(ctx))
val s = f.zipWithIndex.map {
case (exprCode, i) =>
s"""
|if (${exprCode.isNull}) {
| $fieldNamesTerm[$i] = null;
|} else {
| $fieldNamesTerm[$i] = ${exprCode.value}.toString();
|}
|""".stripMargin
}
(f, s)
} else {
// if there is a mix of constant and non-constant expressions
// prefer the cached copy when available
val codes = foldableFieldNames.zip(fieldExpressions).zipWithIndex.map {
case ((null, expr: Expression), i) =>
val f = expr.genCode(ctx)
val s =
s"""
|if (${f.isNull}) {
| $fieldNamesTerm[$i] = null;
|} else {
| $fieldNamesTerm[$i] = ${f.value}.toString();
|}
|""".stripMargin
(Some(f), s)
case ((v: Option[String], _), i) =>
val s = if (v.isDefined) {
s"$fieldNamesTerm[$i] = ${genFoldableFieldNameCode(refFoldableFieldNames, i)};"
} else {
s"$fieldNamesTerm[$i] = null;"
}
(None, s)
}
(codes.filter(c => c._1.isDefined).map(c => c._1.get), codes.map(c => c._2))
}

new GenericInternalRow(row) :: Nil
s"""
|String[] $fieldNamesTerm = new String[${fieldExpressions.length}];
|${fieldNamesEval.map(_.code).mkString("\n")}
|${setFieldNames.mkString("\n")}
|""".stripMargin
}

private def copyCurrentStructure(generator: JsonGenerator, parser: JsonParser): Unit = {
parser.getCurrentToken match {
// if the user requests a string field it needs to be returned without enclosing
// quotes which is accomplished via JsonGenerator.writeRaw instead of JsonGenerator.write
case JsonToken.VALUE_STRING if parser.hasTextCharacters =>
// slight optimization to avoid allocating a String instance, though the characters
// still have to be decoded... Jackson doesn't have a way to access the raw bytes
generator.writeRaw(parser.getTextCharacters, parser.getTextOffset, parser.getTextLength)

case JsonToken.VALUE_STRING =>
// the normal String case, pass it through to the output without enclosing quotes
generator.writeRaw(parser.getText)

case JsonToken.VALUE_NULL =>
// a special case that needs to be handled outside of this method.
// if a requested field is null, the result must be null. the easiest
// way to achieve this is just by ignoring null tokens entirely
throw SparkException.internalError("Do not attempt to copy a null field.")

case _ =>
// handle other types including objects, arrays, booleans and numbers
generator.copyCurrentStructure(parser)
}
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val refEvaluator = ctx.addReferenceObj("evaluator", evaluator)
val refFoldableFieldNames = ctx.addReferenceObj("foldableFieldNames", foldableFieldNames)
val wrapperClass = classOf[Seq[_]].getName
val jsonEval = jsonExpr.genCode(ctx)
val fieldNamesTerm = ctx.freshName("fieldNames")
val fieldNamesCode = genFieldNamesCode(ctx, refFoldableFieldNames, fieldNamesTerm)
val fieldNamesClz = classOf[ArraySeq[_]].getName
ev.copy(code =
code"""
|${jsonEval.code}
|$fieldNamesCode
|boolean ${ev.isNull} = false;
|$wrapperClass<InternalRow> ${ev.value} = $refEvaluator.evaluate(
| ${jsonEval.value}, new $fieldNamesClz.ofRef($fieldNamesTerm));
|""".stripMargin)
}

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): JsonTuple =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,9 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
}

test("json_tuple escaping") {
Copy link
Contributor Author

@panbingkun panbingkun Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test("stack") {
    GenerateUnsafeProjection.generate(
      Stack(Seq(2, 1, 2, 3).map(Literal(_))) :: Nil)
  }
  • which is also not supported and throw an exception.
19:59:59.933 ERROR org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator: Failed to compile the generated Java code.
org.codehaus.commons.compiler.CompileException: File 'generated.java', Line 78, Column 30: Assignment conversion not possible from type "scala.collection.mutable.ArraySeq" to type "org.apache.spark.sql.catalyst.util.ArrayData"
	at org.codehaus.janino.UnitCompiler.compileError(UnitCompiler.java:13014)
	at org.codehaus.janino.UnitCompiler.assignmentConversion(UnitCompiler.java:11263)
	at org.codehaus.janino.UnitCompiler.access$3900(UnitCompiler.java:236)
	at org.codehaus.janino.UnitCompiler$7.visitRvalue(UnitCompiler.java:2764)
	at org.codehaus.janino.UnitCompiler$7.visitRvalue(UnitCompiler.java:2754)
	at org.codehaus.janino.Java$Rvalue.accept(Java.java:4498)
	at org.codehaus.janino.UnitCompiler.compile(UnitCompiler.java:2754)
19:59:59.940 ERROR org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator: 
/* 001 */ public java.lang.Object generate(Object[] references) {
/* 002 */   return new SpecificUnsafeProjection(references);
/* 003 */ }
/* 004 */
/* 005 */ class SpecificUnsafeProjection extends org.apache.spark.sql.catalyst.expressions.UnsafeProjection {
/* 006 */
/* 007 */   private Object[] references;
/* 008 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter[] mutableStateArray_2 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter[1];
/* 009 */   p...
  • therefore, the testing approach has been modified here with the same ultimate goal.

GenerateUnsafeProjection.generate(
JsonTuple(Literal("\"quote") :: Literal("\"quote") :: Nil) :: Nil)
checkJsonTuple(
JsonTuple(Literal("\"quote") :: Literal("\"quote") :: Nil),
InternalRow.fromSeq(Seq(null).map(UTF8String.fromString)))
}

test("json_tuple - hive key 1") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1456,4 +1456,23 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession {
assert(plan.isInstanceOf[WholeStageCodegenExec])
checkAnswer(df, Row(null))
}

test("function json_tuple codegen - field name foldable optimize") {
withTempView("t") {
val df = Seq(("""{"a":1, "b":2}""", "a", "b")).toDF("json", "c1", "c2")
df.createOrReplaceTempView("t")

// all field names are non-foldable
val df1 = sql("SELECT json_tuple(json, c1, c2) from t")
checkAnswer(df1, Row("1", "2"))

// some foldable, some non-foldable
val df2 = sql("SELECT json_tuple(json, 'a', c2) from t")
checkAnswer(df2, Row("1", "2"))

// all field names are foldable
val df3 = sql("SELECT json_tuple(json, 'a', 'b') from t")
checkAnswer(df3, Row("1", "2"))
}
}
}