Skip to content

Commit

Permalink
Support launching Map Pandas UDF on empty partitions (#9557)
Browse files Browse the repository at this point in the history
fixes #9480

This PR adds support of launching Map Pandas UDF on empty partitions to align with Spark's behavior.

So far I don't see other types of Pandas UDF will be called for empty partitions.

The test is copied from the example in the linked issue.

---------

Signed-off-by: Firestarman <firestarmanllc@gmail.com>
  • Loading branch information
firestarman authored Oct 30, 2023
1 parent ddb8f6b commit 2e282f9
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 28 deletions.
10 changes: 10 additions & 0 deletions integration_tests/src/main/python/udf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,3 +400,13 @@ def filter_func(iterator):
.mapInArrow(filter_func, schema=f"a {data_type}, b {data_type}"),
"PythonMapInArrowExec",
conf=conf)


def test_map_pandas_udf_with_empty_partitions():
def test_func(spark):
df = spark.range(10).withColumn("const", f.lit(1))
# The repartition will produce 4 empty partitions.
return df.repartition(5, "const").mapInPandas(
lambda data: [pd.DataFrame([len(list(data))])], schema="ret:integer")

assert_gpu_and_cpu_are_equal_collect(test_func, conf=arrow_udf_conf)
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,17 @@ import ai.rapids.cudf._
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.Arm.withResource
import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion
import org.apache.arrow.vector.VectorSchemaRoot
import org.apache.arrow.vector.ipc.ArrowStreamWriter

import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.api.python._
import org.apache.spark.rapids.shims.api.python.ShimBasePythonRunner
import org.apache.spark.sql.execution.python.PythonUDFRunner
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.shims.ArrowUtilsShim
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -241,6 +245,21 @@ abstract class GpuArrowPythonRunnerBase(
}

protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = {
if (inputIterator.nonEmpty) {
writeNonEmptyIteratorOnGpu(dataOut)
} else { // Partition is empty.
// In this case CPU will still send the schema to Python workers by calling
// the "start" API of the Java Arrow writer, but GPU will send out nothing,
// leading to the IPC error. And it is not easy to do as what Spark does on
// GPU, because the C++ Arrow writer used by GPU will only send out the schema
// iff there is some data. Besides, it does not expose a "start" API to do this.
// So here we leverage the Java Arrow writer to do similar things as Spark.
// It is OK because sending out schema has nothing to do with GPU.
writeEmptyIteratorOnCpu(dataOut)
}
}

private def writeNonEmptyIteratorOnGpu(dataOut: DataOutputStream): Unit = {
val writer = {
val builder = ArrowIPCWriterOptions.builder()
builder.withMaxChunkSize(batchSize)
Expand All @@ -250,11 +269,11 @@ abstract class GpuArrowPythonRunnerBase(
})
// Flatten the names of nested struct columns, required by cudf arrow IPC writer.
GpuArrowPythonRunner.flattenNames(pythonInSchema).foreach { case (name, nullable) =>
if (nullable) {
builder.withColumnNames(name)
} else {
builder.withNotNullableColumnNames(name)
}
if (nullable) {
builder.withColumnNames(name)
} else {
builder.withNotNullableColumnNames(name)
}
}
Table.writeArrowIPCChunked(builder.build(), new BufferToStreamWriter(dataOut))
}
Expand All @@ -277,6 +296,28 @@ abstract class GpuArrowPythonRunnerBase(
if (onDataWriteFinished != null) onDataWriteFinished()
}
}

private def writeEmptyIteratorOnCpu(dataOut: DataOutputStream): Unit = {
// most code is copied from Spark
val arrowSchema = ArrowUtilsShim.toArrowSchema(pythonInSchema, timeZoneId)
val allocator = ArrowUtils.rootAllocator.newChildAllocator(
s"stdout writer for empty partition", 0, Long.MaxValue)
val root = VectorSchemaRoot.create(arrowSchema, allocator)

Utils.tryWithSafeFinally {
val writer = new ArrowStreamWriter(root, null, dataOut)
writer.start()
// No data to write
writer.end()
// The iterator can grab the semaphore even on an empty batch
GpuSemaphore.releaseIfNecessary(TaskContext.get())
} {
root.close()
allocator.close()
if (onDataWriteFinished != null) onDataWriteFinished()
}
}

}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,31 +87,25 @@ trait GpuMapInBatchExec extends ShimUnaryExecNode with GpuPythonExecBase {
}
}
}

if (pyInputIterator.hasNext) {
val pyRunner = new GpuArrowPythonRunnerBase(
chainedFunc,
pythonEvalType,
argOffsets,
pyInputSchema,
sessionLocalTimeZone,
pythonRunnerConf,
batchSize) {
override def toBatch(table: Table): ColumnarBatch = {
BatchGroupedIterator.extractChildren(table, localOutput)
}
val pyRunner = new GpuArrowPythonRunnerBase(
chainedFunc,
pythonEvalType,
argOffsets,
pyInputSchema,
sessionLocalTimeZone,
pythonRunnerConf,
batchSize) {
override def toBatch(table: Table): ColumnarBatch = {
BatchGroupedIterator.extractChildren(table, localOutput)
}

pyRunner.compute(pyInputIterator, context.partitionId(), context)
.map { cb =>
numOutputBatches += 1
numOutputRows += cb.numRows
cb
}
} else {
// Empty partition, return it directly
inputIter
}

pyRunner.compute(pyInputIterator, context.partitionId(), context)
.map { cb =>
numOutputBatches += 1
numOutputRows += cb.numRows
cb
}
} // end of mapPartitionsInternal
} // end of internalDoExecuteColumnar

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,18 @@
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids.shims

import org.apache.arrow.vector.types.pojo.Schema

import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.ArrowUtils

object ArrowUtilsShim {
def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] =
ArrowUtils.getPythonRunnerConfMap(conf)

def toArrowSchema(schema: StructType, timeZoneId: String,
errorOnDuplicatedFieldNames: Boolean = true, largeVarTypes: Boolean = false): Schema = {
ArrowUtils.toArrowSchema(schema, timeZoneId)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,19 @@
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids.shims

import org.apache.arrow.vector.types.pojo.Schema

import org.apache.spark.sql.execution.python.ArrowPythonRunner
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.ArrowUtils

object ArrowUtilsShim {
def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] =
ArrowPythonRunner.getPythonRunnerConfMap(conf)

def toArrowSchema(schema: StructType, timeZoneId: String,
errorOnDuplicatedFieldNames: Boolean = true, largeVarTypes: Boolean = false): Schema = {
ArrowUtils.toArrowSchema(schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes)
}
}

0 comments on commit 2e282f9

Please sign in to comment.