From f69f791ef8787bdcd6bcd2e5cd8e33cde36be5de Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Sat, 16 Dec 2023 18:50:54 +0800 Subject: [PATCH] [SPARK-45795][SQL] DS V2 supports push down Mode ### What changes were proposed in this pull request? This PR will translate the aggregate function `MODE` for pushdown. The constructor of aggregate function `MODE` has a `deterministic` parameter. When multiple values have the same greatest frequency then either any of values is returned if `deterministic` is false or is not defined, or the lowest value is returned if `deterministic` is true. If `deterministic` is true, the semantics of `deterministic` is the same as the syntax supported by some databases (e.g. H2, Postgres) show below. The syntax is: `MODE() WITHIN GROUP (ORDER BY col)`. Note: `MODE() WITHIN GROUP (ORDER BY col)` doesn't support `DISTINCT` keyword. ### Why are the changes needed? DS V2 supports push down `Mode` ### Does this PR introduce _any_ user-facing change? 'No'. New feature. ### How was this patch tested? New test cases. ### Was this patch authored or co-authored using generative AI tooling? 'No'. Closes #43661 from beliefer/SPARK-45795. Authored-by: Jiaan Geng Signed-off-by: Jiaan Geng --- .../aggregate/GeneralAggregateFunc.java | 1 + .../catalyst/util/V2ExpressionBuilder.scala | 4 ++ .../org/apache/spark/sql/jdbc/H2Dialect.scala | 16 ++++- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 64 +++++++++++++++++++ 4 files changed, 83 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java index 4ef5b7f97e926..4d787eaf9644a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java @@ -41,6 +41,7 @@ *
  • REGR_R2(input1, input2)
    Since 3.4.0
  • *
  • REGR_SLOPE(input1, input2)
    Since 3.4.0
  • *
  • REGR_SXY(input1, input2)
    Since 3.4.0
  • + *
  • MODE(input1[, inverse])
    Since 4.0.0
  • * * * @since 3.3.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 4a8965a6413fc..2766bbaa88805 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -345,6 +345,10 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { Some(new GeneralAggregateFunc("REGR_SLOPE", isDistinct, Array(left, right))) case aggregate.RegrSXY(PushableExpression(left), PushableExpression(right)) => Some(new GeneralAggregateFunc("REGR_SXY", isDistinct, Array(left, right))) + // Translate Mode if it is deterministic or reverse is defined. + case aggregate.Mode(PushableExpression(expr), _, _, Some(reverse)) => + Some(new GeneralAggregateFunc("MODE", isDistinct, + Array(expr, LiteralValue(reverse, BooleanType)))) // TODO supports other aggregate functions case aggregate.V2Aggregator(aggrFunc, children, _, _) => val translatedExprs = children.flatMap(PushableExpression.unapply(_)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala index a42fe989b15c7..d275f9c9cb1b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala @@ -42,7 +42,8 @@ private[sql] object H2Dialect extends JdbcDialect { url.toLowerCase(Locale.ROOT).startsWith("jdbc:h2") private val distinctUnsupportedAggregateFunctions = - Set("COVAR_POP", "COVAR_SAMP", "CORR", "REGR_INTERCEPT", "REGR_R2", "REGR_SLOPE", "REGR_SXY") + Set("COVAR_POP", "COVAR_SAMP", "CORR", "REGR_INTERCEPT", "REGR_R2", "REGR_SLOPE", "REGR_SXY", + "MODE") private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", "AVG", "VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP") ++ distinctUnsupportedAggregateFunctions @@ -256,7 +257,18 @@ private[sql] object H2Dialect extends JdbcDialect { throw new UnsupportedOperationException(s"${this.getClass.getSimpleName} does not " + s"support aggregate function: $funcName with DISTINCT") } else { - super.visitAggregateFunction(funcName, isDistinct, inputs) + funcName match { + case "MODE" => + // Support Mode only if it is deterministic or reverse is defined. + assert(inputs.length == 2) + if (inputs.last == "true") { + s"MODE() WITHIN GROUP (ORDER BY ${inputs.head})" + } else { + s"MODE() WITHIN GROUP (ORDER BY ${inputs.head} DESC)" + } + case _ => + super.visitAggregateFunction(funcName, isDistinct, inputs) + } } override def visitExtract(field: String, source: String): String = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index a81501127a484..0a66680edd639 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -2424,6 +2424,70 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df4, Seq(Row(1100.0, 1100.0), Row(1200.0, 1200.0), Row(1250.0, 1250.0))) } + test("scan with aggregate push-down: MODE with filter and group by") { + val df1 = sql( + """ + |SELECT + | dept, + | MODE(salary, true) + |FROM h2.test.employee WHERE dept > 0 GROUP BY DePt""".stripMargin) + checkFiltersRemoved(df1) + checkAggregateRemoved(df1) + checkPushedInfo(df1, + """ + |PushedAggregates: [MODE(SALARY, true)], + |PushedFilters: [DEPT IS NOT NULL, DEPT > 0], + |PushedGroupByExpressions: [DEPT], + |""".stripMargin.replaceAll("\n", " ")) + checkAnswer(df1, Seq(Row(1, 9000.00), Row(2, 10000.00), Row(6, 12000.00))) + + val df2 = sql( + """ + |SELECT + | dept, + | MODE(salary) + |FROM h2.test.employee WHERE dept > 0 GROUP BY DePt""".stripMargin) + checkFiltersRemoved(df2) + checkAggregateRemoved(df2, false) + checkPushedInfo(df2, + """ + |PushedFilters: [DEPT IS NOT NULL, DEPT > 0], + |""".stripMargin.replaceAll("\n", " ")) + checkAnswer(df2, Seq(Row(1, 10000.00), Row(2, 10000.00), Row(6, 12000.00))) + + val df3 = sql( + """ + |SELECT + | dept, + | MODE() WITHIN GROUP (ORDER BY salary) + |FROM h2.test.employee WHERE dept > 0 GROUP BY DePt""".stripMargin) + checkFiltersRemoved(df3) + checkAggregateRemoved(df3) + checkPushedInfo(df3, + """ + |PushedAggregates: [MODE(SALARY, true)], + |PushedFilters: [DEPT IS NOT NULL, DEPT > 0], + |PushedGroupByExpressions: [DEPT], + |""".stripMargin.replaceAll("\n", " ")) + checkAnswer(df3, Seq(Row(1, 9000.00), Row(2, 10000.00), Row(6, 12000.00))) + + val df4 = sql( + """ + |SELECT + | dept, + | MODE() WITHIN GROUP (ORDER BY salary DESC) + |FROM h2.test.employee WHERE dept > 0 GROUP BY DePt""".stripMargin) + checkFiltersRemoved(df4) + checkAggregateRemoved(df4) + checkPushedInfo(df4, + """ + |PushedAggregates: [MODE(SALARY, false)], + |PushedFilters: [DEPT IS NOT NULL, DEPT > 0], + |PushedGroupByExpressions: [DEPT], + |""".stripMargin.replaceAll("\n", " ")) + checkAnswer(df4, Seq(Row(1, 10000.00), Row(2, 12000.00), Row(6, 12000.00))) + } + test("scan with aggregate push-down: aggregate over alias push down") { val cols = Seq("a", "b", "c", "d", "e") val df1 = sql("SELECT * FROM h2.test.employee").toDF(cols: _*)