diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 9a47491b0cca4..9716d342bb6bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.CatalogHelper import org.apache.spark.sql.connector.catalog.CatalogV2Util.withDefaultOwnership import org.apache.spark.sql.errors.DataTypeErrors.toSQLType +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec} import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec @@ -43,6 +44,29 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { private val collationNonPreservingSources = Seq("orc", "csv", "json", "text") private val allFileBasedDataSources = collationPreservingSources ++ collationNonPreservingSources + @inline + private def isSortMergeForced: Boolean = { + SQLConf.get.getConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD) == -1 + } + + private def checkRightTypeOfJoinUsed(queryPlan: SparkPlan): Unit = { + assert( + collectFirst(queryPlan) { + case _: SortMergeJoinExec => assert(isSortMergeForced) + case _: HashJoin => assert(!isSortMergeForced) + }.nonEmpty + ) + } + + private def checkCollationKeyInQueryPlan(queryPlan: SparkPlan, collationName: String): Unit = { + // Only if collation doesn't support binary equality, collation key should be injected. + if (!CollationFactory.fetchCollation(collationName).supportsBinaryEquality) { + assert(queryPlan.toString().contains("collationkey")) + } else { + assert(!queryPlan.toString().contains("collationkey")) + } + } + test("collate returns proper type") { Seq( "utf8_binary", @@ -1419,7 +1443,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { for (codeGen <- Seq("NO_CODEGEN", "CODEGEN_ONLY")) { val collationSetup = if (collation.isEmpty) "" else " COLLATE " + collation val supportsBinaryEquality = collation.isEmpty || collation == "UNICODE" || - CollationFactory.fetchCollation(collation).isUtf8BinaryType + CollationFactory.fetchCollation(collation).supportsBinaryEquality test(s"Group by on map containing$collationSetup strings ($codeGen)") { val tableName = "t" @@ -1589,7 +1613,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } } - test("hash join should be used for collated strings") { + test("hash join should be used for collated strings if sort merge join is not forced") { val t1 = "T_1" val t2 = "T_2" @@ -1602,47 +1626,48 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { HashJoinTestCase("UNICODE_CI_RTRIM", "aa", "AA ", Seq(Row("aa", 1, "AA ", 2), Row("aa", 1, "aa", 2))) ) - - testCases.foreach(t => { + for { + t <- testCases + broadcastJoinThreshold <- Seq(-1, SQLConf.get.getConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) + } { withTable(t1, t2) { - sql(s"CREATE TABLE $t1 (x STRING COLLATE ${t.collation}, i int) USING PARQUET") - sql(s"INSERT INTO $t1 VALUES ('${t.data1}', 1)") + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> broadcastJoinThreshold.toString) { + sql(s"CREATE TABLE $t1 (x STRING COLLATE ${t.collation}, i int) USING PARQUET") + sql(s"INSERT INTO $t1 VALUES ('${t.data1}', 1)") - sql(s"CREATE TABLE $t2 (y STRING COLLATE ${t.collation}, j int) USING PARQUET") - sql(s"INSERT INTO $t2 VALUES ('${t.data2}', 2), ('${t.data1}', 2)") + sql(s"CREATE TABLE $t2 (y STRING COLLATE ${t.collation}, j int) USING PARQUET") + sql(s"INSERT INTO $t2 VALUES ('${t.data2}', 2), ('${t.data1}', 2)") - val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") - checkAnswer(df, t.result) + val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") + checkAnswer(df, t.result) - val queryPlan = df.queryExecution.executedPlan + val queryPlan = df.queryExecution.executedPlan - // confirm that hash join is used instead of sort merge join - assert( - collectFirst(queryPlan) { - case _: HashJoin => () - }.nonEmpty - ) - assert( - collectFirst(queryPlan) { - case _: SortMergeJoinExec => () - }.isEmpty - ) + // confirm that right kind of join is used. + checkRightTypeOfJoinUsed(queryPlan) - // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { - assert(collectFirst(queryPlan) { - case b: HashJoin => b.leftKeys.head - }.head.isInstanceOf[CollationKey]) - } else { - assert(!collectFirst(queryPlan) { - case b: HashJoin => b.leftKeys.head - }.head.isInstanceOf[CollationKey]) + if (isSortMergeForced) { + // Confirm proper injection of collation key. + checkCollationKeyInQueryPlan(queryPlan, t.collation) + } + else { + // Only if collation doesn't support binary equality, collation key should be injected. + if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { + assert(collectFirst(queryPlan) { + case b: HashJoin => b.leftKeys.head + }.head.isInstanceOf[CollationKey]) + } else { + assert(!collectFirst(queryPlan) { + case b: HashJoin => b.leftKeys.head + }.head.isInstanceOf[CollationKey]) + } + } } } - }) + } } - test("hash join should be used for arrays of collated strings") { + test("hash join should be used for arrays of collated strings if sort merge join is not forced") { val t1 = "T_1" val t2 = "T_2" @@ -1660,47 +1685,50 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { Seq(Row(Seq("aa"), 1, Seq("AA "), 2), Row(Seq("aa"), 1, Seq("aa"), 2))) ) - testCases.foreach(t => { + for { + t <- testCases + broadcastJoinThreshold <- Seq(-1, SQLConf.get.getConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) + } { withTable(t1, t2) { - sql(s"CREATE TABLE $t1 (x ARRAY, i int) USING PARQUET") - sql(s"INSERT INTO $t1 VALUES (array('${t.data1}'), 1)") + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> broadcastJoinThreshold.toString) { + sql(s"CREATE TABLE $t1 (x ARRAY, i int) USING PARQUET") + sql(s"INSERT INTO $t1 VALUES (array('${t.data1}'), 1)") - sql(s"CREATE TABLE $t2 (y ARRAY, j int) USING PARQUET") - sql(s"INSERT INTO $t2 VALUES (array('${t.data2}'), 2), (array('${t.data1}'), 2)") + sql(s"CREATE TABLE $t2 (y ARRAY, j int) USING PARQUET") + sql(s"INSERT INTO $t2 VALUES (array('${t.data2}'), 2), (array('${t.data1}'), 2)") - val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") - checkAnswer(df, t.result) + val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") + checkAnswer(df, t.result) - val queryPlan = df.queryExecution.executedPlan + val queryPlan = df.queryExecution.executedPlan - // confirm that hash join is used instead of sort merge join - assert( - collectFirst(queryPlan) { - case _: HashJoin => () - }.nonEmpty - ) - assert( - collectFirst(queryPlan) { - case _: ShuffledJoin => () - }.isEmpty - ) + // confirm that right kind of join is used. + checkRightTypeOfJoinUsed(queryPlan) - // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { - assert(collectFirst(queryPlan) { - case b: BroadcastHashJoinExec => b.leftKeys.head - }.head.asInstanceOf[ArrayTransform].function.asInstanceOf[LambdaFunction]. - function.isInstanceOf[CollationKey]) - } else { - assert(!collectFirst(queryPlan) { - case b: BroadcastHashJoinExec => b.leftKeys.head - }.head.isInstanceOf[ArrayTransform]) + if (isSortMergeForced) { + // Confirm proper injection of collation key. + checkCollationKeyInQueryPlan(queryPlan, t.collation) + } + else { + // Only if collation doesn't support binary equality, collation key should be injected. + if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { + assert(collectFirst(queryPlan) { + case b: BroadcastHashJoinExec => b.leftKeys.head + }.head.asInstanceOf[ArrayTransform].function.asInstanceOf[LambdaFunction]. + function.isInstanceOf[CollationKey]) + } else { + assert(!collectFirst(queryPlan) { + case b: BroadcastHashJoinExec => b.leftKeys.head + }.head.isInstanceOf[ArrayTransform]) + } + } } } - }) + } } - test("hash join should be used for arrays of arrays of collated strings") { + test("hash join should be used for arrays of arrays of collated strings " + + "if sort merge join is not forced") { val t1 = "T_1" val t2 = "T_2" @@ -1718,51 +1746,53 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { Seq(Row(Seq(Seq("aa")), 1, Seq(Seq("AA ")), 2), Row(Seq(Seq("aa")), 1, Seq(Seq("aa")), 2))) ) - testCases.foreach(t => { + for { + t <- testCases + broadcastJoinThreshold <- Seq(-1, SQLConf.get.getConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) + } { withTable(t1, t2) { - sql(s"CREATE TABLE $t1 (x ARRAY>, i int) USING " + - s"PARQUET") - sql(s"INSERT INTO $t1 VALUES (array(array('${t.data1}')), 1)") + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> broadcastJoinThreshold.toString) { + sql(s"CREATE TABLE $t1 (x ARRAY>, i int) USING " + + s"PARQUET") + sql(s"INSERT INTO $t1 VALUES (array(array('${t.data1}')), 1)") - sql(s"CREATE TABLE $t2 (y ARRAY>, j int) USING " + - s"PARQUET") - sql(s"INSERT INTO $t2 VALUES (array(array('${t.data2}')), 2)," + - s" (array(array('${t.data1}')), 2)") + sql(s"CREATE TABLE $t2 (y ARRAY>, j int) USING " + + s"PARQUET") + sql(s"INSERT INTO $t2 VALUES (array(array('${t.data2}')), 2)," + + s" (array(array('${t.data1}')), 2)") - val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") - checkAnswer(df, t.result) + val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") + checkAnswer(df, t.result) - val queryPlan = df.queryExecution.executedPlan + val queryPlan = df.queryExecution.executedPlan - // confirm that hash join is used instead of sort merge join - assert( - collectFirst(queryPlan) { - case _: HashJoin => () - }.nonEmpty - ) - assert( - collectFirst(queryPlan) { - case _: ShuffledJoin => () - }.isEmpty - ) + // confirm that right kind of join is used. + checkRightTypeOfJoinUsed(queryPlan) - // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { - assert(collectFirst(queryPlan) { - case b: BroadcastHashJoinExec => b.leftKeys.head - }.head.asInstanceOf[ArrayTransform].function. - asInstanceOf[LambdaFunction].function.asInstanceOf[ArrayTransform].function. - asInstanceOf[LambdaFunction].function.isInstanceOf[CollationKey]) - } else { - assert(!collectFirst(queryPlan) { - case b: BroadcastHashJoinExec => b.leftKeys.head - }.head.isInstanceOf[ArrayTransform]) + if (isSortMergeForced) { + // Confirm proper injection of collation key. + checkCollationKeyInQueryPlan(queryPlan, t.collation) + } + else { + // Only if collation doesn't support binary equality, collation key should be injected. + if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { + assert(collectFirst(queryPlan) { + case b: BroadcastHashJoinExec => b.leftKeys.head + }.head.asInstanceOf[ArrayTransform].function. + asInstanceOf[LambdaFunction].function.asInstanceOf[ArrayTransform].function. + asInstanceOf[LambdaFunction].function.isInstanceOf[CollationKey]) + } else { + assert(!collectFirst(queryPlan) { + case b: BroadcastHashJoinExec => b.leftKeys.head + }.head.isInstanceOf[ArrayTransform]) + } + } } } - }) + } } - test("hash join should respect collation for struct of strings") { + test("hash and sort merge join should respect collation for struct of strings") { val t1 = "T_1" val t2 = "T_2" @@ -1779,43 +1809,36 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { HashJoinTestCase("UNICODE_CI_RTRIM", "aa", "AA ", Seq(Row(Row("aa"), 1, Row("AA "), 2), Row(Row("aa"), 1, Row("aa"), 2))) ) - testCases.foreach(t => { + for { + t <- testCases + broadcastJoinThreshold <- Seq(-1, SQLConf.get.getConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) + } { withTable(t1, t2) { - sql(s"CREATE TABLE $t1 (x STRUCT, i int) USING PARQUET") - sql(s"INSERT INTO $t1 VALUES (named_struct('f', '${t.data1}'), 1)") + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> broadcastJoinThreshold.toString) { + sql(s"CREATE TABLE $t1 (x STRUCT, i int) USING PARQUET") + sql(s"INSERT INTO $t1 VALUES (named_struct('f', '${t.data1}'), 1)") - sql(s"CREATE TABLE $t2 (y STRUCT, j int) USING PARQUET") - sql(s"INSERT INTO $t2 VALUES (named_struct('f', '${t.data2}'), 2)," + - s" (named_struct('f', '${t.data1}'), 2)") + sql(s"CREATE TABLE $t2 (y STRUCT, j int) USING PARQUET") + sql(s"INSERT INTO $t2 VALUES (named_struct('f', '${t.data2}'), 2)," + + s" (named_struct('f', '${t.data1}'), 2)") - val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") - checkAnswer(df, t.result) + val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") + checkAnswer(df, t.result) - val queryPlan = df.queryExecution.executedPlan + val queryPlan = df.queryExecution.executedPlan - // Confirm that hash join is used instead of sort merge join. - assert( - collectFirst(queryPlan) { - case _: HashJoin => () - }.nonEmpty - ) - assert( - collectFirst(queryPlan) { - case _: ShuffledJoin => () - }.isEmpty - ) + // confirm that right kind of join is used. + checkRightTypeOfJoinUsed(queryPlan) - // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { - assert(queryPlan.toString().contains("collationkey")) - } else { - assert(!queryPlan.toString().contains("collationkey")) + // Confirm proper injection of collation key. + checkCollationKeyInQueryPlan(queryPlan, t.collation) } } - }) + } } - test("hash join should respect collation for struct of array of struct of strings") { + test("hash and sort merge join should respect collation " + + "for struct of array of struct of strings") { val t1 = "T_1" val t2 = "T_2" @@ -1835,43 +1858,36 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { Seq(Row(Row(Seq(Row("aa"))), 1, Row(Seq(Row("AA "))), 2), Row(Row(Seq(Row("aa"))), 1, Row(Seq(Row("aa"))), 2))) ) - testCases.foreach(t => { + + for { + t <- testCases + broadcastJoinThreshold <- Seq(-1, SQLConf.get.getConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) + } { withTable(t1, t2) { - sql(s"CREATE TABLE $t1 (x STRUCT>>, " + - s"i int) USING PARQUET") - sql(s"INSERT INTO $t1 VALUES (named_struct('f', array(named_struct('f', '${t.data1}'))), 1)" - ) + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> broadcastJoinThreshold.toString) { + sql(s"CREATE TABLE $t1 (x STRUCT>>, " + + s"i int) USING PARQUET") + sql(s"INSERT INTO $t1 VALUES (named_struct('f', array(named_struct('f', " + + s"'${t.data1}'))), 1)") - sql(s"CREATE TABLE $t2 (y STRUCT>>, " + - s"j int) USING PARQUET") - sql(s"INSERT INTO $t2 VALUES (named_struct('f', array(named_struct('f', '${t.data2}'))), 2)" - + s", (named_struct('f', array(named_struct('f', '${t.data1}'))), 2)") + sql(s"CREATE TABLE $t2 (y STRUCT>>, " + + s"j int) USING PARQUET") + sql(s"INSERT INTO $t2 VALUES (named_struct('f', array(named_struct('f', " + + s"'${t.data2}'))), 2), (named_struct('f', array(named_struct('f', '${t.data1}'))), 2)") - val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") - checkAnswer(df, t.result) + val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") + checkAnswer(df, t.result) - val queryPlan = df.queryExecution.executedPlan + val queryPlan = df.queryExecution.executedPlan - // confirm that hash join is used instead of sort merge join - assert( - collectFirst(queryPlan) { - case _: HashJoin => () - }.nonEmpty - ) - assert( - collectFirst(queryPlan) { - case _: ShuffledJoin => () - }.isEmpty - ) + // confirm that right kind of join is used. + checkRightTypeOfJoinUsed(queryPlan) - // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { - assert(queryPlan.toString().contains("collationkey")) - } else { - assert(!queryPlan.toString().contains("collationkey")) + // Confirm proper injection of collation key. + checkCollationKeyInQueryPlan(queryPlan, t.collation) } } - }) + } } test("rewrite with collationkey should be a non-excludable rule") { @@ -1931,31 +1947,27 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { "'a', 'a', 1", "'A', 'A ', 1", Row("a", "a", 1, "A", "A ", 1)) ) - testCases.foreach(t => { + for { + t <- testCases + broadcastJoinThreshold <- Seq(-1, SQLConf.get.getConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) + } { withTable(t1, t2) { - sql(s"CREATE TABLE $t1 (x ${t.type1}, y ${t.type2}, i int) USING PARQUET") - sql(s"INSERT INTO $t1 VALUES (${t.data1})") - sql(s"CREATE TABLE $t2 (x ${t.type1}, y ${t.type2}, i int) USING PARQUET") - sql(s"INSERT INTO $t2 VALUES (${t.data2})") + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> broadcastJoinThreshold.toString) { + sql(s"CREATE TABLE $t1 (x ${t.type1}, y ${t.type2}, i int) USING PARQUET") + sql(s"INSERT INTO $t1 VALUES (${t.data1})") + sql(s"CREATE TABLE $t2 (x ${t.type1}, y ${t.type2}, i int) USING PARQUET") + sql(s"INSERT INTO $t2 VALUES (${t.data2})") - val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.x AND $t1.y = $t2.y") - checkAnswer(df, t.result) + val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.x AND $t1.y = $t2.y") + checkAnswer(df, t.result) - val queryPlan = df.queryExecution.executedPlan + val queryPlan = df.queryExecution.executedPlan - // confirm that hash join is used instead of sort merge join - assert( - collectFirst(queryPlan) { - case _: HashJoin => () - }.nonEmpty - ) - assert( - collectFirst(queryPlan) { - case _: SortMergeJoinExec => () - }.isEmpty - ) + // confirm that right kind of join is used. + checkRightTypeOfJoinUsed(queryPlan) + } } - }) + } } test("hll sketch aggregate should respect collation") {