@@ -53,7 +53,7 @@ index b386d135da1..46449e3f3f1 100644
53
53
<!--
54
54
This spark-tags test-dep is needed even though it isn't used in this module, otherwise testing-cmds that exclude
55
55
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
56
- index c595b50950b..6b60213e775 100644
56
+ index c595b50950b..3abb6cb9441 100644
57
57
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
58
58
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
59
59
@@ -102,7 +102,7 @@ class SparkSession private(
@@ -79,7 +79,7 @@ index c595b50950b..6b60213e775 100644
79
79
}
80
80
81
81
+ private def loadCometExtension(sparkContext: SparkContext): Seq[String] = {
82
- + if (sparkContext.getConf.getBoolean("spark.comet.enabled", false )) {
82
+ + if (sparkContext.getConf.getBoolean("spark.comet.enabled", isCometEnabled )) {
83
83
+ Seq("org.apache.comet.CometSparkSessionExtensions")
84
84
+ } else {
85
85
+ Seq.empty
@@ -100,6 +100,19 @@ index c595b50950b..6b60213e775 100644
100
100
try {
101
101
val extensionConfClass = Utils.classForName(extensionConfClassName)
102
102
val extensionConf = extensionConfClass.getConstructor().newInstance()
103
+ @@ -1323,4 +1333,12 @@ object SparkSession extends Logging {
104
+ }
105
+ }
106
+ }
107
+ +
108
+ + /**
109
+ + * Whether Comet extension is enabled
110
+ + */
111
+ + def isCometEnabled: Boolean = {
112
+ + val v = System.getenv("ENABLE_COMET")
113
+ + v == null || v.toBoolean
114
+ + }
115
+ }
103
116
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
104
117
index db587dd9868..aac7295a53d 100644
105
118
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
@@ -957,6 +970,37 @@ index 525d97e4998..8a3e7457618 100644
957
970
AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external sort") {
958
971
sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC").collect()
959
972
}
973
+ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
974
+ index 48ad10992c5..51d1ee65422 100644
975
+ --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
976
+ +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
977
+ @@ -221,6 +221,8 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper {
978
+ withSession(extensions) { session =>
979
+ session.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED, true)
980
+ session.conf.set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "-1")
981
+ + // https://github.com/apache/datafusion-comet/issues/1197
982
+ + session.conf.set("spark.comet.enabled", false)
983
+ assert(session.sessionState.columnarRules.contains(
984
+ MyColumnarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule())))
985
+ import session.sqlContext.implicits._
986
+ @@ -279,6 +281,8 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper {
987
+ }
988
+ withSession(extensions) { session =>
989
+ session.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED, enableAQE)
990
+ + // https://github.com/apache/datafusion-comet/issues/1197
991
+ + session.conf.set("spark.comet.enabled", false)
992
+ assert(session.sessionState.columnarRules.contains(
993
+ MyColumnarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule())))
994
+ import session.sqlContext.implicits._
995
+ @@ -317,6 +321,8 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper {
996
+ val session = SparkSession.builder()
997
+ .master("local[1]")
998
+ .config(COLUMN_BATCH_SIZE.key, 2)
999
+ + // https://github.com/apache/datafusion-comet/issues/1197
1000
+ + .config("spark.comet.enabled", false)
1001
+ .withExtensions { extensions =>
1002
+ extensions.injectColumnar(session =>
1003
+ MyColumnarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule())) }
960
1004
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
961
1005
index 75eabcb96f2..36e3318ad7e 100644
962
1006
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -2746,7 +2790,7 @@ index abe606ad9c1..2d930b64cca 100644
2746
2790
val tblTargetName = "tbl_target"
2747
2791
val tblSourceQualified = s"default.$tblSourceName"
2748
2792
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
2749
- index dd55fcfe42c..aa9b0be8e68 100644
2793
+ index dd55fcfe42c..2702f87c1f1 100644
2750
2794
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
2751
2795
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
2752
2796
@@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.plans.PlanTest
@@ -2770,17 +2814,14 @@ index dd55fcfe42c..aa9b0be8e68 100644
2770
2814
}
2771
2815
}
2772
2816
2773
- @@ -242,6 +247,41 @@ private[sql] trait SQLTestUtilsBase
2817
+ @@ -242,6 +247,38 @@ private[sql] trait SQLTestUtilsBase
2774
2818
protected override def _sqlContext: SQLContext = self.spark.sqlContext
2775
2819
}
2776
2820
2777
2821
+ /**
2778
2822
+ * Whether Comet extension is enabled
2779
2823
+ */
2780
- + protected def isCometEnabled: Boolean = {
2781
- + val v = System.getenv("ENABLE_COMET")
2782
- + v != null && v.toBoolean
2783
- + }
2824
+ + protected def isCometEnabled: Boolean = SparkSession.isCometEnabled
2784
2825
+
2785
2826
+ /**
2786
2827
+ * Whether to enable ansi mode This is only effective when
@@ -2812,7 +2853,7 @@ index dd55fcfe42c..aa9b0be8e68 100644
2812
2853
protected override def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
2813
2854
SparkSession.setActiveSession(spark)
2814
2855
super.withSQLConf(pairs: _*)(f)
2815
- @@ -434,6 +474 ,8 @@ private[sql] trait SQLTestUtilsBase
2856
+ @@ -434,6 +471 ,8 @@ private[sql] trait SQLTestUtilsBase
2816
2857
val schema = df.schema
2817
2858
val withoutFilters = df.queryExecution.executedPlan.transform {
2818
2859
case FilterExec(_, child) => child
@@ -2910,10 +2951,10 @@ index 1966e1e64fd..cde97a0aafe 100644
2910
2951
spark.sql(
2911
2952
"""
2912
2953
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
2913
- index 07361cfdce9..6673c141c9a 100644
2954
+ index 07361cfdce9..e40c59a4207 100644
2914
2955
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
2915
2956
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
2916
- @@ -55,25 +55,53 @@ object TestHive
2957
+ @@ -55,25 +55,52 @@ object TestHive
2917
2958
new SparkContext(
2918
2959
System.getProperty("spark.sql.test.master", "local[1]"),
2919
2960
"TestSQLContext",
@@ -2955,8 +2996,7 @@ index 07361cfdce9..6673c141c9a 100644
2955
2996
+ // ConstantPropagation etc.
2956
2997
+ .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName)
2957
2998
+
2958
- + val v = System.getenv("ENABLE_COMET")
2959
- + if (v != null && v.toBoolean) {
2999
+ + if (SparkSession.isCometEnabled) {
2960
3000
+ conf
2961
3001
+ .set("spark.sql.extensions", "org.apache.comet.CometSparkSessionExtensions")
2962
3002
+ .set("spark.comet.enabled", "true")
0 commit comments