Skip to content

Commit

Permalink
[SPARK-50334][SQL] Extract common logic for reading the descriptor of…
Browse files Browse the repository at this point in the history
… PB file

### What changes were proposed in this pull request?
The pr aims to
- extract `common` logic for `reading the descriptor of PB file` to one place.
- at the same time, when using the `from_protobuf` or `to_protobuf` function in `connect-client` and `spark-sql` (or `spark-shell`), the spark error-condition thrown when `the PB file is not found` or `read fails` will be aligned.

### Why are the changes needed?
I found that the logic for `reading the descriptor of PB file` is scattered in various places in the `spark code repository`, eg:
https://github.com/apache/spark/blob/a01856de20013e5551d385ee000772049a0e1bc0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/toFromProtobufSqlFunctions.scala#L37-L48
https://github.com/apache/spark/blob/a01856de20013e5551d385ee000772049a0e1bc0/sql/api/src/main/scala/org/apache/spark/sql/protobuf/functions.scala#L304-L315
https://github.com/apache/spark/blob/a01856de20013e5551d385ee000772049a0e1bc0/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala#L231-L241

- I think we should gather it together to reduce the cost of maintenance.
- Align `spark error-condition` to improve consistency in end-user experience.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Pass GA.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #48874 from panbingkun/SPARK-50334.

Authored-by: panbingkun <panbingkun@baidu.com>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
  • Loading branch information
panbingkun authored and MaxGekk committed Nov 21, 2024
1 parent 229b1b8 commit 136c722
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,14 @@

package org.apache.spark.sql.protobuf.utils

import java.io.File
import java.io.FileNotFoundException
import java.nio.file.NoSuchFileException
import java.util.Locale

import scala.jdk.CollectionConverters._
import scala.util.control.NonFatal

import com.google.protobuf.{DescriptorProtos, Descriptors, InvalidProtocolBufferException, Message}
import com.google.protobuf.DescriptorProtos.{FileDescriptorProto, FileDescriptorSet}
import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor}
import com.google.protobuf.TypeRegistry
import org.apache.commons.io.FileUtils

import org.apache.spark.internal.Logging
import org.apache.spark.sql.errors.QueryCompilationErrors
Expand Down Expand Up @@ -228,18 +223,6 @@ private[sql] object ProtobufUtils extends Logging {
}
}

def readDescriptorFileContent(filePath: String): Array[Byte] = {
try {
FileUtils.readFileToByteArray(new File(filePath))
} catch {
case ex: FileNotFoundException =>
throw QueryCompilationErrors.cannotFindDescriptorFileError(filePath, ex)
case ex: NoSuchFileException =>
throw QueryCompilationErrors.cannotFindDescriptorFileError(filePath, ex)
case NonFatal(ex) => throw QueryCompilationErrors.descriptorParseError(ex)
}
}

