Skip to content

Commit

Permalink
[SPARK-45795][SQL] DS V2 supports push down Mode
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
This PR will translate the aggregate function `MODE` for pushdown.

The constructor of aggregate function `MODE` has a `deterministic` parameter. When multiple values have the same greatest frequency then either any of values is returned if `deterministic` is false or is not defined, or the lowest value is returned if `deterministic` is true.
If `deterministic` is true, the semantics of `deterministic` is the same as the syntax supported by some databases (e.g. H2, Postgres) show below.
The syntax is:
`MODE() WITHIN GROUP (ORDER BY col)`.

Note: `MODE() WITHIN GROUP (ORDER BY col)` doesn't support `DISTINCT` keyword.

### Why are the changes needed?
DS V2 supports push down `Mode`

### Does this PR introduce _any_ user-facing change?
'No'.
New feature.

### How was this patch tested?
New test cases.

### Was this patch authored or co-authored using generative AI tooling?
'No'.

Closes #43661 from beliefer/SPARK-45795.

Authored-by: Jiaan Geng <beliefer@163.com>
Signed-off-by: Jiaan Geng <beliefer@163.com>
  • Loading branch information
beliefer committed Dec 16, 2023
1 parent ac935f5 commit f69f791
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
* <li><pre>REGR_R2(input1, input2)</pre> Since 3.4.0</li>
* <li><pre>REGR_SLOPE(input1, input2)</pre> Since 3.4.0</li>
* <li><pre>REGR_SXY(input1, input2)</pre> Since 3.4.0</li>
* <li><pre>MODE(input1[, inverse])</pre> Since 4.0.0</li>
* </ol>
*
* @since 3.3.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,10 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
Some(new GeneralAggregateFunc("REGR_SLOPE", isDistinct, Array(left, right)))
case aggregate.RegrSXY(PushableExpression(left), PushableExpression(right)) =>
Some(new GeneralAggregateFunc("REGR_SXY", isDistinct, Array(left, right)))
// Translate Mode if it is deterministic or reverse is defined.
case aggregate.Mode(PushableExpression(expr), _, _, Some(reverse)) =>
Some(new GeneralAggregateFunc("MODE", isDistinct,
Array(expr, LiteralValue(reverse, BooleanType))))
// TODO supports other aggregate functions
case aggregate.V2Aggregator(aggrFunc, children, _, _) =>
val translatedExprs = children.flatMap(PushableExpression.unapply(_))
Expand Down
16 changes: 14 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ private[sql] object H2Dialect extends JdbcDialect {
url.toLowerCase(Locale.ROOT).startsWith("jdbc:h2")

private val distinctUnsupportedAggregateFunctions =
Set("COVAR_POP", "COVAR_SAMP", "CORR", "REGR_INTERCEPT", "REGR_R2", "REGR_SLOPE", "REGR_SXY")
Set("COVAR_POP", "COVAR_SAMP", "CORR", "REGR_INTERCEPT", "REGR_R2", "REGR_SLOPE", "REGR_SXY",
"MODE")

private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", "AVG",
"VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP") ++ distinctUnsupportedAggregateFunctions
Expand Down Expand Up @@ -256,7 +257,18 @@ private[sql] object H2Dialect extends JdbcDialect {
throw new UnsupportedOperationException(s"${this.getClass.getSimpleName} does not " +
s"support aggregate function: $funcName with DISTINCT")
} else {
super.visitAggregateFunction(funcName, isDistinct, inputs)
funcName match {
case "MODE" =>
// Support Mode only if it is deterministic or reverse is defined.
assert(inputs.length == 2)
if (inputs.last == "true") {
s"MODE() WITHIN GROUP (ORDER BY ${inputs.head})"
} else {
s"MODE() WITHIN GROUP (ORDER BY ${inputs.head} DESC)"
}
case _ =>
super.visitAggregateFunction(funcName, isDistinct, inputs)
}
}

override def visitExtract(field: String, source: String): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2424,6 +2424,70 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
checkAnswer(df4, Seq(Row(1100.0, 1100.0), Row(1200.0, 1200.0), Row(1250.0, 1250.0)))
}

test("scan with aggregate push-down: MODE with filter and group by") {
val df1 = sql(
"""
|SELECT
| dept,
| MODE(salary, true)
|FROM h2.test.employee WHERE dept > 0 GROUP BY DePt""".stripMargin)
checkFiltersRemoved(df1)
checkAggregateRemoved(df1)
checkPushedInfo(df1,
"""
|PushedAggregates: [MODE(SALARY, true)],
|PushedFilters: [DEPT IS NOT NULL, DEPT > 0],
|PushedGroupByExpressions: [DEPT],
|""".stripMargin.replaceAll("\n", " "))
checkAnswer(df1, Seq(Row(1, 9000.00), Row(2, 10000.00), Row(6, 12000.00)))

val df2 = sql(
"""
|SELECT
| dept,
| MODE(salary)
|FROM h2.test.employee WHERE dept > 0 GROUP BY DePt""".stripMargin)
checkFiltersRemoved(df2)
checkAggregateRemoved(df2, false)
checkPushedInfo(df2,
"""
|PushedFilters: [DEPT IS NOT NULL, DEPT > 0],
|""".stripMargin.replaceAll("\n", " "))
checkAnswer(df2, Seq(Row(1, 10000.00), Row(2, 10000.00), Row(6, 12000.00)))

val df3 = sql(
"""
|SELECT
| dept,
| MODE() WITHIN GROUP (ORDER BY salary)
|FROM h2.test.employee WHERE dept > 0 GROUP BY DePt""".stripMargin)
checkFiltersRemoved(df3)
checkAggregateRemoved(df3)
checkPushedInfo(df3,
"""
|PushedAggregates: [MODE(SALARY, true)],
|PushedFilters: [DEPT IS NOT NULL, DEPT > 0],
|PushedGroupByExpressions: [DEPT],
|""".stripMargin.replaceAll("\n", " "))
checkAnswer(df3, Seq(Row(1, 9000.00), Row(2, 10000.00), Row(6, 12000.00)))

val df4 = sql(
"""
|SELECT
| dept,
| MODE() WITHIN GROUP (ORDER BY salary DESC)
|FROM h2.test.employee WHERE dept > 0 GROUP BY DePt""".stripMargin)
checkFiltersRemoved(df4)
checkAggregateRemoved(df4)
checkPushedInfo(df4,
"""
|PushedAggregates: [MODE(SALARY, false)],
|PushedFilters: [DEPT IS NOT NULL, DEPT > 0],
|PushedGroupByExpressions: [DEPT],
|""".stripMargin.replaceAll("\n", " "))
checkAnswer(df4, Seq(Row(1, 10000.00), Row(2, 12000.00), Row(6, 12000.00)))
}

test("scan with aggregate push-down: aggregate over alias push down") {
val cols = Seq("a", "b", "c", "d", "e")
val df1 = sql("SELECT * FROM h2.test.employee").toDF(cols: _*)
Expand Down

0 comments on commit f69f791

Please sign in to comment.