private def parseFileDescriptorSet(bytes: Array[Byte]): List[Descriptors.FileDescriptor] = {
var fileDescriptorSet: DescriptorProtos.FileDescriptorSet = null
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.protobuf.utils.{ProtobufUtils, SchemaConverters}
import org.apache.spark.sql.sources.{EqualTo, Not}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.{ProtobufUtils => CommonProtobufUtils}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.ArrayImplicits._

Expand All @@ -39,15 +40,15 @@ class ProtobufCatalystDataConversionSuite
with ProtobufTestBase {

private val testFileDescFile = protobufDescriptorFile("catalyst_types.desc")
private val testFileDesc = ProtobufUtils.readDescriptorFileContent(testFileDescFile)
private val testFileDesc = CommonProtobufUtils.readDescriptorFileContent(testFileDescFile)
private val javaClassNamePrefix = "org.apache.spark.sql.protobuf.protos.CatalystTypes$"

private def checkResultWithEval(
data: Literal,
descFilePath: String,
messageName: String,
expected: Any): Unit = {
val descBytes = ProtobufUtils.readDescriptorFileContent(descFilePath)
val descBytes = CommonProtobufUtils.readDescriptorFileContent(descFilePath)
withClue("(Eval check with Java class name)") {
val className = s"$javaClassNamePrefix$messageName"
checkEvaluation(
Expand All @@ -72,7 +73,7 @@ class ProtobufCatalystDataConversionSuite
actualSchema: String,
badSchema: String): Unit = {

val descBytes = ProtobufUtils.readDescriptorFileContent(descFilePath)
val descBytes = CommonProtobufUtils.readDescriptorFileContent(descFilePath)
val binary = CatalystDataToProtobuf(data, actualSchema, Some(descBytes))

intercept[Exception] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,19 @@ import org.apache.spark.sql.protobuf.utils.ProtobufOptions
import org.apache.spark.sql.protobuf.utils.ProtobufUtils
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.{ProtobufUtils => CommonProtobufUtils}

class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with ProtobufTestBase
with Serializable {

import testImplicits._

val testFileDescFile = protobufDescriptorFile("functions_suite.desc")
private val testFileDesc = ProtobufUtils.readDescriptorFileContent(testFileDescFile)
private val testFileDesc = CommonProtobufUtils.readDescriptorFileContent(testFileDescFile)
private val javaClassNamePrefix = "org.apache.spark.sql.protobuf.protos.SimpleMessageProtos$"

val proto2FileDescFile = protobufDescriptorFile("proto2_messages.desc")
val proto2FileDesc = ProtobufUtils.readDescriptorFileContent(proto2FileDescFile)
val proto2FileDesc = CommonProtobufUtils.readDescriptorFileContent(proto2FileDescFile)
private val proto2JavaClassNamePrefix = "org.apache.spark.sql.protobuf.protos.Proto2Messages$"

private def emptyBinaryDF = Seq(Array[Byte]()).toDF("binary")
Expand Down Expand Up @@ -467,7 +468,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot

test("Handle extra fields : oldProducer -> newConsumer") {
val catalystTypesFile = protobufDescriptorFile("catalyst_types.desc")
val descBytes = ProtobufUtils.readDescriptorFileContent(catalystTypesFile)
val descBytes = CommonProtobufUtils.readDescriptorFileContent(catalystTypesFile)

val oldProducer = ProtobufUtils.buildDescriptor(descBytes, "oldProducer")
val newConsumer = ProtobufUtils.buildDescriptor(descBytes, "newConsumer")
Expand Down Expand Up @@ -509,7 +510,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot

test("Handle extra fields : newProducer -> oldConsumer") {
val catalystTypesFile = protobufDescriptorFile("catalyst_types.desc")
val descBytes = ProtobufUtils.readDescriptorFileContent(catalystTypesFile)
val descBytes = CommonProtobufUtils.readDescriptorFileContent(catalystTypesFile)

val newProducer = ProtobufUtils.buildDescriptor(descBytes, "newProducer")
val oldConsumer = ProtobufUtils.buildDescriptor(descBytes, "oldConsumer")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast.toSQLType
import org.apache.spark.sql.protobuf.utils.ProtobufUtils
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.sql.util.{ProtobufUtils => CommonProtobufUtils}

/**
* Tests for [[ProtobufSerializer]] and [[ProtobufDeserializer]] with a more specific focus on
Expand All @@ -37,12 +38,12 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase {
import ProtoSerdeSuite.MatchType._

private val testFileDescFile = protobufDescriptorFile("serde_suite.desc")
private val testFileDesc = ProtobufUtils.readDescriptorFileContent(testFileDescFile)
private val testFileDesc = CommonProtobufUtils.readDescriptorFileContent(testFileDescFile)

private val javaClassNamePrefix = "org.apache.spark.sql.protobuf.protos.SerdeSuiteProtos$"

private val proto2DescFile = protobufDescriptorFile("proto2_messages.desc")
private val proto2Desc = ProtobufUtils.readDescriptorFileContent(proto2DescFile)
private val proto2Desc = CommonProtobufUtils.readDescriptorFileContent(proto2DescFile)

test("Test basic conversion") {
withFieldMatchType { fieldMatch =>
Expand Down Expand Up @@ -215,7 +216,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase {

val e1 = intercept[AnalysisException] {
ProtobufUtils.buildDescriptor(
ProtobufUtils.readDescriptorFileContent(fileDescFile),
CommonProtobufUtils.readDescriptorFileContent(fileDescFile),
"SerdeBasicMessage"
)
}
Expand All @@ -225,7 +226,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase {
condition = "CANNOT_PARSE_PROTOBUF_DESCRIPTOR")

val basicMessageDescWithoutImports = descriptorSetWithoutImports(
ProtobufUtils.readDescriptorFileContent(
CommonProtobufUtils.readDescriptorFileContent(
protobufDescriptorFile("basicmessage.desc")
),
"BasicMessage"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,12 @@
*/
package org.apache.spark.sql.protobuf

import java.io.FileNotFoundException
import java.nio.file.{Files, NoSuchFileException, Paths}

import scala.jdk.CollectionConverters._
import scala.util.control.NonFatal

import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.Column
import org.apache.spark.sql.errors.CompilationErrors
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.util.ProtobufUtils

// scalastyle:off: object.name
object functions {
Expand All @@ -51,7 +47,7 @@ object functions {
messageName: String,
descFilePath: String,
options: java.util.Map[String, String]): Column = {
val descriptorFileContent = readDescriptorFileContent(descFilePath)
val descriptorFileContent = ProtobufUtils.readDescriptorFileContent(descFilePath)
from_protobuf(data, messageName, descriptorFileContent, options)
}

Expand Down Expand Up @@ -98,7 +94,7 @@ object functions {
*/
@Experimental
def from_protobuf(data: Column, messageName: String, descFilePath: String): Column = {
val fileContent = readDescriptorFileContent(descFilePath)
val fileContent = ProtobufUtils.readDescriptorFileContent(descFilePath)
from_protobuf(data, messageName, fileContent)
}

Expand Down Expand Up @@ -226,7 +222,7 @@ object functions {
messageName: String,
descFilePath: String,
options: java.util.Map[String, String]): Column = {
val fileContent = readDescriptorFileContent(descFilePath)
val fileContent = ProtobufUtils.readDescriptorFileContent(descFilePath)
to_protobuf(data, messageName, fileContent, options)
}

Expand Down Expand Up @@ -299,18 +295,4 @@ object functions {
options: java.util.Map[String, String]): Column = {
Column.fnWithOptions("to_protobuf", options.asScala.iterator, data, lit(messageClassName))
}

// This method is copied from org.apache.spark.sql.protobuf.util.ProtobufUtils
private def readDescriptorFileContent(filePath: String): Array[Byte] = {
try {
Files.readAllBytes(Paths.get(filePath))
} catch {
case ex: FileNotFoundException =>
throw CompilationErrors.cannotFindDescriptorFileError(filePath, ex)
case ex: NoSuchFileException =>
throw CompilationErrors.cannotFindDescriptorFileError(filePath, ex)
case NonFatal(ex) =>
throw CompilationErrors.descriptorParseError(ex)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.util

import java.io.{File, FileNotFoundException}
import java.nio.file.NoSuchFileException

import scala.util.control.NonFatal

import org.apache.commons.io.FileUtils

import org.apache.spark.sql.errors.CompilationErrors

object ProtobufUtils {
def readDescriptorFileContent(filePath: String): Array[Byte] = {
try {
FileUtils.readFileToByteArray(new File(filePath))
} catch {
case ex: FileNotFoundException =>
throw CompilationErrors.cannotFindDescriptorFileError(filePath, ex)
case ex: NoSuchFileException =>
throw CompilationErrors.cannotFindDescriptorFileError(filePath, ex)
case NonFatal(ex) => throw CompilationErrors.descriptorParseError(ex)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,37 +17,15 @@

package org.apache.spark.sql.catalyst.expressions

import java.io.File
import java.io.FileNotFoundException
import java.nio.file.NoSuchFileException

import scala.util.control.NonFatal

import org.apache.commons.io.FileUtils

import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types.{BinaryType, MapType, NullType, StringType}
import org.apache.spark.sql.util.ProtobufUtils
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils

object ProtobufHelper {
def readDescriptorFileContent(filePath: String): Array[Byte] = {
try {
FileUtils.readFileToByteArray(new File(filePath))
} catch {
case ex: FileNotFoundException =>
throw new RuntimeException(s"Cannot find descriptor file at path: $filePath", ex)
case ex: NoSuchFileException =>
throw new RuntimeException(s"Cannot find descriptor file at path: $filePath", ex)
case NonFatal(ex) =>
throw new RuntimeException(s"Failed to read the descriptor file: $filePath", ex)
}
}
}

/**
* Converts a binary column of Protobuf format into its corresponding catalyst value.
* The Protobuf definition is provided through Protobuf <i>descriptor file</i>.
Expand Down Expand Up @@ -163,7 +141,7 @@ case class FromProtobuf(
}
val descFilePathValue: Option[Array[Byte]] = descFilePath.eval() match {
case s: UTF8String if s.toString.isEmpty => None
case s: UTF8String => Some(ProtobufHelper.readDescriptorFileContent(s.toString))
case s: UTF8String => Some(ProtobufUtils.readDescriptorFileContent(s.toString))
case bytes: Array[Byte] if bytes.isEmpty => None
case bytes: Array[Byte] => Some(bytes)
case null => None
Expand Down Expand Up @@ -300,7 +278,7 @@ case class ToProtobuf(
s.toString
}
val descFilePathValue: Option[Array[Byte]] = descFilePath.eval() match {
case s: UTF8String => Some(ProtobufHelper.readDescriptorFileContent(s.toString))
case s: UTF8String => Some(ProtobufUtils.readDescriptorFileContent(s.toString))
case bytes: Array[Byte] => Some(bytes)
case null => None
}
Expand Down

0 comments on commit 136c722

Please sign in to comment.