diff --git a/core/src/test/scala/org/apache/spark/CheckErrorHelper.scala b/core/src/test/scala/org/apache/spark/CheckErrorHelper.scala new file mode 100644 index 0000000000000..d01600bb439f1 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/CheckErrorHelper.scala @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.collection.mutable.ListBuffer +import scala.jdk.CollectionConverters._ + +import org.scalatest.Suite + +trait CheckErrorHelper { self: Suite => + + case class ExpectedContext( + contextType: QueryContextType, + objectType: String, + objectName: String, + startIndex: Int, + stopIndex: Int, + fragment: String, + callSitePattern: String + ) + + object ExpectedContext { + def apply(fragment: String, start: Int, stop: Int): ExpectedContext = { + ExpectedContext("", "", start, stop, fragment) + } + + // Check the fragment only. This is only used when the fragment is distinguished within + // the query text + def apply(fragment: String): ExpectedContext = { + ExpectedContext("", "", -1, -1, fragment) + } + + def apply( + objectType: String, + objectName: String, + startIndex: Int, + stopIndex: Int, + fragment: String): ExpectedContext = { + new ExpectedContext(QueryContextType.SQL, objectType, objectName, startIndex, stopIndex, + fragment, "") + } + + def apply(fragment: String, callSitePattern: String): ExpectedContext = { + new ExpectedContext(QueryContextType.DataFrame, "", "", -1, -1, fragment, callSitePattern) + } + } + + /** + * Parameter keys that are omitted from comparison when absent from the expected map. + * For each error condition, the set lists keys that are removed from the actual + * exception parameters before comparison with the expected map. + * Test suites may override this to add or change ignorable parameters per condition. + */ + protected def checkErrorIgnorableParameters: Map[String, Set[String]] = Map( + "TABLE_OR_VIEW_NOT_FOUND" -> Set("searchPath") + ) + + /** + * Checks an exception with an error condition against expected results. + * @param exception The exception to check + * @param condition The expected error condition identifying the error + * @param sqlState Optional the expected SQLSTATE, not verified if not supplied + * @param parameters A map of parameter names and values. The names are as defined + * in the error-classes file. + * @param matchPVals Optionally treat the parameters value as regular expression pattern. + * false if not supplied. + */ + protected def checkError( + exception: SparkThrowable, + condition: String, + sqlState: Option[String] = None, + parameters: Map[String, String] = Map.empty, + matchPVals: Boolean = false, + queryContext: Array[ExpectedContext] = Array.empty): Unit = { + val mismatches = new ListBuffer[String] + + if (exception.getCondition != condition) { + mismatches += s"condition: expected '$condition' but got '${exception.getCondition}'" + } + sqlState.foreach { state => + if (exception.getSqlState != state) { + mismatches += s"sqlState: expected '$state' but got '${exception.getSqlState}'" + } + } + + val actualParameters = exception.getMessageParameters.asScala + val ignorable = checkErrorIgnorableParameters.getOrElse(condition, Set.empty[String]) + val actualParametersToCompare = actualParameters.filter { case (k, _) => + !ignorable.contains(k) || parameters.contains(k) + } + if (matchPVals) { + if (actualParametersToCompare.size != parameters.size) { + mismatches += s"parameters size: expected ${parameters.size} but got" + + s" ${actualParametersToCompare.size}" + } + actualParametersToCompare.foreach { case (key, actualVal) => + parameters.get(key) match { + case None => + mismatches += s"parameters: unexpected key '$key' with value '$actualVal'" + case Some(pattern) if !actualVal.matches(pattern) => + mismatches += s"parameters['$key']: value '$actualVal' does not match pattern" + + s" '$pattern'" + case _ => + } + } + parameters.keys.filterNot(actualParametersToCompare.contains).foreach { key => + mismatches += s"parameters: missing expected key '$key'" + } + } else if (actualParametersToCompare != parameters) { + mismatches += s"parameters: expected $parameters but got $actualParametersToCompare" + } + + val actualQueryContext = exception.getQueryContext() + if (actualQueryContext.length != queryContext.length) { + mismatches += s"queryContext.length: expected ${queryContext.length}" + + s" but got ${actualQueryContext.length}" + } + actualQueryContext.zip(queryContext).zipWithIndex.foreach { + case ((actual, expected), idx) => + if (actual.contextType() != expected.contextType) { + mismatches += s"queryContext[$idx].contextType: expected ${expected.contextType}" + + s" but got ${actual.contextType()}" + } + if (actual.contextType() == QueryContextType.SQL) { + if (actual.objectType() != expected.objectType) { + mismatches += s"queryContext[$idx].objectType: expected '${expected.objectType}'" + + s" but got '${actual.objectType()}'" + } + if (actual.objectName() != expected.objectName) { + mismatches += s"queryContext[$idx].objectName: expected '${expected.objectName}'" + + s" but got '${actual.objectName()}'" + } + // If startIndex and stopIndex are -1, it means we simply want to check the + // fragment of the query context. This should be the case when the fragment is + // distinguished within the query text. + if (expected.startIndex != -1 && actual.startIndex() != expected.startIndex) { + mismatches += s"queryContext[$idx].startIndex: expected ${expected.startIndex}" + + s" but got ${actual.startIndex()}" + } + if (expected.stopIndex != -1 && actual.stopIndex() != expected.stopIndex) { + mismatches += s"queryContext[$idx].stopIndex: expected ${expected.stopIndex}" + + s" but got ${actual.stopIndex()}" + } + if (actual.fragment() != expected.fragment) { + mismatches += s"queryContext[$idx].fragment: expected '${expected.fragment}'" + + s" but got '${actual.fragment()}'" + } + } else if (actual.contextType() == QueryContextType.DataFrame) { + if (actual.fragment() != expected.fragment) { + mismatches += s"queryContext[$idx].fragment: expected '${expected.fragment}'" + + s" but got '${actual.fragment()}'" + } + if (expected.callSitePattern.nonEmpty && + !actual.callSite().matches(expected.callSitePattern)) { + mismatches += s"queryContext[$idx].callSite: '${actual.callSite()}'" + + s" does not match pattern '${expected.callSitePattern}'" + } + } + } + + if (mismatches.nonEmpty) { + val sb = new StringBuilder + sb.append(s"checkError found ${mismatches.size} mismatch(es).\n\n") + sb.append("=== Actual Exception State ===\n") + sb.append(s" condition: ${exception.getCondition}\n") + sb.append(s" sqlState: ${exception.getSqlState}\n") + sb.append(s" parameters:\n") + if (actualParameters.isEmpty) { + sb.append(" (empty)\n") + } else { + actualParameters.foreach { case (k, v) => sb.append(s" $k -> $v\n") } + } + actualQueryContext.zipWithIndex.foreach { case (ctx, idx) => + sb.append(s" queryContext[$idx] (${ctx.contextType()}):\n") + if (ctx.contextType() == QueryContextType.SQL) { + sb.append(s" objectType: ${ctx.objectType()}\n") + sb.append(s" objectName: ${ctx.objectName()}\n") + sb.append(s" startIndex: ${ctx.startIndex()}\n") + sb.append(s" stopIndex: ${ctx.stopIndex()}\n") + sb.append(s" fragment: ${ctx.fragment()}\n") + } else if (ctx.contextType() == QueryContextType.DataFrame) { + sb.append(s" fragment: ${ctx.fragment()}\n") + sb.append(s" callSite: ${ctx.callSite()}\n") + } + } + sb.append("\n=== Mismatches ===\n") + mismatches.foreach(m => sb.append(s" $m\n")) + fail(sb.toString()) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/SparkTestSuite.scala b/core/src/test/scala/org/apache/spark/SparkTestSuite.scala index 10504684be9fd..0fd595bf3fdf3 100644 --- a/core/src/test/scala/org/apache/spark/SparkTestSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkTestSuite.scala @@ -22,8 +22,7 @@ import java.nio.charset.StandardCharsets.UTF_8 import java.nio.file.{Files, Path} import java.util.{Locale, TimeZone} -import scala.collection.mutable.{ArrayBuffer, ListBuffer} -import scala.jdk.CollectionConverters._ +import scala.collection.mutable.ArrayBuffer import org.apache.logging.log4j._ import org.apache.logging.log4j.core.{LogEvent, Logger, LoggerContext} @@ -70,6 +69,7 @@ trait SparkTestSuite with BeforeAndAfterEach with ThreadAudit with TimeLimits + with CheckErrorHelper with Logging { // scalastyle:on @@ -274,150 +274,6 @@ trait SparkTestSuite } } - /** - * Parameter keys that are omitted from comparison when absent from the expected map. - * For each error condition, the set lists keys that are removed from the actual - * exception parameters before comparison with the expected map. - * Test suites may override this to add or change ignorable parameters per condition. - */ - protected def checkErrorIgnorableParameters: Map[String, Set[String]] = Map( - "TABLE_OR_VIEW_NOT_FOUND" -> Set("searchPath") - ) - - /** - * Checks an exception with an error condition against expected results. - * @param exception The exception to check - * @param condition The expected error condition identifying the error - * @param sqlState Optional the expected SQLSTATE, not verified if not supplied - * @param parameters A map of parameter names and values. The names are as defined - * in the error-classes file. - * @param matchPVals Optionally treat the parameters value as regular expression pattern. - * false if not supplied. - */ - protected def checkError( - exception: SparkThrowable, - condition: String, - sqlState: Option[String] = None, - parameters: Map[String, String] = Map.empty, - matchPVals: Boolean = false, - queryContext: Array[ExpectedContext] = Array.empty): Unit = { - val mismatches = new ListBuffer[String] - - if (exception.getCondition != condition) { - mismatches += s"condition: expected '$condition' but got '${exception.getCondition}'" - } - sqlState.foreach { state => - if (exception.getSqlState != state) { - mismatches += s"sqlState: expected '$state' but got '${exception.getSqlState}'" - } - } - - val actualParameters = exception.getMessageParameters.asScala - val ignorable = checkErrorIgnorableParameters.getOrElse(condition, Set.empty[String]) - val actualParametersToCompare = actualParameters.filter { case (k, _) => - !ignorable.contains(k) || parameters.contains(k) - } - if (matchPVals) { - if (actualParametersToCompare.size != parameters.size) { - mismatches += s"parameters size: expected ${parameters.size} but got" + - s" ${actualParametersToCompare.size}" - } - actualParametersToCompare.foreach { case (key, actualVal) => - parameters.get(key) match { - case None => - mismatches += s"parameters: unexpected key '$key' with value '$actualVal'" - case Some(pattern) if !actualVal.matches(pattern) => - mismatches += s"parameters['$key']: value '$actualVal' does not match pattern" + - s" '$pattern'" - case _ => - } - } - parameters.keys.filterNot(actualParametersToCompare.contains).foreach { key => - mismatches += s"parameters: missing expected key '$key'" - } - } else if (actualParametersToCompare != parameters) { - mismatches += s"parameters: expected $parameters but got $actualParametersToCompare" - } - - val actualQueryContext = exception.getQueryContext() - if (actualQueryContext.length != queryContext.length) { - mismatches += s"queryContext.length: expected ${queryContext.length}" + - s" but got ${actualQueryContext.length}" - } - actualQueryContext.zip(queryContext).zipWithIndex.foreach { - case ((actual, expected), idx) => - if (actual.contextType() != expected.contextType) { - mismatches += s"queryContext[$idx].contextType: expected ${expected.contextType}" + - s" but got ${actual.contextType()}" - } - if (actual.contextType() == QueryContextType.SQL) { - if (actual.objectType() != expected.objectType) { - mismatches += s"queryContext[$idx].objectType: expected '${expected.objectType}'" + - s" but got '${actual.objectType()}'" - } - if (actual.objectName() != expected.objectName) { - mismatches += s"queryContext[$idx].objectName: expected '${expected.objectName}'" + - s" but got '${actual.objectName()}'" - } - // If startIndex and stopIndex are -1, it means we simply want to check the - // fragment of the query context. This should be the case when the fragment is - // distinguished within the query text. - if (expected.startIndex != -1 && actual.startIndex() != expected.startIndex) { - mismatches += s"queryContext[$idx].startIndex: expected ${expected.startIndex}" + - s" but got ${actual.startIndex()}" - } - if (expected.stopIndex != -1 && actual.stopIndex() != expected.stopIndex) { - mismatches += s"queryContext[$idx].stopIndex: expected ${expected.stopIndex}" + - s" but got ${actual.stopIndex()}" - } - if (actual.fragment() != expected.fragment) { - mismatches += s"queryContext[$idx].fragment: expected '${expected.fragment}'" + - s" but got '${actual.fragment()}'" - } - } else if (actual.contextType() == QueryContextType.DataFrame) { - if (actual.fragment() != expected.fragment) { - mismatches += s"queryContext[$idx].fragment: expected '${expected.fragment}'" + - s" but got '${actual.fragment()}'" - } - if (expected.callSitePattern.nonEmpty && - !actual.callSite().matches(expected.callSitePattern)) { - mismatches += s"queryContext[$idx].callSite: '${actual.callSite()}'" + - s" does not match pattern '${expected.callSitePattern}'" - } - } - } - - if (mismatches.nonEmpty) { - val sb = new StringBuilder - sb.append(s"checkError found ${mismatches.size} mismatch(es).\n\n") - sb.append("=== Actual Exception State ===\n") - sb.append(s" condition: ${exception.getCondition}\n") - sb.append(s" sqlState: ${exception.getSqlState}\n") - sb.append(s" parameters:\n") - if (actualParameters.isEmpty) { - sb.append(" (empty)\n") - } else { - actualParameters.foreach { case (k, v) => sb.append(s" $k -> $v\n") } - } - actualQueryContext.zipWithIndex.foreach { case (ctx, idx) => - sb.append(s" queryContext[$idx] (${ctx.contextType()}):\n") - if (ctx.contextType() == QueryContextType.SQL) { - sb.append(s" objectType: ${ctx.objectType()}\n") - sb.append(s" objectName: ${ctx.objectName()}\n") - sb.append(s" startIndex: ${ctx.startIndex()}\n") - sb.append(s" stopIndex: ${ctx.stopIndex()}\n") - sb.append(s" fragment: ${ctx.fragment()}\n") - } else if (ctx.contextType() == QueryContextType.DataFrame) { - sb.append(s" fragment: ${ctx.fragment()}\n") - sb.append(s" callSite: ${ctx.callSite()}\n") - } - } - sb.append("\n=== Mismatches ===\n") - mismatches.foreach(m => sb.append(s" $m\n")) - fail(sb.toString()) - } - } - protected def checkError( exception: SparkThrowable, condition: String, @@ -524,42 +380,6 @@ trait SparkTestSuite condition = "TABLE_OR_VIEW_ALREADY_EXISTS", parameters = Map("relationName" -> tableName)) - case class ExpectedContext( - contextType: QueryContextType, - objectType: String, - objectName: String, - startIndex: Int, - stopIndex: Int, - fragment: String, - callSitePattern: String - ) - - object ExpectedContext { - def apply(fragment: String, start: Int, stop: Int): ExpectedContext = { - ExpectedContext("", "", start, stop, fragment) - } - - // Check the fragment only. This is only used when the fragment is distinguished within - // the query text - def apply(fragment: String): ExpectedContext = { - ExpectedContext("", "", -1, -1, fragment) - } - - def apply( - objectType: String, - objectName: String, - startIndex: Int, - stopIndex: Int, - fragment: String): ExpectedContext = { - new ExpectedContext(QueryContextType.SQL, objectType, objectName, startIndex, stopIndex, - fragment, "") - } - - def apply(fragment: String, callSitePattern: String): ExpectedContext = { - new ExpectedContext(QueryContextType.DataFrame, "", "", -1, -1, fragment, callSitePattern) - } - } - class LogAppender(msg: String = "", maxEvents: Int = 1000) extends AbstractAppender("logAppender", null, null, true, Property.EMPTY_ARRAY) { private val _loggingEvents = new ArrayBuffer[LogEvent]() diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala index c27a83b79b89f..14124b048e4f8 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala @@ -241,6 +241,29 @@ class Dataset[T] private[sql] ( // scalastyle:on println } + private[connect] def explainString(mode: String): String = { + val protoMode = mode.trim.toLowerCase(util.Locale.ROOT) match { + case "simple" => proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_SIMPLE + case "extended" => proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_EXTENDED + case "codegen" => proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_CODEGEN + case "cost" => proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_COST + case "formatted" => proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_FORMATTED + case _ => throw new IllegalArgumentException("Unsupported explain mode: " + mode) + } + sparkSession + .analyze(plan, proto.AnalyzePlanRequest.AnalyzeCase.EXPLAIN, Some(protoMode)) + .getExplain + .getExplainString + } + + private[connect] def explainString(extended: Boolean): String = if (extended) { + explainString("extended") + } else { + explainString("simple") + } + + private[connect] def explainString(): String = explainString("simple") + /** @inheritdoc */ def isLocal: Boolean = sparkSession .analyze(plan, proto.AnalyzePlanRequest.AnalyzeCase.IS_LOCAL) diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/DataSourceV2DataFrameConnectSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/DataSourceV2DataFrameConnectSuite.scala index 1a31e5f8ac1a3..d3294f8af3b4a 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/DataSourceV2DataFrameConnectSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/DataSourceV2DataFrameConnectSuite.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.connect import scala.reflect.ClassTag import org.apache.spark.SparkConf -import org.apache.spark.sql.{DataFrame, QueryTest, Row, SparkSession} +import org.apache.spark.sql.{classic, connect, SparkSession} +import org.apache.spark.sql.connect.service.{SessionKey, SparkConnectService} import org.apache.spark.sql.connector.{DSv2CacheTableReadTests, DSv2IncrementallyConstructedQueryTests, DSv2RepeatedTableAccessTests, DSv2TempViewWithStoredPlanTests} import org.apache.spark.sql.connector.catalog.{CachingInMemoryTableCatalog, InMemoryTableCatalog, NullTableIdAndNullColumnIdInMemoryTableCatalog, NullTableIdInMemoryTableCatalog, TableCatalog} @@ -34,7 +35,7 @@ import org.apache.spark.sql.connector.catalog.{CachingInMemoryTableCatalog, InMe * this class only provides the Connect-specific session, catalog access, and result comparison. */ class DataSourceV2DataFrameConnectSuite - extends SparkConnectServerTest + extends SessionQueryTest with DSv2TempViewWithStoredPlanTests with DSv2RepeatedTableAccessTests with DSv2IncrementallyConstructedQueryTests @@ -53,45 +54,26 @@ class DataSourceV2DataFrameConnectSuite .set("spark.sql.catalog.nullbothidscat.copyOnLoad", "true") override protected def testPrefix: String = "[connect] " - override protected def isConnect: Boolean = true - override protected def withTestSession(fn: SparkSession => Unit): Unit = - withSession(fn) - - // Cannot use QueryTest.checkAnswer directly because it accesses df.logicalPlan, - // df.queryExecution, and df.materializedRdd, which are not available on Connect *client* - // DataFrames (they throw ConnectClientUnsupportedErrors). Note: checkAnswer IS usable from - // Connect server tests that operate on classic server-side DataFrames, but in this suite - // `df` is a Connect client DataFrame returned by session.table() / session.sql(). - // Instead, collect the rows and delegate to QueryTest.sameRows, which is the same - // value-based, order-agnostic comparison that checkAnswer uses internally. - override protected def checkRows(df: => DataFrame, expected: Seq[Row]): Unit = - QueryTest.sameRows(expected, df.collect().toSeq).foreach(msg => fail(msg)) + protected def getServerSession(clientSession: SparkSession): classic.SparkSession = { + val connectSession = clientSession.asInstanceOf[connect.SparkSession] + val userId = connectSession.client.userId + val sessionId = connectSession.sessionId + val key = SessionKey(userId, sessionId) + SparkConnectService.sessionManager + .getIsolatedSessionIfPresent(key) + .get + .session + } override protected def getTableCatalog[C <: TableCatalog: ClassTag]( session: SparkSession, catalogName: String): C = { - val serverSession = getServerSession(session) - val catalog = serverSession.sessionState.catalogManager.catalog(catalogName) + val catalog = getServerSession(session).sessionState.catalogManager.catalog(catalogName) val ct = implicitly[ClassTag[C]] require( ct.runtimeClass.isInstance(catalog), s"Expected ${ct.runtimeClass.getName} but got ${catalog.getClass.getName}") catalog.asInstanceOf[C] } - - // No explicit clearCache() for cachingcat is needed here, unlike the classic suite. - // Each withSession call creates a freshly isolated SparkSession on the server side - // (via SparkConnectSessionManager.newIsolatedSession), and afterEach invalidates all - // sessions, so the CachingInMemoryTableCatalog instance is per-test. - override protected def withTestTableAndViews( - session: SparkSession, - table: String, - views: Seq[String] = Seq.empty)(fn: => Unit): Unit = { - try { fn } - finally { - views.foreach(v => session.sql(s"DROP VIEW IF EXISTS $v").collect()) - session.sql(s"DROP TABLE IF EXISTS $table").collect() - } - } } diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ExampleSessionAgnosticConnectSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ExampleSessionAgnosticConnectSuite.scala new file mode 100644 index 0000000000000..ac28afd3b2955 --- /dev/null +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ExampleSessionAgnosticConnectSuite.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect + +import org.apache.spark.sql + +class ExampleSessionAgnosticConnectSuite + extends sql.ExampleSessionAgnosticSuite + with SessionQueryTest diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/QueryTestWithConnectSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/QueryTestWithConnectSuite.scala new file mode 100644 index 0000000000000..76c2201756600 --- /dev/null +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/QueryTestWithConnectSuite.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect + +import org.apache.spark.sql.QueryTestSuite + +/** + * Runs [[QueryTestSuite]] tests through a Connect session. + * + * This validates the `FooSuite with connect.SparkSessionBinder` pattern: the existing + * [[QueryTestSuite]] tests are inherited unchanged, but execute against a + * [[SparkSession connect.SparkSession]] instead of a classic one. + */ +class QueryTestWithConnectSuite + extends QueryTestSuite + with SessionQueryTest diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SessionQueryTest.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SessionQueryTest.scala new file mode 100644 index 0000000000000..11a7369cfcaa1 --- /dev/null +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SessionQueryTest.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect + +import org.apache.spark.sql + +/** + * Overrides test utils to implement 'connect variants' of suites declared in sql/core: + * {{{ + * // in sql/core + * FooSuite extends SessionQueryTest { test("") { ... } } + * + * // in sql/connect + * FooConnectSuite extends FooSuite with connect.SessionQueryTest + * }}} + * + * This trait overrides [[spark]] to use a [[SparkSession connect.SparkSession]], which executes + * via the gRPC API using an in-process connect server. + */ +trait SessionQueryTest extends sql.SessionQueryTest with SparkSessionBinder { + + /** + * Approximates [[sql.SessionQueryTest.isDfSorted]] by inspecting the explain string. + */ + override def isDfSorted(df: sql.DataFrame): Boolean = df match { + case df: DataFrame => df.explainString(extended = true).contains("Sort") + case df => super.isDfSorted(df) + } + + override def sessionType: String = "connect" +} diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkSessionBinder.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkSessionBinder.scala new file mode 100644 index 0000000000000..24b33163c2328 --- /dev/null +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkSessionBinder.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect + +import java.util.UUID + +import org.apache.spark.{SparkEnv, SparkFunSuite} +import org.apache.spark.sql +import org.apache.spark.sql.classic +import org.apache.spark.sql.connect.client.SparkConnectClient +import org.apache.spark.sql.connect.config.Connect +import org.apache.spark.sql.connect.service.SparkConnectService + +/** + * Provides a [[SparkSession connect.SparkSession]] backed by an in-process gRPC server. + * Extends [[sql.SparkSessionBinder sql.SparkSessionBinder]] (which creates a + * [[classic.SparkSession classic.SparkSession]] and SparkContext), then layers a Connect client + * session on top by starting the gRPC service in-process. + */ +trait SparkSessionBinder extends sql.SparkSessionBinder { self: SparkFunSuite => + + private var _connectSpark: SparkSession = _ + + protected override def spark: SparkSession = _connectSpark + + /** The underlying classic session used by the in-process server. */ + private def classicSpark: classic.SparkSession = super.spark.asInstanceOf[classic.SparkSession] + + override protected def beforeAll(): Unit = { + super.beforeAll() + // Other suites using mocks leave a mess in the global executionManager, + // shut it down so that it's cleared before starting server. + SparkConnectService.executionManager.shutdown() + val prevPort = SparkEnv.get.conf.get(Connect.CONNECT_GRPC_BINDING_PORT) + try { + // set GRPC_BINDING_PORT to 0 so that the server picks a random, freely available port. + SparkEnv.get.conf.set(Connect.CONNECT_GRPC_BINDING_PORT, 0) + SparkConnectService.start(classicSpark.sparkContext) + } finally { + SparkEnv.get.conf.set(Connect.CONNECT_GRPC_BINDING_PORT, prevPort) + } + } + + override def beforeEach(): Unit = { + val client = SparkConnectClient + .builder() + .port(SparkConnectService.localPort) + .sessionId(UUID.randomUUID().toString) + .userId("test") + .build() + _connectSpark = SparkSession + .builder() + .client(client) + .create() + super.beforeEach() + } + + override def afterEach(): Unit = { + super.afterEach() + if (_connectSpark != null) { + _connectSpark.close() + _connectSpark = null + } + } + + override def afterAll(): Unit = { + SparkConnectService.stop() + super.afterAll() + } +} diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkSessionProvider.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkSessionProvider.scala new file mode 100644 index 0000000000000..d9e456c0fd706 --- /dev/null +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkSessionProvider.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect + +import org.apache.spark.sql + +/** + * A common trait for test suites or utils that require a connect [[SparkSession]]. + * Use together with e.g. [[SparkSessionBinder]]. + */ +trait SparkSessionProvider extends sql.SparkSessionProvider { + protected override def spark: SparkSession +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CheckAnswerHelper.scala b/sql/core/src/test/scala/org/apache/spark/sql/CheckAnswerHelper.scala new file mode 100644 index 0000000000000..0437a41edd782 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/CheckAnswerHelper.scala @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.TimeZone + +import org.scalatest.Assertions + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.ExtendedAnalysisException +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.util.{SparkErrorUtils, SparkStringUtils} + +/** + * Provides [[checkAnswer]] helper for SQL- & DataFrame-API tests. + * + * TODO: should be moved to sql/api together with SessionQueryTestBase + */ +@Experimental +trait CheckAnswerHelper extends Assertions { + + /** + * Runs the plan and makes sure the answer matches the expected result. + * + * @param df the DataFrame to be executed + * @param expectedAnswer the expected result in a Seq of Rows. + */ + protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit = { + + val analyzedDF = try df catch { + case ae: ExtendedAnalysisException => + if (ae.plan.isDefined) { + fail( + s""" + |Failed to analyze query: $ae + |${ae.plan.get} + | + |${SparkErrorUtils.stackTraceToString(ae)} + |""".stripMargin) + } else { + throw ae + } + } + + getErrorMessageInCheckAnswer(analyzedDF, expectedAnswer) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } + } + + /* + * Note: when moving this to sql/api, implementation should stay in sql/core + * (i.e. only have abstract decl in sql/api) + */ + protected def isDfSorted(df: DataFrame): Boolean = { + df match { + case df: classic.DataFrame => + df.logicalPlan.collectFirst { case s: logical.Sort => s }.nonEmpty + case _ => throw new RuntimeException(s"Cannot determine whether df is sorted: $df") + } + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * If there was exception during the execution or the contents of the DataFrame does not + * match the expected result, an error message will be returned. Otherwise, a None will + * be returned. + * + * @param df the DataFrame to be executed + * @param expectedAnswer the expected result in a Seq of Rows. + */ + private def getErrorMessageInCheckAnswer( + df: DataFrame, + expectedAnswer: Seq[Row]): Option[String] = { + val sparkAnswer = try df.collect().toSeq catch { + case e: Exception => + val errorMessage = + s""" + |Exception thrown while executing query: + |${if (df.isInstanceOf[classic.DataFrame]) { df.queryExecution } else df.toString} + |== Exception == + |$e + |${SparkErrorUtils.stackTraceToString(e)} + """.stripMargin + return Some(errorMessage) + } + + sameRows(expectedAnswer, sparkAnswer, isDfSorted(df)).map { results => + s""" + |Results do not match for query: + |Timezone: ${TimeZone.getDefault} + |Timezone Env: ${sys.env.getOrElse("TZ", "")} + | + |${if (df.isInstanceOf[classic.DataFrame]) { df.queryExecution } else df.toString } + |== Results == + |$results + """.stripMargin + } + } + + private def prepareAnswer(answer: Seq[Row], isSorted: Boolean): Seq[Row] = { + // Converts data to types that we can do equality comparison using Scala collections. + // For BigDecimal type, the Scala type has a better definition of equality test (similar to + // Java's java.math.BigDecimal.compareTo). + // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for + // equality test. + val converted: Seq[Row] = answer.map(prepareRow) + if (!isSorted) converted.sortBy(_.toString()) else converted + } + + // We need to call prepareRow recursively to handle schemas with struct types. + private def prepareRow(row: Row): Row = { + Row.fromSeq(row.toSeq.map { + case null => null + case bd: java.math.BigDecimal => BigDecimal(bd) + // Equality of WrappedArray differs for AnyVal and AnyRef in Scala 2.12.2+ + case seq: Seq[_] => seq.map { + case b: java.lang.Byte => b.byteValue + case s: java.lang.Short => s.shortValue + case i: java.lang.Integer => i.intValue + case l: java.lang.Long => l.longValue + case f: java.lang.Float => f.floatValue + case d: java.lang.Double => d.doubleValue + case x => x + } + // Convert array to Seq for easy equality check. + case b: Array[_] => b.toSeq + case r: Row => prepareRow(r) + // SPARK-51349: "null" and null had the same precedence in sorting + case "null" => "__null_string__" + case o => o + }) + } + + private def genError( + expectedAnswer: Seq[Row], + sparkAnswer: Seq[Row], + isSorted: Boolean = false): String = { + val getRowType: Option[Row] => String = row => + row.map(row => + if (row.schema == null) { + "struct<>" + } else { + s"${row.schema.catalogString}" + }).getOrElse("struct<>") + + s""" + |== Results == + |${ + SparkStringUtils.sideBySide( + s"== Correct Answer - ${expectedAnswer.size} ==" +: + getRowType(expectedAnswer.headOption) +: + prepareAnswer(expectedAnswer, isSorted).map(_.toString()), + s"== Spark Answer - ${sparkAnswer.size} ==" +: + getRowType(sparkAnswer.headOption) +: + prepareAnswer(sparkAnswer, isSorted).map(_.toString())).mkString("\n") + } + """.stripMargin + } + + private def compare(obj1: Any, obj2: Any): Boolean = (obj1, obj2) match { + case (null, null) => true + case (null, _) => false + case (_, null) => false + case (a: Array[_], b: Array[_]) => + a.length == b.length && a.zip(b).forall { case (l, r) => compare(l, r)} + case (a: Map[_, _], b: Map[_, _]) => + a.size == b.size && a.keys.forall { aKey => + b.keys.find(bKey => compare(aKey, bKey)).exists(bKey => compare(a(aKey), b(bKey))) + } + case (a: Iterable[_], b: Iterable[_]) => + a.size == b.size && a.zip(b).forall { case (l, r) => compare(l, r)} + case (a: Product, b: Product) => + compare(a.productIterator.toSeq, b.productIterator.toSeq) + case (a: Row, b: Row) => + compare(a.toSeq, b.toSeq) + // 0.0 == -0.0, turn float/double to bits before comparison, to distinguish 0.0 and -0.0. + // in some hardware NaN can be represented with different bits, so first check for it + case (a: Double, b: Double) => + a.isNaN && b.isNaN || + java.lang.Double.doubleToRawLongBits(a) == java.lang.Double.doubleToRawLongBits(b) + case (a: Float, b: Float) => + a.isNaN && b.isNaN || + java.lang.Float.floatToRawIntBits(a) == java.lang.Float.floatToRawIntBits(b) + case (a, b) => a == b + } + + private def sameRows( expectedAnswer: Seq[Row], + sparkAnswer: Seq[Row], + isSorted: Boolean = false): Option[String] = { + if (!compare(prepareAnswer(expectedAnswer, isSorted), prepareAnswer(sparkAnswer, isSorted))) { + return Some(genError(expectedAnswer, sparkAnswer, isSorted)) + } + None + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExampleSessionAgnosticSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExampleSessionAgnosticSuite.scala new file mode 100644 index 0000000000000..d72cae519e001 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExampleSessionAgnosticSuite.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.SparkConf +import org.apache.spark.sql.connector.catalog.InMemoryPartitionTableCatalog + +/** + * Example for + */ +class ExampleSessionAgnosticSuite extends SessionQueryTest { + + override protected def sparkConf: SparkConf = + super.sparkConf + .set("spark.sql.catalog.testcat", classOf[InMemoryPartitionTableCatalog].getName) + .set("spark.sql.defaultCatalog", "testcat") + + test("Example classic/connect-agnostic testcase") { + withTable("t") { + spark.sql(s"CREATE TABLE t (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO t VALUES (1, 100)").collect() + + val df1 = spark.table("t") + + spark.sql(s"ALTER TABLE t ADD COLUMN new_column INT").collect() + spark.sql(s"INSERT INTO t VALUES (2, 200, -1)").collect() + + val df2 = spark.table("t") + val selfJoin = df1.join(df2, df1("id") === df2("id")) + + if (sessionType == "connect") { + // Connect re-resolves df1 with the new 3-column schema (id, salary, new_column). + assert(selfJoin.columns.length == 6, + s"Expected 6 columns (3 + 3) but got: ${selfJoin.columns.mkString(", ")}") + checkAnswer(selfJoin, + Seq(Row(1, 100, null, 1, 100, null), Row(2, 200, -1, 2, 200, -1))) + } else { + // Classic: df1 keeps its original 2-column schema (id, salary). + assert(selfJoin.columns.length == 5, + s"Expected 5 columns (2 + 3) but got: ${selfJoin.columns.mkString(", ")}") + checkAnswer(selfJoin, + Seq(Row(1, 100, 1, 100, null), Row(2, 200, 2, 200, -1))) + } + } + } + + test("testcase that uses withConf") { + withConf("spark.sql.charAsVarchar" -> "true") { + withTable("t") { + spark.sql(s"CREATE TABLE t(col CHAR(5)) USING foo") + checkAnswer( + spark.sql(s"desc t").selectExpr("data_type"), + Seq(Row("varchar(5)"))) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryCleanupHelper.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryCleanupHelper.scala new file mode 100644 index 0000000000000..49658bfe64a78 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryCleanupHelper.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.scalatest.Assertions + +import org.apache.spark.annotation.Experimental +import org.apache.spark.util.SparkErrorUtils + +/** + * Provides [[withTable]], [[withView]], and [[withUserDefinedFunction]] + */ +@Experimental +trait QueryCleanupHelper extends SparkSessionProvider with Assertions { + + /** + * Drops table `tableName` after calling `f`. + */ + protected def withTable(tableNames: String*)(f: => Unit): Unit = { + SparkErrorUtils.tryWithSafeFinally(f) { + tableNames.foreach { name => + spark.sql(s"DROP TABLE IF EXISTS $name") + } + } + } + + /** + * Drops view `viewName` after calling `f`. + */ + protected def withView(viewNames: String*)(f: => Unit): Unit = { + SparkErrorUtils.tryWithSafeFinally(f)( + viewNames.foreach { name => + spark.sql(s"DROP VIEW IF EXISTS $name") + } + ) + } + + protected def withUserDefinedFunction(functions: (String, Boolean)*)(f: => Unit): Unit = { + try { + f + } catch { + case cause: Throwable => throw cause + } finally { + functions.foreach { case (functionName, isTemporary) => + val withTemporary = if (isTemporary) "TEMPORARY" else "" + spark.sql(s"DROP $withTemporary FUNCTION IF EXISTS $functionName") + assert( + !spark.catalog.functionExists(functionName), + s"Function $functionName should have been dropped. But, it still exists.") + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 5a1ea3d9f53cf..90f1565c2a722 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -30,12 +30,11 @@ import scala.util.control.NonFatal import org.apache.hadoop.fs.Path import org.scalactic.source.Position -import org.scalatest.{Assertions, BeforeAndAfterAll, Suite, Tag} +import org.scalatest.{BeforeAndAfterAll, Suite, Tag} import org.scalatest.concurrent.Eventually import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.ExtendedAnalysisException -import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, UnresolvedAttribute} import org.apache.spark.sql.catalyst.catalog.SessionCatalog.DEFAULT_DATABASE import org.apache.spark.sql.catalyst.plans._ @@ -59,6 +58,8 @@ trait QueryTestBase extends Eventually with BeforeAndAfterAll with SQLTestData + with CheckAnswerHelper + with QueryCleanupHelper with PlanTestBase { self: Suite => /** @@ -156,7 +157,7 @@ trait QueryTestBase * @param df the [[DataFrame]] to be executed * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. */ - protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit = { + override protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit = { val analyzedDF = try df catch { case ae: ExtendedAnalysisException => if (ae.plan.isDefined) { @@ -172,9 +173,15 @@ trait QueryTestBase } } - assertEmptyMissingInput(analyzedDF) + if (analyzedDF.isInstanceOf[classic.DataFrame]) { + assertEmptyMissingInput(analyzedDF) - QueryTest.checkAnswer(analyzedDF, expectedAnswer) + SQLExecution.withSQLConfPropagated(analyzedDF.sparkSession) { + analyzedDF.materializedRdd.count() // Also attempt to deserialize as an RDD [SPARK-15791] + } + } + + super.checkAnswer(analyzedDF, expectedAnswer) } protected def checkAnswer(df: => DataFrame, expectedAnswer: Row): Unit = { @@ -202,6 +209,7 @@ trait QueryTestBase * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. * @param absTol the absolute tolerance between actual and expected answers. */ + @deprecated("rarely used", since = "4.2.0") protected def checkAggregatesWithTol(dataFrame: DataFrame, expectedAnswer: Seq[Row], absTol: Double): Unit = { @@ -216,6 +224,7 @@ trait QueryTestBase } } + @deprecated("rarely used", since = "4.2.0") protected def checkAggregatesWithTol(dataFrame: DataFrame, expectedAnswer: Row, absTol: Double): Unit = { @@ -322,25 +331,6 @@ trait QueryTestBase } } - /** - * Drops functions after calling `f`. A function is represented by (functionName, isTemporary). - */ - protected def withUserDefinedFunction(functions: (String, Boolean)*)(f: => Unit): Unit = { - try { - f - } catch { - case cause: Throwable => throw cause - } finally { - functions.foreach { case (functionName, isTemporary) => - val withTemporary = if (isTemporary) "TEMPORARY" else "" - spark.sql(s"DROP $withTemporary FUNCTION IF EXISTS $functionName") - assert( - !spark.sessionState.catalog.functionExists(FunctionIdentifier(functionName)), - s"Function $functionName should have been dropped. But, it still exists.") - } - } - } - /** * Drops temporary view `viewNames` after calling `f`. */ @@ -367,28 +357,6 @@ trait QueryTestBase } } - /** - * Drops table `tableName` after calling `f`. - */ - protected def withTable(tableNames: String*)(f: => Unit): Unit = { - Utils.tryWithSafeFinally(f) { - tableNames.foreach { name => - spark.sql(s"DROP TABLE IF EXISTS $name") - } - } - } - - /** - * Drops view `viewName` after calling `f`. - */ - protected def withView(viewNames: String*)(f: => Unit): Unit = { - Utils.tryWithSafeFinally(f)( - viewNames.foreach { name => - spark.sql(s"DROP VIEW IF EXISTS $name") - } - ) - } - /** * Drops cache `cacheName` after calling `f`. */ @@ -463,6 +431,7 @@ trait QueryTestBase /** * Restores the current catalog/database after calling `f`. */ + @deprecated("rarely used", since = "4.2.0") protected def withCurrentCatalogAndNamespace(f: => Unit): Unit = { val curCatalog = sql("select current_catalog()").head().getString(0) val curDatabase = sql("select current_database()").head().getString(0) @@ -510,6 +479,7 @@ trait QueryTestBase /** * Strip Spark-side filtering in order to check if a datasource filters rows correctly. */ + @deprecated("Classic-only method, use classic.QueryTest", since = "4.2.0") protected def stripSparkFilter(df: DataFrame): DataFrame = { val schema = df.schema val withoutFilters = df.queryExecution.executedPlan.transform { @@ -524,6 +494,7 @@ trait QueryTestBase * Turn a logical plan into a `DataFrame`. This should be removed once we have an easier * way to construct `DataFrame` directly out of local data without relying on implicits. */ + @deprecated("Classic-only method, use classic.QueryTest", since = "4.2.0") protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): classic.DataFrame = { classic.Dataset.ofRows(spark.asInstanceOf[classic.SparkSession], plan) } @@ -533,6 +504,7 @@ trait QueryTestBase * does not contain a scheme, this path will not be changed after the default * FileSystem is changed. */ + @deprecated("Classic-only method, use classic.QueryTest", since = "4.2.0") def makeQualifiedPath(path: String): URI = { val hadoopPath = new Path(path) val fs = hadoopPath.getFileSystem(spark.sessionState.newHadoopConf()) @@ -823,7 +795,18 @@ trait QueryTest extends SparkFunSuite with QueryTestBase { } } -object QueryTest extends Assertions { +@deprecated("superseded by CheckAnswerHelper", since = "4.2.0") +object QueryTest extends CheckAnswerHelper { + /** + * Runs the plan and makes sure the answer matches the expected result. + * + * @param df the DataFrame to be executed + * @param expectedAnswer the expected result in a Seq of Rows. + */ + def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Unit = { + checkAnswer(df, expectedAnswer, checkToRDD = true) + } + /** * Runs the plan and makes sure the answer matches the expected result. * @@ -831,13 +814,26 @@ object QueryTest extends Assertions { * @param expectedAnswer the expected result in a Seq of Rows. * @param checkToRDD whether to verify deserialization to an RDD. This runs the query twice. */ - def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row], checkToRDD: Boolean = true): Unit = { - getErrorMessageInCheckAnswer(df, expectedAnswer, checkToRDD) match { + def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row], checkToRDD: Boolean): Unit = { + if (checkToRDD) { + SQLExecution.withSQLConfPropagated(df.sparkSession) { + df.materializedRdd.count() // Also attempt to deserialize as an RDD [SPARK-15791] + } + } + + super.checkAnswer(df, expectedAnswer) + } + + def checkAnswer(df: DataFrame, expectedAnswer: java.util.List[Row]): Unit = { + getErrorMessageInCheckAnswer(df, expectedAnswer.asScala.toSeq) match { case Some(errorMessage) => fail(errorMessage) case None => } } + override protected def isDfSorted(df: DataFrame): Boolean = + df.logicalPlan.collectFirst { case s: logical.Sort => s }.nonEmpty + /** * Runs the plan and makes sure the answer matches the expected result. * If there was exception during the execution or the contents of the DataFrame does not @@ -1054,13 +1050,6 @@ object QueryTest extends Assertions { } } - def checkAnswer(df: DataFrame, expectedAnswer: java.util.List[Row]): Unit = { - getErrorMessageInCheckAnswer(df, expectedAnswer.asScala.toSeq) match { - case Some(errorMessage) => fail(errorMessage) - case None => - } - } - def withQueryExecutionsCaptured(spark: SparkSession)(thunk: => Unit): Seq[QueryExecution] = { var capturedQueryExecutions = Seq.empty[QueryExecution] @@ -1211,7 +1200,7 @@ object QueryTest extends Assertions { } -class QueryTestSuite extends test.SharedSparkSession { +class QueryTestSuite extends QueryTest with SparkSessionBinder { test("SPARK-16940: checkAnswer should raise TestFailedException for wrong results") { intercept[org.scalatest.exceptions.TestFailedException] { checkAnswer(sql("SELECT 1"), Row(2) :: Nil) @@ -1223,4 +1212,19 @@ class QueryTestSuite extends test.SharedSparkSession { "from range(2)"), Seq(Row(Row(null)), Row(Row("null")))) } + + test("checkAnswer demands correct result order for ordered queries") { + val e = intercept[org.scalatest.exceptions.TestFailedException] { + checkAnswer( + sql("SELECT col1 FROM VALUES 1, 2, 1, 3 ORDER BY col1"), + Seq(Row(3), Row(1), Row(1), Row(2))) + } + assert(e.getMessage().contains("Results do not match for query")) + } + + test("checkAnswer ignores result order for unordered queries") { + checkAnswer( + sql("SELECT col1 FROM VALUES 1, 2, 1, 3"), + Seq(Row(3), Row(1), Row(1), Row(2))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionQueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionQueryTest.scala new file mode 100644 index 0000000000000..a80c77ba877db --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionQueryTest.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.SparkFunSuite + +/** + * Provides connect-compatible test utils to write suites that have 'connect variants': + * {{{ + * // in sql/core + * FooSuite extends SessionQueryTest { test("") { ... } } + * + * // in sql/connect + * FooConnectSuite extends connect.SessionQueryTest + * }}} + * + * While this trait internally uses a [[classic.SparkSession]] when executing tests, + * it is exposed as a [[SparkSession sql.SparkSession]] to allow for overriding on the connect side. + * + * For classic-specific tests, use [[classic.SessionQueryTest]]. + */ +trait SessionQueryTest + extends SparkFunSuite + with SessionQueryTestBase + with SparkSessionBinder { + + override def sessionType: String = "classic" +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionQueryTestBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionQueryTestBase.scala new file mode 100644 index 0000000000000..5a75bedc82ac0 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionQueryTestBase.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +// scalastyle:off funsuite +import org.scalatest.funsuite.AnyFunSuite + +import org.apache.spark.CheckErrorHelper +import org.apache.spark.sql.catalyst.SQLConfHelper +// scalastyle:on + +/** + * TODO should be moved to sql/api + * + * base for fully sql/core independent tests, i.e. this trait could be moved to sql/api and then + * used in sql/connect/client. + */ +trait SessionQueryTestBase + extends AnyFunSuite + with SparkSessionProvider + with CheckAnswerHelper + with CheckErrorHelper + with SQLConfHelper + with QueryCleanupHelper { + + /** + * Sets all configurations specified in `pairs`, calls `f`, and then restores all configurations. + */ + protected def withConf[T](pairs: (String, String)*)(f: => T): T = { + val (keys, values) = pairs.unzip + val currentValues = keys.map { key => + if (spark.conf.contains(key)) { + Some(spark.conf.get(key)) + } else { + None + } + } + keys.lazyZip(values).foreach { (k, v) => + spark.conf.set(k, v) + } + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => spark.conf.set(key, value) + case (key, None) => spark.conf.unset(key) + } + } + } + + /** + * Documents used session so that tests can handle and document session-specific behaviour + * + * {{{ + * test(...) { + * val df = // query with connect-specific behaviour + * if (sessionType == "connect") { + * checkError(...) + * } else { + * checkAnswer(df, ...) + * } + * } + * }}} + */ + def sessionType: String +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBinder.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBinder.scala new file mode 100644 index 0000000000000..017e35ff4ce4b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBinder.scala @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.concurrent.duration._ + +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, Suite} +import org.scalatest.concurrent.Eventually + +import org.apache.spark.{DebugFilesystem, SparkConf, SparkFunSuite} +import org.apache.spark.annotation.Experimental +import org.apache.spark.internal.config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation +import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} +import org.apache.spark.sql.test.TestSparkSession + +/** + * Provides a [[spark]] implementation by creating a [[classic.SparkSession]]. + * + * Counterpart to [[SparkSessionProvider]], used in [[org.apache.spark.sql.test.SharedSparkSession]] + */ +trait SparkSessionBinder extends SparkSessionBinderBase { self: SparkFunSuite => + + /** + * Suites extending this trait are sharing resources (e.g. SparkSession) in their + * tests. This trait initializes the spark session in its [[beforeAll()]] implementation before + * the automatic thread snapshot is performed, so the audit code could fail to report threads + * leaked by that shared session. + * + * The behavior is overridden here to take the snapshot before the spark session is initialized. + */ + override protected val enableAutoThreadAudit = false + + protected override def beforeAll(): Unit = { + doThreadPreAudit() + super.beforeAll() + } + + protected override def afterAll(): Unit = { + try { + super.afterAll() + } finally { + doThreadPostAudit() + } + } +} + +/** + * [[SparkSessionBinderBase]] is needed for now as + * [[test.SharedSparkSessionBase SharedSparkSessionBase]] is still used by e.g. + * [[test.GenericWordSpecSuite]]. + * + * This Base might be merged into [[SparkSessionBinder]] once it is not required anymore. + * + * TODO: migrate SharedSparkSessionBase users so this can be removed + */ +@Experimental +trait SparkSessionBinderBase + extends SparkSessionProvider + with BeforeAndAfterEach + with BeforeAndAfterAll + with Eventually { self: Suite => + + protected def sparkConf: SparkConf = { + val conf = new SparkConf() + .set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) + .set(UNSAFE_EXCEPTION_ON_MEMORY_LEAK, true) + .set(SQLConf.CODEGEN_FALLBACK.key, "false") + .set(SQLConf.CODEGEN_FACTORY_MODE.key, CodegenObjectFactoryMode.CODEGEN_ONLY.toString) + // Disable ConvertToLocalRelation for better test coverage. Test cases built on + // LocalRelation will exercise the optimization rules better by disabling it as + // this rule may potentially block testing of other optimization rules such as + // ConstantPropagation etc. + .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName) + conf.set( + StaticSQLConf.WAREHOUSE_PATH, + conf.get(StaticSQLConf.WAREHOUSE_PATH) + "/" + getClass.getCanonicalName) + conf.set(StaticSQLConf.LOAD_SESSION_EXTENSIONS_FROM_CLASSPATH, false) + conf.set(StaticSQLConf.SHUFFLE_EXCHANGE_MAX_THREAD_THRESHOLD, + sys.env.getOrElse("SPARK_TEST_SQL_SHUFFLE_EXCHANGE_MAX_THREAD_THRESHOLD", + StaticSQLConf.SHUFFLE_EXCHANGE_MAX_THREAD_THRESHOLD.defaultValueString).toInt) + conf.set(StaticSQLConf.RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD, + sys.env.getOrElse("SPARK_TEST_SQL_RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD", + StaticSQLConf.RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD.defaultValueString).toInt) + } + + /** + * The [[TestSparkSession]] to use for all tests in this suite. + * + * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local + * mode with the default test configurations. + */ + private var _spark: classic.SparkSession = null + + protected override def spark: SparkSession = _spark + + /** + * The [[SQLContext]] to use for all tests in this suite. + */ + protected implicit def sqlContext: SQLContext = _spark.sqlContext + + protected def createSparkSession: classic.SparkSession = { + classic.SparkSession.cleanupAnyExistingSession() + new TestSparkSession(sparkConf) + } + + protected def sqlConf: SQLConf = _spark.sessionState.conf + + /** + * Initialize the [[TestSparkSession]]. Generally, this is just called from + * beforeAll; however, in test using styles other than FunSuite, there is + * often code that relies on the session between test group constructs and + * the actual tests, which may need this session. It is purely a semantic + * difference, but semantically, it makes more sense to call + * 'initializeSession' between a 'describe' and an 'it' call than it does to + * call 'beforeAll'. + */ + protected def initializeSession(): Unit = { + if (_spark == null) { + _spark = createSparkSession + } + } + + /** + * Make sure the [[TestSparkSession]] is initialized before any tests are run. + */ + protected override def beforeAll(): Unit = { + initializeSession() + + // Ensure we have initialized the context before calling parent code + super.beforeAll() + } + + /** + * Stop the underlying [[org.apache.spark.SparkContext]], if any. + */ + protected override def afterAll(): Unit = { + try { + super.afterAll() + } finally { + try { + if (_spark != null) { + try { + _spark.sessionState.catalog.reset() + } finally { + _spark.stop() + _spark = null + } + } + } finally { + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + } + } + } + + protected override def beforeEach(): Unit = { + super.beforeEach() + DebugFilesystem.clearOpenStreams() + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Clear all persistent datasets after each test + _spark.sharedState.cacheManager.clearCache() + // files can be closed from other threads, so wait a bit + // normally this doesn't take more than 1s + eventually(timeout(10.seconds), interval(2.seconds)) { + DebugFilesystem.assertNoOpenStreams() + } + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/classic/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/classic/QueryTest.scala new file mode 100644 index 0000000000000..36a52e16314cb --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/classic/QueryTest.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.classic + +import scala.language.implicitConversions + +import org.apache.spark.sql +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.FilterExec + +/** + * Extends [[org.apache.spark.sql.QueryTest sql.QueryTest]] to provide classic-only helpers. + */ +trait QueryTest extends sql.QueryTest with SparkSessionProvider { + + /** + * Strip Spark-side filtering in order to check if a datasource filters rows correctly. + */ + protected def stripSparkFilter(df: DataFrame): DataFrame = { + val schema = df.schema + val withoutFilters = df.queryExecution.executedPlan.transform { + case FilterExec(_, child) => child + } + + spark.internalCreateDataFrame(withoutFilters.execute(), schema) + } + + /** + * Turn a logical plan into a `DataFrame`. This should be removed once we have an easier + * way to construct `DataFrame` directly out of local data without relying on implicits. + */ + protected implicit override def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { + Dataset.ofRows(spark, plan) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/classic/SessionQueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/classic/SessionQueryTest.scala new file mode 100644 index 0000000000000..d56146b05d23b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/classic/SessionQueryTest.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.classic + +import org.apache.spark.sql + +/** + * Override of [[sql.SessionQueryTest]] that provides [[SparkSession classic.SparkSession]]. + * + * Can be used to declare classic-specific tests: + * {{{ + * class FooSuite extends sql.SessionQueryTest { + * // shared classic/connect-agnostic testcases + * } + * + * // no need to extend FooSuite as sql.SessionQueryTest + * // already executes shared tests via classic internally. + * class FooClassicSuite extends classic.SessionQueryTest { + * test("classic-only test") { + * // classic-only APIs are visible here + * spark.sessionState.conf + * } + * } + * }}} + */ +trait SessionQueryTest extends sql.SessionQueryTest with SparkSessionBinder + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionBinder.scala b/sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionBinder.scala new file mode 100644 index 0000000000000..2f79876d841d8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionBinder.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.classic + +import org.apache.spark.{sql, SparkFunSuite} + +/** + * Overrides [[spark]] to provide a [[SparkSession classic.SparkSession]] + */ +trait SparkSessionBinder extends sql.SparkSessionBinder { self: SparkFunSuite => + override protected def spark: SparkSession = super.spark.asInstanceOf[SparkSession] +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionProvider.scala b/sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionProvider.scala new file mode 100644 index 0000000000000..77de0db4bf68b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionProvider.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.classic + +import org.apache.spark.sql + +trait SparkSessionProvider extends sql.SparkSessionProvider { + override protected def spark: SparkSession +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DSv2CacheTableReadTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DSv2CacheTableReadTests.scala index ac6ffcc6ecc0d..79c101d524a07 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DSv2CacheTableReadTests.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DSv2CacheTableReadTests.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.connector -import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.catalog.{CachingInMemoryTableCatalog, Column, InMemoryTableCatalog, TableChange, TableInfo} import org.apache.spark.sql.types.IntegerType @@ -49,223 +49,209 @@ import org.apache.spark.sql.types.IntegerType * (via the CacheManager), making a session drop+recreate scenario trivially different from * the external variant. * - * NOTE: All `session.sql(...)` calls append `.collect()` because Connect client DataFrames + * NOTE: All `spark.sql(...)` calls append `.collect()` because Connect client DataFrames * are lazy and require an action to trigger execution. In classic mode `.collect()` on * DDL / DML is a no-op (these execute eagerly), so this is harmless. */ trait DSv2CacheTableReadTests extends DSv2ExternalMutationTestBase { - private def assertTableCached(session: SparkSession, tableName: String): Unit = - assert(session.catalog.isCached(tableName)) + private def assertTableCached(tableName: String): Unit = + assert(spark.catalog.isCached(tableName)) test(s"${testPrefix}SPARK-54022: cached table pinned against external data write") { - withTestSession { session => - withTestTableAndViews(session, testTable) { - session.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $testTable VALUES (1, 100)").collect() + withTable(testTable) { + spark.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $testTable VALUES (1, 100)").collect() - session.table(testTable).cache() - assertTableCached(session, testTable) - checkRows(session.table(testTable), Seq(Row(1, 100))) + spark.table(testTable).cache() + assertTableCached(testTable) + checkAnswer(spark.table(testTable), Seq(Row(1, 100))) - val catalog = getTableCatalog[InMemoryTableCatalog](session, "testcat") - externalAppend(catalog = catalog, ident = testIdent, row = InternalRow(2, 200)) + val catalog = getTableCatalog[InMemoryTableCatalog](spark, "testcat") + externalAppend(catalog = catalog, ident = testIdent, row = InternalRow(2, 200)) - assertTableCached(session, testTable) - checkRows(session.table(testTable), Seq(Row(1, 100))) + assertTableCached(testTable) + checkAnswer(spark.table(testTable), Seq(Row(1, 100))) - session.sql(s"REFRESH TABLE $testTable").collect() - assertTableCached(session, testTable) - checkRows(session.table(testTable), Seq(Row(1, 100), Row(2, 200))) - } + spark.sql(s"REFRESH TABLE $testTable").collect() + assertTableCached(testTable) + checkAnswer(spark.table(testTable), Seq(Row(1, 100), Row(2, 200))) } } test(s"${testPrefix}SPARK-54022: connector w/ cache: cached table pinned, " + "REFRESH clears both layers") { - withTestSession { session => - withTestTableAndViews(session, cachingTestTable) { - session.sql(s"CREATE TABLE $cachingTestTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $cachingTestTable VALUES (1, 100)").collect() - - session.table(cachingTestTable).cache() - assertTableCached(session, cachingTestTable) - checkRows(session.table(cachingTestTable), Seq(Row(1, 100))) - - val catalog = - getTableCatalog[CachingInMemoryTableCatalog](session, "cachingcat") - externalAppend(catalog = catalog, ident = testIdent, row = InternalRow(2, 200)) - - // Both CacheManager and connector cache are stale: external write invisible - assertTableCached(session, cachingTestTable) - checkRows(session.table(cachingTestTable), Seq(Row(1, 100))) - - // REFRESH TABLE calls invalidateTable (clears connector cache) and rebuilds - // the CacheManager entry, so the external write becomes visible. - session.sql(s"REFRESH TABLE $cachingTestTable").collect() - assertTableCached(session, cachingTestTable) - checkRows(session.table(cachingTestTable), Seq(Row(1, 100), Row(2, 200))) - } + withTable(cachingTestTable) { + spark.sql(s"CREATE TABLE $cachingTestTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $cachingTestTable VALUES (1, 100)").collect() + + spark.table(cachingTestTable).cache() + assertTableCached(cachingTestTable) + checkAnswer(spark.table(cachingTestTable), Seq(Row(1, 100))) + + val catalog = + getTableCatalog[CachingInMemoryTableCatalog](spark, "cachingcat") + externalAppend(catalog = catalog, ident = testIdent, row = InternalRow(2, 200)) + + // Both CacheManager and connector cache are stale: external write invisible + assertTableCached(cachingTestTable) + checkAnswer(spark.table(cachingTestTable), Seq(Row(1, 100))) + + // REFRESH TABLE calls invalidateTable (clears connector cache) and rebuilds + // the CacheManager entry, so the external write becomes visible. + spark.sql(s"REFRESH TABLE $cachingTestTable").collect() + assertTableCached(cachingTestTable) + checkAnswer(spark.table(cachingTestTable), Seq(Row(1, 100), Row(2, 200))) } } test(s"${testPrefix}SPARK-54022: session write invalidates cache, " + "then external write invisible") { - withTestSession { session => - withTestTableAndViews(session, testTable) { - session.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $testTable VALUES (1, 100)").collect() + withTable(testTable) { + spark.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $testTable VALUES (1, 100)").collect() - session.table(testTable).cache() - assertTableCached(session, testTable) - checkRows(session.table(testTable), Seq(Row(1, 100))) + spark.table(testTable).cache() + assertTableCached(testTable) + checkAnswer(spark.table(testTable), Seq(Row(1, 100))) - session.sql(s"INSERT INTO $testTable VALUES (2, 200)").collect() - assertTableCached(session, testTable) - checkRows(session.table(testTable), Seq(Row(1, 100), Row(2, 200))) + spark.sql(s"INSERT INTO $testTable VALUES (2, 200)").collect() + assertTableCached(testTable) + checkAnswer(spark.table(testTable), Seq(Row(1, 100), Row(2, 200))) - val catalog = getTableCatalog[InMemoryTableCatalog](session, "testcat") - externalAppend(catalog = catalog, ident = testIdent, row = InternalRow(3, 300)) + val catalog = getTableCatalog[InMemoryTableCatalog](spark, "testcat") + externalAppend(catalog = catalog, ident = testIdent, row = InternalRow(3, 300)) - assertTableCached(session, testTable) - checkRows(session.table(testTable), Seq(Row(1, 100), Row(2, 200))) + assertTableCached(testTable) + checkAnswer(spark.table(testTable), Seq(Row(1, 100), Row(2, 200))) - session.sql(s"REFRESH TABLE $testTable").collect() - assertTableCached(session, testTable) - checkRows(session.table(testTable), Seq(Row(1, 100), Row(2, 200), Row(3, 300))) - } + spark.sql(s"REFRESH TABLE $testTable").collect() + assertTableCached(testTable) + checkAnswer(spark.table(testTable), Seq(Row(1, 100), Row(2, 200), Row(3, 300))) } } test(s"${testPrefix}SPARK-54022: cached table pinned against external schema change") { - withTestSession { session => - withTestTableAndViews(session, testTable) { - session.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $testTable VALUES (1, 100)").collect() - - session.table(testTable).cache() - assertTableCached(session, testTable) - checkRows(session.table(testTable), Seq(Row(1, 100))) - - val catalog = getTableCatalog[InMemoryTableCatalog](session, "testcat") - val addCol = TableChange.addColumn(Array("new_column"), IntegerType, true) - catalog.alterTable(testIdent, addCol) - externalAppend(catalog = catalog, ident = testIdent, row = InternalRow(2, 200, -1)) - - assertTableCached(session, testTable) - checkRows(session.table(testTable), Seq(Row(1, 100))) - - session.sql(s"REFRESH TABLE $testTable").collect() - assertTableCached(session, testTable) - checkRows(session.table(testTable), Seq(Row(1, 100, null), Row(2, 200, -1))) - } + withTable(testTable) { + spark.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $testTable VALUES (1, 100)").collect() + + spark.table(testTable).cache() + assertTableCached(testTable) + checkAnswer(spark.table(testTable), Seq(Row(1, 100))) + + val catalog = getTableCatalog[InMemoryTableCatalog](spark, "testcat") + val addCol = TableChange.addColumn(Array("new_column"), IntegerType, true) + catalog.alterTable(testIdent, addCol) + externalAppend(catalog = catalog, ident = testIdent, row = InternalRow(2, 200, -1)) + + assertTableCached(testTable) + checkAnswer(spark.table(testTable), Seq(Row(1, 100))) + + spark.sql(s"REFRESH TABLE $testTable").collect() + assertTableCached(testTable) + checkAnswer(spark.table(testTable), Seq(Row(1, 100, null), Row(2, 200, -1))) } } test(s"${testPrefix}SPARK-54022: session schema change invalidates cache, " + "external write invisible") { - withTestSession { session => - withTestTableAndViews(session, testTable) { - session.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $testTable VALUES (1, 100)").collect() + withTable(testTable) { + spark.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $testTable VALUES (1, 100)").collect() - session.table(testTable).cache() - assertTableCached(session, testTable) - checkRows(session.table(testTable), Seq(Row(1, 100))) + spark.table(testTable).cache() + assertTableCached(testTable) + checkAnswer(spark.table(testTable), Seq(Row(1, 100))) - session.sql(s"ALTER TABLE $testTable ADD COLUMN new_column INT").collect() - assertTableCached(session, testTable) - checkRows(session.table(testTable), Seq(Row(1, 100, null))) + spark.sql(s"ALTER TABLE $testTable ADD COLUMN new_column INT").collect() + assertTableCached(testTable) + checkAnswer(spark.table(testTable), Seq(Row(1, 100, null))) - val catalog = getTableCatalog[InMemoryTableCatalog](session, "testcat") - externalAppend(catalog = catalog, ident = testIdent, row = InternalRow(2, 200, -1)) + val catalog = getTableCatalog[InMemoryTableCatalog](spark, "testcat") + externalAppend(catalog = catalog, ident = testIdent, row = InternalRow(2, 200, -1)) - assertTableCached(session, testTable) - checkRows(session.table(testTable), Seq(Row(1, 100, null))) + assertTableCached(testTable) + checkAnswer(spark.table(testTable), Seq(Row(1, 100, null))) - session.sql(s"REFRESH TABLE $testTable").collect() - assertTableCached(session, testTable) - checkRows(session.table(testTable), Seq(Row(1, 100, null), Row(2, 200, -1))) - } + spark.sql(s"REFRESH TABLE $testTable").collect() + assertTableCached(testTable) + checkAnswer(spark.table(testTable), Seq(Row(1, 100, null), Row(2, 200, -1))) } } test(s"${testPrefix}SPARK-54022: cached table after external drop and " + "recreate sees empty table") { - withTestSession { session => - withTestTableAndViews(session, testTable) { - session.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $testTable VALUES (1, 100)").collect() - - session.table(testTable).cache() - assertTableCached(session, testTable) - checkRows(session.table(testTable), Seq(Row(1, 100))) - - val catalog = getTableCatalog[InMemoryTableCatalog](session, "testcat") - val originalTableId = catalog.loadTable(testIdent).id - - catalog.dropTable(testIdent) - catalog.createTable( - testIdent, - new TableInfo.Builder() - .withColumns(Array( - Column.create("id", IntegerType), - Column.create("salary", IntegerType))) - .build()) - - val newTableId = catalog.loadTable(testIdent).id - assert(originalTableId != newTableId) - - val result = session.table(testTable) - assert(result.schema.fieldNames.toSeq == Seq("id", "salary")) - checkRows(result, Seq.empty) - - // External drop+recreate produces a new table identity, so the prior cache entry - // is unreachable via name lookup (unlike external write/schema change where the - // cache stays pinned). - assert(!session.catalog.isCached(testTable)) - - session.sql(s"REFRESH TABLE $testTable").collect() - checkRows(session.table(testTable), Seq.empty) - } + withTable(testTable) { + spark.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $testTable VALUES (1, 100)").collect() + + spark.table(testTable).cache() + assertTableCached(testTable) + checkAnswer(spark.table(testTable), Seq(Row(1, 100))) + + val catalog = getTableCatalog[InMemoryTableCatalog](spark, "testcat") + val originalTableId = catalog.loadTable(testIdent).id + + catalog.dropTable(testIdent) + catalog.createTable( + testIdent, + new TableInfo.Builder() + .withColumns(Array( + Column.create("id", IntegerType), + Column.create("salary", IntegerType))) + .build()) + + val newTableId = catalog.loadTable(testIdent).id + assert(originalTableId != newTableId) + + val result = spark.table(testTable) + assert(result.schema.fieldNames.toSeq == Seq("id", "salary")) + checkAnswer(result, Seq.empty) + + // External drop+recreate produces a new table identity, so the prior cache entry + // is unreachable via name lookup (unlike external write/schema change where the + // cache stays pinned). + assert(!spark.catalog.isCached(testTable)) + + spark.sql(s"REFRESH TABLE $testTable").collect() + checkAnswer(spark.table(testTable), Seq.empty) } } test(s"${testPrefix}SPARK-54022: connector w/ cache: cached table stale after " + "external drop and recreate") { - withTestSession { session => - withTestTableAndViews(session, cachingTestTable) { - session.sql(s"CREATE TABLE $cachingTestTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $cachingTestTable VALUES (1, 100)").collect() - - session.table(cachingTestTable).cache() - assertTableCached(session, cachingTestTable) - checkRows(session.table(cachingTestTable), Seq(Row(1, 100))) - - val catalog = - getTableCatalog[CachingInMemoryTableCatalog](session, "cachingcat") - val originalTableId = catalog.loadTable(testIdent).id - - catalog.dropTable(testIdent) - catalog.createTable( - testIdent, - new TableInfo.Builder() - .withColumns(Array( - Column.create("id", IntegerType), - Column.create("salary", IntegerType))) - .build()) - - // CachingInMemoryTableCatalog does not invalidate on drop/create, so loadTable - // still returns the old cached table object. CacheManager still matches and - // serves the stale cached data. - assertTableCached(session, cachingTestTable) - checkRows(session.table(cachingTestTable), Seq(Row(1, 100))) - - // REFRESH TABLE calls invalidateTable (clears connector cache) and rebuilds - // the CacheManager entry, so the new empty table becomes visible. - session.sql(s"REFRESH TABLE $cachingTestTable").collect() - checkRows(session.table(cachingTestTable), Seq.empty) - } + withTable(cachingTestTable) { + spark.sql(s"CREATE TABLE $cachingTestTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $cachingTestTable VALUES (1, 100)").collect() + + spark.table(cachingTestTable).cache() + assertTableCached(cachingTestTable) + checkAnswer(spark.table(cachingTestTable), Seq(Row(1, 100))) + + val catalog = + getTableCatalog[CachingInMemoryTableCatalog](spark, "cachingcat") + val originalTableId = catalog.loadTable(testIdent).id + + catalog.dropTable(testIdent) + catalog.createTable( + testIdent, + new TableInfo.Builder() + .withColumns(Array( + Column.create("id", IntegerType), + Column.create("salary", IntegerType))) + .build()) + + // CachingInMemoryTableCatalog does not invalidate on drop/create, so loadTable + // still returns the old cached table object. CacheManager still matches and + // serves the stale cached data. + assertTableCached(cachingTestTable) + checkAnswer(spark.table(cachingTestTable), Seq(Row(1, 100))) + + // REFRESH TABLE calls invalidateTable (clears connector cache) and rebuilds + // the CacheManager entry, so the new empty table becomes visible. + spark.sql(s"REFRESH TABLE $cachingTestTable").collect() + checkAnswer(spark.table(cachingTestTable), Seq.empty) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DSv2ExternalMutationTestBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DSv2ExternalMutationTestBase.scala index 0b2a50534447c..73c69f8a9de41 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DSv2ExternalMutationTestBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DSv2ExternalMutationTestBase.scala @@ -21,7 +21,7 @@ import java.util import scala.reflect.ClassTag -import org.apache.spark.sql.{DataFrame, QueryTest, Row, SparkSession} +import org.apache.spark.sql.{SessionQueryTestBase, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.catalog.{BufferedRows, CatalogV2Util, Identifier, InMemoryBaseTable, TableCatalog, TableWritePrivilege} @@ -37,7 +37,7 @@ import org.apache.spark.sql.connector.catalog.{BufferedRows, CatalogV2Util, Iden * [[DSv2TempViewWithStoredPlanTests]], [[DSv2RepeatedTableAccessTests]], * [[DSv2IncrementallyConstructedQueryTests]], or [[DSv2CacheTableReadTests]]. */ -trait DSv2ExternalMutationTestBase extends QueryTest { +trait DSv2ExternalMutationTestBase extends SessionQueryTestBase { /** Fully qualified table name under the non-caching test catalog. */ protected val testTable: String = "testcat.ns1.ns2.tbl" @@ -51,17 +51,6 @@ trait DSv2ExternalMutationTestBase extends QueryTest { /** Prefix for test names, e.g. "" or "[connect] ". */ protected def testPrefix: String - /** Whether this suite runs under Spark Connect. */ - protected def isConnect: Boolean - - /** Execute a test body with a session. */ - protected def withTestSession(fn: SparkSession => Unit): Unit - - /** - * Assert that a DataFrame's rows match the expected rows (order-agnostic). - */ - protected def checkRows(df: => DataFrame, expected: Seq[Row]): Unit - /** * Get a [[TableCatalog]] by name from the underlying session. */ @@ -69,12 +58,6 @@ trait DSv2ExternalMutationTestBase extends QueryTest { session: SparkSession, catalogName: String): C - /** Cleanup wrapper: drop views and the table after the test body, even on failure. */ - protected def withTestTableAndViews( - session: SparkSession, - table: String, - views: Seq[String] = Seq.empty)(fn: => Unit): Unit - /** Appends a row to a DSv2 table via the catalog API, bypassing the session. */ protected def externalAppend( catalog: TableCatalog, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DSv2IncrementallyConstructedQueryTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DSv2IncrementallyConstructedQueryTests.scala index 1dbaad18e3e71..a6de3a0139452 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DSv2IncrementallyConstructedQueryTests.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DSv2IncrementallyConstructedQueryTests.scala @@ -32,7 +32,7 @@ import org.apache.spark.unsafe.types.UTF8String * mode, resolution is deferred until execution, so both sides of a join always see the * latest table state. * - * NOTE: All `session.sql(...)` calls append `.collect()` because Connect client DataFrames + * NOTE: All `spark.sql(...)` calls append `.collect()` because Connect client DataFrames * are lazy and require an action to trigger execution. In classic mode `.collect()` on * eager statements (DDL, INSERT) is a no-op, so this is harmless. */ @@ -45,44 +45,40 @@ trait DSv2IncrementallyConstructedQueryTests extends DSv2ExternalMutationTestBas test(s"${testPrefix}SPARK-54157: join refreshes both sides after external insert" + " (table with both table and column ID support)") { - withTestSession { session => - withTestTableAndViews(session, testTable) { - session.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $testTable VALUES (1, 100)").collect() + withTable(testTable) { + spark.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $testTable VALUES (1, 100)").collect() - val df1 = session.table(testTable) + val df1 = spark.table(testTable) - val catalog = getTableCatalog[InMemoryTableCatalog](session, "testcat") + val catalog = getTableCatalog[InMemoryTableCatalog](spark, "testcat") externalAppend(catalog = catalog, ident = testIdent, row = InternalRow(2, 200)) - val df2 = session.table(testTable) + val df2 = spark.table(testTable) - checkRows( + checkAnswer( df1.join(df2, df1("id") === df2("id")), Seq(Row(1, 100, 1, 100), Row(2, 200, 2, 200))) } } - } test(s"${testPrefix}SPARK-54157: join refreshes both sides after same-session insert" + " (table with both table and column ID support)") { - withTestSession { session => - withTestTableAndViews(session, testTable) { - session.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $testTable VALUES (1, 100)").collect() + withTable(testTable) { + spark.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $testTable VALUES (1, 100)").collect() - val df1 = session.table(testTable) + val df1 = spark.table(testTable) - session.sql(s"INSERT INTO $testTable VALUES (2, 200)").collect() + spark.sql(s"INSERT INTO $testTable VALUES (2, 200)").collect() - val df2 = session.table(testTable) + val df2 = spark.table(testTable) - checkRows( + checkAnswer( df1.join(df2, df1("id") === df2("id")), Seq(Row(1, 100, 1, 100), Row(2, 200, 2, 200))) } } - } // --------------------------------------------------------------------------- // Scenario 2: join after ADD COLUMN. @@ -92,70 +88,66 @@ trait DSv2IncrementallyConstructedQueryTests extends DSv2ExternalMutationTestBas test(s"${testPrefix}SPARK-54157: join after external ADD COLUMN" + " (table with both table and column ID support)") { - withTestSession { session => - withTestTableAndViews(session, testTable) { - session.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $testTable VALUES (1, 100)").collect() + withTable(testTable) { + spark.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $testTable VALUES (1, 100)").collect() - val df1 = session.table(testTable) + val df1 = spark.table(testTable) - val catalog = getTableCatalog[InMemoryTableCatalog](session, "testcat") + val catalog = getTableCatalog[InMemoryTableCatalog](spark, "testcat") catalog.alterTable( testIdent, TableChange.addColumn(Array("new_column"), IntegerType, true)) externalAppend( catalog = catalog, ident = testIdent, row = InternalRow(2, 200, -1)) - val df2 = session.table(testTable) + val df2 = spark.table(testTable) val selfJoin = df1.join(df2, df1("id") === df2("id")) - if (isConnect) { + if (sessionType == "connect") { // Connect re-resolves df1 with the new 3-column schema (id, salary, new_column). assert(selfJoin.columns.length == 6, s"Expected 6 columns (3 + 3) but got: ${selfJoin.columns.mkString(", ")}") - checkRows(selfJoin, + checkAnswer(selfJoin, Seq(Row(1, 100, null, 1, 100, null), Row(2, 200, -1, 2, 200, -1))) } else { // Classic: df1 keeps its original 2-column schema (id, salary). assert(selfJoin.columns.length == 5, s"Expected 5 columns (2 + 3) but got: ${selfJoin.columns.mkString(", ")}") - checkRows(selfJoin, + checkAnswer(selfJoin, Seq(Row(1, 100, 1, 100, null), Row(2, 200, 2, 200, -1))) } } } - } test(s"${testPrefix}SPARK-54157: join after same-session ADD COLUMN" + " (table with both table and column ID support)") { - withTestSession { session => - withTestTableAndViews(session, testTable) { - session.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $testTable VALUES (1, 100)").collect() + withTable(testTable) { + spark.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $testTable VALUES (1, 100)").collect() - val df1 = session.table(testTable) + val df1 = spark.table(testTable) - session.sql(s"ALTER TABLE $testTable ADD COLUMN new_column INT").collect() - session.sql(s"INSERT INTO $testTable VALUES (2, 200, -1)").collect() + spark.sql(s"ALTER TABLE $testTable ADD COLUMN new_column INT").collect() + spark.sql(s"INSERT INTO $testTable VALUES (2, 200, -1)").collect() - val df2 = session.table(testTable) + val df2 = spark.table(testTable) val selfJoin = df1.join(df2, df1("id") === df2("id")) - if (isConnect) { + if (sessionType == "connect") { // Connect re-resolves df1 with the new 3-column schema (id, salary, new_column). assert(selfJoin.columns.length == 6, s"Expected 6 columns (3 + 3) but got: ${selfJoin.columns.mkString(", ")}") - checkRows(selfJoin, + checkAnswer(selfJoin, Seq(Row(1, 100, null, 1, 100, null), Row(2, 200, -1, 2, 200, -1))) } else { // Classic: df1 keeps its original 2-column schema (id, salary). assert(selfJoin.columns.length == 5, s"Expected 5 columns (2 + 3) but got: ${selfJoin.columns.mkString(", ")}") - checkRows(selfJoin, + checkAnswer(selfJoin, Seq(Row(1, 100, 1, 100, null), Row(2, 200, 2, 200, -1))) } } } - } // --------------------------------------------------------------------------- // Scenario 3: join after DROP COLUMN. @@ -165,23 +157,22 @@ trait DSv2IncrementallyConstructedQueryTests extends DSv2ExternalMutationTestBas test(s"${testPrefix}SPARK-54157: join after external DROP COLUMN" + " (table with both table and column ID support)") { - withTestSession { session => - withTestTableAndViews(session, testTable) { - session.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $testTable VALUES (1, 100)").collect() + withTable(testTable) { + spark.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $testTable VALUES (1, 100)").collect() - val df1 = session.table(testTable) + val df1 = spark.table(testTable) - val catalog = getTableCatalog[InMemoryTableCatalog](session, "testcat") + val catalog = getTableCatalog[InMemoryTableCatalog](spark, "testcat") catalog.alterTable( testIdent, TableChange.deleteColumn(Array("salary"), false)) externalAppend(catalog = catalog, ident = testIdent, row = InternalRow(2)) - val df2 = session.table(testTable) + val df2 = spark.table(testTable) - if (isConnect) { + if (sessionType == "connect") { // Connect re-resolves df1 without the dropped column. - checkRows( + checkAnswer( df1.join(df2, df1("id") === df2("id")), Seq(Row(1, 1), Row(2, 2))) } else { @@ -196,25 +187,23 @@ trait DSv2IncrementallyConstructedQueryTests extends DSv2ExternalMutationTestBas } } } - } test(s"${testPrefix}SPARK-54157: join after same-session DROP COLUMN" + " (table with both table and column ID support)") { - withTestSession { session => - withTestTableAndViews(session, testTable) { - session.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $testTable VALUES (1, 100)").collect() + withTable(testTable) { + spark.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $testTable VALUES (1, 100)").collect() - val df1 = session.table(testTable) + val df1 = spark.table(testTable) - session.sql(s"ALTER TABLE $testTable DROP COLUMN salary").collect() - session.sql(s"INSERT INTO $testTable VALUES (2)").collect() + spark.sql(s"ALTER TABLE $testTable DROP COLUMN salary").collect() + spark.sql(s"INSERT INTO $testTable VALUES (2)").collect() - val df2 = session.table(testTable) + val df2 = spark.table(testTable) - if (isConnect) { + if (sessionType == "connect") { // Connect re-resolves df1 without the dropped column. - checkRows( + checkAnswer( df1.join(df2, df1("id") === df2("id")), Seq(Row(1, 1), Row(2, 2))) } else { @@ -229,7 +218,6 @@ trait DSv2IncrementallyConstructedQueryTests extends DSv2ExternalMutationTestBas } } } - } // --------------------------------------------------------------------------- // Scenario 4: external drop and recreate table. @@ -240,13 +228,12 @@ trait DSv2IncrementallyConstructedQueryTests extends DSv2ExternalMutationTestBas test(s"${testPrefix}SPARK-54157: join after external table drop and recreate" + " (table with both table and column ID support)") { - withTestSession { session => - withTestTableAndViews(session, testTable) { - session.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $testTable VALUES (1, 100)").collect() + withTable(testTable) { + spark.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $testTable VALUES (1, 100)").collect() - val df1 = session.table(testTable) - val catalog = getTableCatalog[InMemoryTableCatalog](session, "testcat") + val df1 = spark.table(testTable) + val catalog = getTableCatalog[InMemoryTableCatalog](spark, "testcat") val originTableId = catalog.loadTable(testIdent).id catalog.dropTable(testIdent) @@ -259,13 +246,13 @@ trait DSv2IncrementallyConstructedQueryTests extends DSv2ExternalMutationTestBas .build()) externalAppend(catalog = catalog, ident = testIdent, row = InternalRow(2, 200)) - val df2 = session.table(testTable) + val df2 = spark.table(testTable) val newTableId = catalog.loadTable(testIdent).id assert(originTableId != newTableId) - if (isConnect) { + if (sessionType == "connect") { // Connect re-resolves both sides to the recreated table. - checkRows( + checkAnswer( df1.join(df2, df1("id") === df2("id")), Seq(Row(2, 200, 2, 200))) } else { @@ -283,18 +270,16 @@ trait DSv2IncrementallyConstructedQueryTests extends DSv2ExternalMutationTestBas } } } - } test(s"${testPrefix}SPARK-54157: join after external drop/recreate" + " (table without table ID support, but with column ID support)") { val nullIdT = "nullidcat.ns1.ns2.tbl" - withTestSession { session => - withTestTableAndViews(session, nullIdT) { - session.sql(s"CREATE TABLE $nullIdT (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $nullIdT VALUES (1, 100)").collect() + withTable(nullIdT) { + spark.sql(s"CREATE TABLE $nullIdT (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $nullIdT VALUES (1, 100)").collect() - val df1 = session.table(nullIdT) - val catalog = getTableCatalog[TableCatalog](session, "nullidcat") + val df1 = spark.table(nullIdT) + val catalog = getTableCatalog[TableCatalog](spark, "nullidcat") assert(catalog.loadTable(testIdent).id == null, "NullTableIdInMemoryTableCatalog should produce null table IDs") @@ -308,11 +293,11 @@ trait DSv2IncrementallyConstructedQueryTests extends DSv2ExternalMutationTestBas .build()) externalAppend(catalog = catalog, ident = testIdent, row = InternalRow(2, 200)) - val df2 = session.table(nullIdT) + val df2 = spark.table(nullIdT) - if (isConnect) { + if (sessionType == "connect") { // Connect re-resolves both sides to the recreated table. - checkRows( + checkAnswer( df1.join(df2, df1("id") === df2("id")), Seq(Row(2, 200, 2, 200))) } else { @@ -327,19 +312,17 @@ trait DSv2IncrementallyConstructedQueryTests extends DSv2ExternalMutationTestBas } } } - } test(s"${testPrefix}SPARK-54157: join does not detect external table drop and recreate" + " (table without table ID support and without column ID support)") { val nullBothT = "nullbothidscat.ns1.ns2.tbl" - withTestSession { session => - withTestTableAndViews(session, nullBothT) { - session.sql(s"CREATE TABLE $nullBothT (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $nullBothT VALUES (1, 100)").collect() + withTable(nullBothT) { + spark.sql(s"CREATE TABLE $nullBothT (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $nullBothT VALUES (1, 100)").collect() - val df1 = session.table(nullBothT) + val df1 = spark.table(nullBothT) val catalog = getTableCatalog[TableCatalog]( - session, "nullbothidscat") + spark, "nullbothidscat") assert(catalog.loadTable(testIdent).id == null, "NullTableIdAndNullColumnIdInMemoryTableCatalog should produce null table IDs") assert(catalog.loadTable(testIdent).columns().forall(_.id() == null), @@ -355,12 +338,12 @@ trait DSv2IncrementallyConstructedQueryTests extends DSv2ExternalMutationTestBas .build()) externalAppend(catalog = catalog, ident = testIdent, row = InternalRow(2, 200)) - val df2 = session.table(nullBothT) + val df2 = spark.table(nullBothT) - if (isConnect) { + if (sessionType == "connect") { // Connect re-resolves both sides to the recreated table, so the join // sees the row appended after recreate. - checkRows( + checkAnswer( df1.join(df2, df1("id") === df2("id")), Seq(Row(2, 200, 2, 200))) } else { @@ -368,13 +351,12 @@ trait DSv2IncrementallyConstructedQueryTests extends DSv2ExternalMutationTestBas // drop and recreate goes undetected. df1 keeps its pre-drop snapshot // (1, 100) while df2 reads the recreated table (2, 200), so the join finds // no matching ids and returns no rows. - checkRows( + checkAnswer( df1.join(df2, df1("id") === df2("id")), Seq.empty) } } } - } // --------------------------------------------------------------------------- // Scenario 5: external drop+re-add column. @@ -385,24 +367,23 @@ trait DSv2IncrementallyConstructedQueryTests extends DSv2ExternalMutationTestBas test(s"${testPrefix}SPARK-54157: join after external drop+re-add column" + " (table without table ID support, but with column ID support)") { val nullIdT = "nullidcat.ns1.ns2.tbl" - withTestSession { session => - withTestTableAndViews(session, nullIdT) { - session.sql(s"CREATE TABLE $nullIdT (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $nullIdT VALUES (1, 100)").collect() + withTable(nullIdT) { + spark.sql(s"CREATE TABLE $nullIdT (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $nullIdT VALUES (1, 100)").collect() - val df1 = session.table(nullIdT) + val df1 = spark.table(nullIdT) - val catalog = getTableCatalog[TableCatalog](session, "nullidcat") + val catalog = getTableCatalog[TableCatalog](spark, "nullidcat") catalog.alterTable( testIdent, TableChange.deleteColumn(Array("salary"), false)) catalog.alterTable( testIdent, TableChange.addColumn(Array("salary"), IntegerType, true)) - val df2 = session.table(nullIdT) + val df2 = spark.table(nullIdT) - if (isConnect) { + if (sessionType == "connect") { // Connect re-resolves both sides with the new column ID. - checkRows( + checkAnswer( df1.join(df2, df1("id") === df2("id")), Seq(Row(1, null, 1, null))) } else { @@ -417,35 +398,31 @@ trait DSv2IncrementallyConstructedQueryTests extends DSv2ExternalMutationTestBas } } } - } test(s"${testPrefix}SPARK-54157: join does not detect external drop+re-add column" + " (table without table ID support and without column ID support)") { val nullBothT = "nullbothidscat.ns1.ns2.tbl" - withTestSession { session => - withTestTableAndViews(session, nullBothT) { - session.sql(s"CREATE TABLE $nullBothT (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $nullBothT VALUES (1, 100)").collect() + withTable(nullBothT) { + spark.sql(s"CREATE TABLE $nullBothT (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $nullBothT VALUES (1, 100)").collect() - val df1 = session.table(nullBothT) + val df1 = spark.table(nullBothT) - val catalog = getTableCatalog[TableCatalog]( - session, "nullbothidscat") + val catalog = getTableCatalog[TableCatalog](spark, "nullbothidscat") catalog.alterTable( testIdent, TableChange.deleteColumn(Array("salary"), false)) catalog.alterTable( testIdent, TableChange.addColumn(Array("salary"), IntegerType, true)) - val df2 = session.table(nullBothT) + val df2 = spark.table(nullBothT) // Neither TABLE_ID_MISMATCH nor COLUMN_ID_MISMATCH fires. // The change goes undetected and the join succeeds. - checkRows( + checkAnswer( df1.join(df2, df1("id") === df2("id")), Seq(Row(1, null, 1, null))) } } - } // --------------------------------------------------------------------------- // Scenario 6: external type change (drop INT column, add STRING column). @@ -457,14 +434,13 @@ trait DSv2IncrementallyConstructedQueryTests extends DSv2ExternalMutationTestBas test(s"${testPrefix}SPARK-54157: join after external drop+re-add different-type column" + " (table with both table and column ID support)") { - withTestSession { session => - withTestTableAndViews(session, testTable) { - session.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $testTable VALUES (1, 100)").collect() + withTable(testTable) { + spark.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $testTable VALUES (1, 100)").collect() - val df1 = session.table(testTable) + val df1 = spark.table(testTable) - val catalog = getTableCatalog[InMemoryTableCatalog](session, "testcat") + val catalog = getTableCatalog[InMemoryTableCatalog](spark, "testcat") catalog.alterTable( testIdent, TableChange.deleteColumn(Array("salary"), false)) catalog.alterTable( @@ -472,11 +448,11 @@ trait DSv2IncrementallyConstructedQueryTests extends DSv2ExternalMutationTestBas externalAppend(catalog = catalog, ident = testIdent, row = InternalRow(2, UTF8String.fromString("high"))) - val df2 = session.table(testTable) + val df2 = spark.table(testTable) - if (isConnect) { + if (sessionType == "connect") { // Connect re-resolves both sides with the new column type. - checkRows( + checkAnswer( df1.join(df2, df1("id") === df2("id")), Seq(Row(1, null, 1, null), Row(2, "high", 2, "high"))) } else { @@ -491,5 +467,4 @@ trait DSv2IncrementallyConstructedQueryTests extends DSv2ExternalMutationTestBas } } } - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DSv2RepeatedTableAccessTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DSv2RepeatedTableAccessTests.scala index 533d10a949796..fb22a8bb7ab79 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DSv2RepeatedTableAccessTests.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DSv2RepeatedTableAccessTests.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.types.IntegerType * Each scenario includes a session mutation baseline, an external mutation test, and a * caching-connector variant showing stale results until `REFRESH TABLE`. * - * NOTE: All `session.sql(...)` calls append `.collect()` because Connect client DataFrames + * NOTE: All `spark.sql(...)` calls append `.collect()` because Connect client DataFrames * are lazy and require an action to trigger execution. In classic mode `.collect()` on * DDL / DML is a no-op (these execute eagerly), so this is harmless. */ @@ -45,178 +45,160 @@ trait DSv2RepeatedTableAccessTests extends DSv2ExternalMutationTestBase { // Scenario 1: data changes via writes test(s"${testPrefix}repeated sql() reflects session write") { - withTestSession { session => - withTestTableAndViews(session, testTable) { - session.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $testTable VALUES (1, 100)").collect() - checkRows(session.sql(s"SELECT * FROM $testTable"), Seq(Row(1, 100))) - - session.sql(s"INSERT INTO $testTable VALUES (2, 200)").collect() - checkRows(session.sql(s"SELECT * FROM $testTable"), Seq(Row(1, 100), Row(2, 200))) - } + withTable(testTable) { + spark.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $testTable VALUES (1, 100)").collect() + checkAnswer(spark.sql(s"SELECT * FROM $testTable"), Seq(Row(1, 100))) + + spark.sql(s"INSERT INTO $testTable VALUES (2, 200)").collect() + checkAnswer(spark.sql(s"SELECT * FROM $testTable"), Seq(Row(1, 100), Row(2, 200))) } } test(s"${testPrefix}repeated sql() reflects external write") { - withTestSession { session => - withTestTableAndViews(session, testTable) { - session.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $testTable VALUES (1, 100)").collect() - checkRows(session.sql(s"SELECT * FROM $testTable"), Seq(Row(1, 100))) + withTable(testTable) { + spark.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $testTable VALUES (1, 100)").collect() + checkAnswer(spark.sql(s"SELECT * FROM $testTable"), Seq(Row(1, 100))) - val catalog = getTableCatalog[InMemoryTableCatalog](session, "testcat") - externalAppend(catalog = catalog, ident = testIdent, row = InternalRow(2, 200)) + val catalog = getTableCatalog[InMemoryTableCatalog](spark, "testcat") + externalAppend(catalog = catalog, ident = testIdent, row = InternalRow(2, 200)) - checkRows(session.sql(s"SELECT * FROM $testTable"), Seq(Row(1, 100), Row(2, 200))) - } + checkAnswer(spark.sql(s"SELECT * FROM $testTable"), Seq(Row(1, 100), Row(2, 200))) } } test(s"${testPrefix}connector w/ cache: repeated sql() stale after external write") { - withTestSession { session => - withTestTableAndViews(session, cachingTestTable) { - session.sql(s"CREATE TABLE $cachingTestTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $cachingTestTable VALUES (1, 100)").collect() - checkRows(session.sql(s"SELECT * FROM $cachingTestTable"), Seq(Row(1, 100))) - - val catalog = getTableCatalog[CachingInMemoryTableCatalog](session, "cachingcat") - externalAppend(catalog = catalog, ident = testIdent, row = InternalRow(2, 200)) - - // Caching connector returns stale table: external write invisible - checkRows(session.sql(s"SELECT * FROM $cachingTestTable"), Seq(Row(1, 100))) - - // REFRESH TABLE invalidates the connector cache, external write becomes visible - session.sql(s"REFRESH TABLE $cachingTestTable").collect() - checkRows(session.sql(s"SELECT * FROM $cachingTestTable"), Seq(Row(1, 100), Row(2, 200))) - } + withTable(cachingTestTable) { + spark.sql(s"CREATE TABLE $cachingTestTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $cachingTestTable VALUES (1, 100)").collect() + checkAnswer(spark.sql(s"SELECT * FROM $cachingTestTable"), Seq(Row(1, 100))) + + val catalog = getTableCatalog[CachingInMemoryTableCatalog](spark, "cachingcat") + externalAppend(catalog = catalog, ident = testIdent, row = InternalRow(2, 200)) + + // Caching connector returns stale table: external write invisible + checkAnswer(spark.sql(s"SELECT * FROM $cachingTestTable"), Seq(Row(1, 100))) + + // REFRESH TABLE invalidates the connector cache, external write becomes visible + spark.sql(s"REFRESH TABLE $cachingTestTable").collect() + checkAnswer(spark.sql(s"SELECT * FROM $cachingTestTable"), Seq(Row(1, 100), Row(2, 200))) } } // Scenario 2: schema changes test(s"${testPrefix}repeated sql() reflects session schema change") { - withTestSession { session => - withTestTableAndViews(session, testTable) { - session.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $testTable VALUES (1, 100)").collect() - checkRows(session.sql(s"SELECT * FROM $testTable"), Seq(Row(1, 100))) - - session.sql(s"ALTER TABLE $testTable ADD COLUMN new_col INT").collect() - session.sql(s"INSERT INTO $testTable VALUES (2, 200, -1)").collect() - checkRows( - session.sql(s"SELECT * FROM $testTable"), - Seq(Row(1, 100, null), Row(2, 200, -1))) - } + withTable(testTable) { + spark.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $testTable VALUES (1, 100)").collect() + checkAnswer(spark.sql(s"SELECT * FROM $testTable"), Seq(Row(1, 100))) + + spark.sql(s"ALTER TABLE $testTable ADD COLUMN new_col INT").collect() + spark.sql(s"INSERT INTO $testTable VALUES (2, 200, -1)").collect() + checkAnswer( + spark.sql(s"SELECT * FROM $testTable"), + Seq(Row(1, 100, null), Row(2, 200, -1))) } } test(s"${testPrefix}repeated sql() reflects external schema change") { - withTestSession { session => - withTestTableAndViews(session, testTable) { - session.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $testTable VALUES (1, 100)").collect() - checkRows(session.sql(s"SELECT * FROM $testTable"), Seq(Row(1, 100))) - - val catalog = getTableCatalog[InMemoryTableCatalog](session, "testcat") - val addCol = TableChange.addColumn(Array("new_col"), IntegerType, true) - catalog.alterTable(testIdent, addCol) - - externalAppend(catalog = catalog, ident = testIdent, row = InternalRow(2, 200, -1)) - - checkRows( - session.sql(s"SELECT * FROM $testTable"), - Seq(Row(1, 100, null), Row(2, 200, -1))) - } + withTable(testTable) { + spark.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $testTable VALUES (1, 100)").collect() + checkAnswer(spark.sql(s"SELECT * FROM $testTable"), Seq(Row(1, 100))) + + val catalog = getTableCatalog[InMemoryTableCatalog](spark, "testcat") + val addCol = TableChange.addColumn(Array("new_col"), IntegerType, true) + catalog.alterTable(testIdent, addCol) + + externalAppend(catalog = catalog, ident = testIdent, row = InternalRow(2, 200, -1)) + + checkAnswer( + spark.sql(s"SELECT * FROM $testTable"), + Seq(Row(1, 100, null), Row(2, 200, -1))) } } test(s"${testPrefix}connector w/ cache: repeated sql() stale after external schema change") { - withTestSession { session => - withTestTableAndViews(session, cachingTestTable) { - session.sql(s"CREATE TABLE $cachingTestTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $cachingTestTable VALUES (1, 100)").collect() - checkRows(session.sql(s"SELECT * FROM $cachingTestTable"), Seq(Row(1, 100))) - - val catalog = getTableCatalog[CachingInMemoryTableCatalog](session, "cachingcat") - val addCol = TableChange.addColumn(Array("new_col"), IntegerType, true) - catalog.alterTable(testIdent, addCol) - - externalAppend(catalog = catalog, ident = testIdent, row = InternalRow(2, 200, -1)) - - // Caching connector returns stale table: external changes invisible - checkRows(session.sql(s"SELECT * FROM $cachingTestTable"), Seq(Row(1, 100))) - - // REFRESH TABLE invalidates the connector cache, schema change + data visible - session.sql(s"REFRESH TABLE $cachingTestTable").collect() - checkRows( - session.sql(s"SELECT * FROM $cachingTestTable"), - Seq(Row(1, 100, null), Row(2, 200, -1))) - } + withTable(cachingTestTable) { + spark.sql(s"CREATE TABLE $cachingTestTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $cachingTestTable VALUES (1, 100)").collect() + checkAnswer(spark.sql(s"SELECT * FROM $cachingTestTable"), Seq(Row(1, 100))) + + val catalog = getTableCatalog[CachingInMemoryTableCatalog](spark, "cachingcat") + val addCol = TableChange.addColumn(Array("new_col"), IntegerType, true) + catalog.alterTable(testIdent, addCol) + + externalAppend(catalog = catalog, ident = testIdent, row = InternalRow(2, 200, -1)) + + // Caching connector returns stale table: external changes invisible + checkAnswer(spark.sql(s"SELECT * FROM $cachingTestTable"), Seq(Row(1, 100))) + + // REFRESH TABLE invalidates the connector cache, schema change + data visible + spark.sql(s"REFRESH TABLE $cachingTestTable").collect() + checkAnswer( + spark.sql(s"SELECT * FROM $cachingTestTable"), + Seq(Row(1, 100, null), Row(2, 200, -1))) } } // Scenario 3: drop and recreate table test(s"${testPrefix}repeated sql() reflects session drop/recreate") { - withTestSession { session => - withTestTableAndViews(session, testTable) { - session.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $testTable VALUES (1, 100)").collect() - checkRows(session.sql(s"SELECT * FROM $testTable"), Seq(Row(1, 100))) - - session.sql(s"DROP TABLE $testTable").collect() - session.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() - checkRows(session.sql(s"SELECT * FROM $testTable"), Seq.empty) - } + withTable(testTable) { + spark.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $testTable VALUES (1, 100)").collect() + checkAnswer(spark.sql(s"SELECT * FROM $testTable"), Seq(Row(1, 100))) + + spark.sql(s"DROP TABLE $testTable").collect() + spark.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() + checkAnswer(spark.sql(s"SELECT * FROM $testTable"), Seq.empty) } } test(s"${testPrefix}repeated sql() reflects external drop/recreate") { - withTestSession { session => - withTestTableAndViews(session, testTable) { - session.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $testTable VALUES (1, 100)").collect() - checkRows(session.sql(s"SELECT * FROM $testTable"), Seq(Row(1, 100))) - - val catalog = getTableCatalog[InMemoryTableCatalog](session, "testcat") - catalog.dropTable(testIdent) - catalog.createTable( - testIdent, - new TableInfo.Builder() - .withColumns(Array( - Column.create("id", IntegerType), - Column.create("salary", IntegerType))) - .build()) - - checkRows(session.sql(s"SELECT * FROM $testTable"), Seq.empty) - } + withTable(testTable) { + spark.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $testTable VALUES (1, 100)").collect() + checkAnswer(spark.sql(s"SELECT * FROM $testTable"), Seq(Row(1, 100))) + + val catalog = getTableCatalog[InMemoryTableCatalog](spark, "testcat") + catalog.dropTable(testIdent) + catalog.createTable( + testIdent, + new TableInfo.Builder() + .withColumns(Array( + Column.create("id", IntegerType), + Column.create("salary", IntegerType))) + .build()) + + checkAnswer(spark.sql(s"SELECT * FROM $testTable"), Seq.empty) } } test(s"${testPrefix}connector w/ cache: repeated sql() stale after external drop/recreate") { - withTestSession { session => - withTestTableAndViews(session, cachingTestTable) { - session.sql(s"CREATE TABLE $cachingTestTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $cachingTestTable VALUES (1, 100)").collect() - checkRows(session.sql(s"SELECT * FROM $cachingTestTable"), Seq(Row(1, 100))) - - val catalog = getTableCatalog[CachingInMemoryTableCatalog](session, "cachingcat") - catalog.dropTable(testIdent) - catalog.createTable( - testIdent, - new TableInfo.Builder() - .withColumns(Array( - Column.create("id", IntegerType), - Column.create("salary", IntegerType))) - .build()) - - // Caching connector returns stale table: drop/recreate invisible - checkRows(session.sql(s"SELECT * FROM $cachingTestTable"), Seq(Row(1, 100))) - - // REFRESH TABLE invalidates the connector cache, new empty table visible - session.sql(s"REFRESH TABLE $cachingTestTable").collect() - checkRows(session.sql(s"SELECT * FROM $cachingTestTable"), Seq.empty) - } + withTable(cachingTestTable) { + spark.sql(s"CREATE TABLE $cachingTestTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $cachingTestTable VALUES (1, 100)").collect() + checkAnswer(spark.sql(s"SELECT * FROM $cachingTestTable"), Seq(Row(1, 100))) + + val catalog = getTableCatalog[CachingInMemoryTableCatalog](spark, "cachingcat") + catalog.dropTable(testIdent) + catalog.createTable( + testIdent, + new TableInfo.Builder() + .withColumns(Array( + Column.create("id", IntegerType), + Column.create("salary", IntegerType))) + .build()) + + // Caching connector returns stale table: drop/recreate invisible + checkAnswer(spark.sql(s"SELECT * FROM $cachingTestTable"), Seq(Row(1, 100))) + + // REFRESH TABLE invalidates the connector cache, new empty table visible + spark.sql(s"REFRESH TABLE $cachingTestTable").collect() + checkAnswer(spark.sql(s"SELECT * FROM $cachingTestTable"), Seq.empty) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DSv2TempViewWithStoredPlanTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DSv2TempViewWithStoredPlanTests.scala index 9f8a93e30550f..e473968794c37 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DSv2TempViewWithStoredPlanTests.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DSv2TempViewWithStoredPlanTests.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.types.{IntegerType, LongType, StringType} * backed by DSv2 tables correctly handle data changes, schema changes, and table recreation, * both via session SQL and external catalog mutations. * - * NOTE: All `session.sql(...)` calls append `.collect()` because Connect client DataFrames + * NOTE: All `spark.sql(...)` calls append `.collect()` because Connect client DataFrames * are lazy and require an action to trigger execution. In classic mode `.collect()` on DDL * is a no-op (DDL executes eagerly), so this is harmless. */ @@ -35,143 +35,143 @@ trait DSv2TempViewWithStoredPlanTests extends DSv2ExternalMutationTestBase { // Scenario 1.1 (session write) test(s"${testPrefix}temp view with stored plan reflects session write") { - withTestSession { session => - withTestTableAndViews(session, testTable, Seq("v")) { - session.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $testTable VALUES (1, 100), (10, 1000)").collect() + withTable(testTable) { + withView("v") { + spark.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $testTable VALUES (1, 100), (10, 1000)").collect() - session.table(testTable).filter("salary < 999").createOrReplaceTempView("v") - checkRows(session.table("v"), Seq(Row(1, 100))) + spark.table(testTable).filter("salary < 999").createOrReplaceTempView("v") + checkAnswer(spark.table("v"), Seq(Row(1, 100))) - session.sql(s"INSERT INTO $testTable VALUES (2, 200)").collect() - checkRows(session.table("v"), Seq(Row(1, 100), Row(2, 200))) + spark.sql(s"INSERT INTO $testTable VALUES (2, 200)").collect() + checkAnswer(spark.table("v"), Seq(Row(1, 100), Row(2, 200))) } } } // Scenario 1.2 (external write) test(s"${testPrefix}temp view with stored plan reflects external write") { - withTestSession { session => - withTestTableAndViews(session, testTable, Seq("v")) { - session.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $testTable VALUES (1, 100), (10, 1000)").collect() + withTable(testTable) { + withView("v") { + spark.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $testTable VALUES (1, 100), (10, 1000)").collect() - session.table(testTable).filter("salary < 999").createOrReplaceTempView("v") - checkRows(session.table("v"), Seq(Row(1, 100))) + spark.table(testTable).filter("salary < 999").createOrReplaceTempView("v") + checkAnswer(spark.table("v"), Seq(Row(1, 100))) - val catalog = getTableCatalog[InMemoryTableCatalog](session, "testcat") + val catalog = getTableCatalog[InMemoryTableCatalog](spark, "testcat") externalAppend(catalog = catalog, ident = testIdent, row = InternalRow(2, 200)) - checkRows(session.table("v"), Seq(Row(1, 100), Row(2, 200))) + checkAnswer(spark.table("v"), Seq(Row(1, 100), Row(2, 200))) } } } // Scenario 1.2 connector w/ cache (external write, caching connector) test(s"${testPrefix}connector w/ cache: temp view stale after external write") { - withTestSession { session => - withTestTableAndViews(session, cachingTestTable, Seq("v")) { - session.sql(s"CREATE TABLE $cachingTestTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $cachingTestTable VALUES (1, 100), (10, 1000)").collect() + withTable(cachingTestTable) { + withView("v") { + spark.sql(s"CREATE TABLE $cachingTestTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $cachingTestTable VALUES (1, 100), (10, 1000)").collect() - session.table(cachingTestTable).filter("salary < 999").createOrReplaceTempView("v") - checkRows(session.table("v"), Seq(Row(1, 100))) + spark.table(cachingTestTable).filter("salary < 999").createOrReplaceTempView("v") + checkAnswer(spark.table("v"), Seq(Row(1, 100))) - val catalog = getTableCatalog[CachingInMemoryTableCatalog](session, "cachingcat") + val catalog = getTableCatalog[CachingInMemoryTableCatalog](spark, "cachingcat") externalAppend(catalog = catalog, ident = testIdent, row = InternalRow(2, 200)) // Caching connector returns stale table: external write invisible - checkRows(session.table("v"), Seq(Row(1, 100))) + checkAnswer(spark.table("v"), Seq(Row(1, 100))) // REFRESH TABLE invalidates the connector cache, external write becomes visible - session.sql(s"REFRESH TABLE $cachingTestTable").collect() - checkRows(session.table("v"), Seq(Row(1, 100), Row(2, 200))) + spark.sql(s"REFRESH TABLE $cachingTestTable").collect() + checkAnswer(spark.table("v"), Seq(Row(1, 100), Row(2, 200))) } } } // Scenario 2.1 (session ADD COLUMN) test(s"${testPrefix}temp view with stored plan preserves schema after session ADD COLUMN") { - withTestSession { session => - withTestTableAndViews(session, testTable, Seq("v")) { - session.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $testTable VALUES (1, 100), (10, 1000)").collect() + withTable(testTable) { + withView("v") { + spark.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $testTable VALUES (1, 100), (10, 1000)").collect() - session.table(testTable).filter("salary < 999").createOrReplaceTempView("v") - checkRows(session.table("v"), Seq(Row(1, 100))) + spark.table(testTable).filter("salary < 999").createOrReplaceTempView("v") + checkAnswer(spark.table("v"), Seq(Row(1, 100))) - session.sql(s"ALTER TABLE $testTable ADD COLUMN new_column INT").collect() - session.sql(s"INSERT INTO $testTable VALUES (2, 200, -1)").collect() + spark.sql(s"ALTER TABLE $testTable ADD COLUMN new_column INT").collect() + spark.sql(s"INSERT INTO $testTable VALUES (2, 200, -1)").collect() // view preserves original 2-column schema, filter still applied - checkRows(session.table("v"), Seq(Row(1, 100), Row(2, 200))) + checkAnswer(spark.table("v"), Seq(Row(1, 100), Row(2, 200))) } } } // Scenario 2.2 (external ADD COLUMN) test(s"${testPrefix}temp view with stored plan preserves schema after external ADD COLUMN") { - withTestSession { session => - withTestTableAndViews(session, testTable, Seq("v")) { - session.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $testTable VALUES (1, 100), (10, 1000)").collect() + withTable(testTable) { + withView("v") { + spark.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $testTable VALUES (1, 100), (10, 1000)").collect() - session.table(testTable).filter("salary < 999").createOrReplaceTempView("v") - checkRows(session.table("v"), Seq(Row(1, 100))) + spark.table(testTable).filter("salary < 999").createOrReplaceTempView("v") + checkAnswer(spark.table("v"), Seq(Row(1, 100))) // external schema change via catalog API - val catalog = getTableCatalog[InMemoryTableCatalog](session, "testcat") + val catalog = getTableCatalog[InMemoryTableCatalog](spark, "testcat") val addCol = TableChange.addColumn(Array("new_column"), IntegerType, true) catalog.alterTable(testIdent, addCol) externalAppend(catalog = catalog, ident = testIdent, row = InternalRow(2, 200, -1)) // view preserves original 2-column schema, filter still applied - checkRows(session.table("v"), Seq(Row(1, 100), Row(2, 200))) + checkAnswer(spark.table("v"), Seq(Row(1, 100), Row(2, 200))) } } } // Scenario 2.2 connector w/ cache (external ADD COLUMN, caching connector) test(s"${testPrefix}connector w/ cache: temp view stale after external ADD COLUMN") { - withTestSession { session => - withTestTableAndViews(session, cachingTestTable, Seq("v")) { - session.sql(s"CREATE TABLE $cachingTestTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $cachingTestTable VALUES (1, 100), (10, 1000)").collect() + withTable(cachingTestTable) { + withView("v") { + spark.sql(s"CREATE TABLE $cachingTestTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $cachingTestTable VALUES (1, 100), (10, 1000)").collect() - session.table(cachingTestTable).filter("salary < 999").createOrReplaceTempView("v") - checkRows(session.table("v"), Seq(Row(1, 100))) + spark.table(cachingTestTable).filter("salary < 999").createOrReplaceTempView("v") + checkAnswer(spark.table("v"), Seq(Row(1, 100))) - val catalog = getTableCatalog[CachingInMemoryTableCatalog](session, "cachingcat") + val catalog = getTableCatalog[CachingInMemoryTableCatalog](spark, "cachingcat") val addCol = TableChange.addColumn(Array("new_column"), IntegerType, true) catalog.alterTable(testIdent, addCol) externalAppend(catalog = catalog, ident = testIdent, row = InternalRow(2, 200, -1)) // Caching connector returns stale table: external changes invisible - checkRows(session.table("v"), Seq(Row(1, 100))) + checkAnswer(spark.table("v"), Seq(Row(1, 100))) // REFRESH TABLE invalidates the connector cache, view preserves original 2-column schema - session.sql(s"REFRESH TABLE $cachingTestTable").collect() - checkRows(session.table("v"), Seq(Row(1, 100), Row(2, 200))) + spark.sql(s"REFRESH TABLE $cachingTestTable").collect() + checkAnswer(spark.table("v"), Seq(Row(1, 100), Row(2, 200))) } } } // Scenario 3.1 (session column removal) test(s"${testPrefix}temp view with stored plan detects session column removal") { - withTestSession { session => - withTestTableAndViews(session, testTable, Seq("v")) { - session.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $testTable VALUES (1, 100), (10, 1000)").collect() + withTable(testTable) { + withView("v") { + spark.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $testTable VALUES (1, 100), (10, 1000)").collect() - session.table(testTable).filter("salary < 999").createOrReplaceTempView("v") - checkRows(session.table("v"), Seq(Row(1, 100))) + spark.table(testTable).filter("salary < 999").createOrReplaceTempView("v") + checkAnswer(spark.table("v"), Seq(Row(1, 100))) - session.sql(s"ALTER TABLE $testTable DROP COLUMN salary").collect() + spark.sql(s"ALTER TABLE $testTable DROP COLUMN salary").collect() checkError( - exception = intercept[AnalysisException] { session.table("v").collect() }, + exception = intercept[AnalysisException] { spark.table("v").collect() }, condition = "INCOMPATIBLE_COLUMN_CHANGES_AFTER_VIEW_WITH_PLAN_CREATION", parameters = Map( "viewName" -> "`v`", @@ -184,20 +184,20 @@ trait DSv2TempViewWithStoredPlanTests extends DSv2ExternalMutationTestBase { // Scenario 3.2 (external column removal) test(s"${testPrefix}temp view with stored plan detects external column removal") { - withTestSession { session => - withTestTableAndViews(session, testTable, Seq("v")) { - session.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $testTable VALUES (1, 100), (10, 1000)").collect() + withTable(testTable) { + withView("v") { + spark.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $testTable VALUES (1, 100), (10, 1000)").collect() - session.table(testTable).filter("salary < 999").createOrReplaceTempView("v") - checkRows(session.table("v"), Seq(Row(1, 100))) + spark.table(testTable).filter("salary < 999").createOrReplaceTempView("v") + checkAnswer(spark.table("v"), Seq(Row(1, 100))) - val catalog = getTableCatalog[InMemoryTableCatalog](session, "testcat") + val catalog = getTableCatalog[InMemoryTableCatalog](spark, "testcat") val dropCol = TableChange.deleteColumn(Array("salary"), false) catalog.alterTable(testIdent, dropCol) checkError( - exception = intercept[AnalysisException] { session.table("v").collect() }, + exception = intercept[AnalysisException] { spark.table("v").collect() }, condition = "INCOMPATIBLE_COLUMN_CHANGES_AFTER_VIEW_WITH_PLAN_CREATION", parameters = Map( "viewName" -> "`v`", @@ -210,25 +210,25 @@ trait DSv2TempViewWithStoredPlanTests extends DSv2ExternalMutationTestBase { // Scenario 3.2 connector w/ cache (external column removal, caching connector) test(s"${testPrefix}connector w/ cache: temp view stale after external column removal") { - withTestSession { session => - withTestTableAndViews(session, cachingTestTable, Seq("v")) { - session.sql(s"CREATE TABLE $cachingTestTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $cachingTestTable VALUES (1, 100), (10, 1000)").collect() + withTable(cachingTestTable) { + withView("v") { + spark.sql(s"CREATE TABLE $cachingTestTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $cachingTestTable VALUES (1, 100), (10, 1000)").collect() - session.table(cachingTestTable).filter("salary < 999").createOrReplaceTempView("v") - checkRows(session.table("v"), Seq(Row(1, 100))) + spark.table(cachingTestTable).filter("salary < 999").createOrReplaceTempView("v") + checkAnswer(spark.table("v"), Seq(Row(1, 100))) - val catalog = getTableCatalog[CachingInMemoryTableCatalog](session, "cachingcat") + val catalog = getTableCatalog[CachingInMemoryTableCatalog](spark, "cachingcat") val dropCol = TableChange.deleteColumn(Array("salary"), false) catalog.alterTable(testIdent, dropCol) // Caching connector returns stale table: column removal invisible, no error - checkRows(session.table("v"), Seq(Row(1, 100))) + checkAnswer(spark.table("v"), Seq(Row(1, 100))) // REFRESH TABLE invalidates the connector cache, column removal detected - session.sql(s"REFRESH TABLE $cachingTestTable").collect() + spark.sql(s"REFRESH TABLE $cachingTestTable").collect() checkError( - exception = intercept[AnalysisException] { session.table("v").collect() }, + exception = intercept[AnalysisException] { spark.table("v").collect() }, condition = "INCOMPATIBLE_COLUMN_CHANGES_AFTER_VIEW_WITH_PLAN_CREATION", parameters = Map( "viewName" -> "`v`", @@ -241,43 +241,43 @@ trait DSv2TempViewWithStoredPlanTests extends DSv2ExternalMutationTestBase { // Scenario 4.1 (session drop and recreate table) test(s"${testPrefix}temp view with stored plan resolves to session-recreated table") { - withTestSession { session => - withTestTableAndViews(session, testTable, Seq("v")) { - session.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $testTable VALUES (1, 100), (10, 1000)").collect() + withTable(testTable) { + withView("v") { + spark.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $testTable VALUES (1, 100), (10, 1000)").collect() - session.table(testTable).filter("salary < 999").createOrReplaceTempView("v") - checkRows(session.table("v"), Seq(Row(1, 100))) + spark.table(testTable).filter("salary < 999").createOrReplaceTempView("v") + checkAnswer(spark.table("v"), Seq(Row(1, 100))) - val catalog = getTableCatalog[InMemoryTableCatalog](session, "testcat") + val catalog = getTableCatalog[InMemoryTableCatalog](spark, "testcat") val originalTableId = catalog.loadTable(testIdent).id - session.sql(s"DROP TABLE $testTable").collect() - session.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() + spark.sql(s"DROP TABLE $testTable").collect() + spark.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() val newTableId = catalog.loadTable(testIdent).id assert(originalTableId != newTableId) // view resolves to the new empty table - checkRows(session.table("v"), Seq.empty) + checkAnswer(spark.table("v"), Seq.empty) - session.sql(s"INSERT INTO $testTable VALUES (2, 200)").collect() - checkRows(session.table("v"), Seq(Row(2, 200))) + spark.sql(s"INSERT INTO $testTable VALUES (2, 200)").collect() + checkAnswer(spark.table("v"), Seq(Row(2, 200))) } } } // Scenario 4.2 (external drop and recreate table) test(s"${testPrefix}temp view with stored plan resolves to externally recreated table") { - withTestSession { session => - withTestTableAndViews(session, testTable, Seq("v")) { - session.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $testTable VALUES (1, 100), (10, 1000)").collect() + withTable(testTable) { + withView("v") { + spark.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $testTable VALUES (1, 100), (10, 1000)").collect() - session.table(testTable).filter("salary < 999").createOrReplaceTempView("v") - checkRows(session.table("v"), Seq(Row(1, 100))) + spark.table(testTable).filter("salary < 999").createOrReplaceTempView("v") + checkAnswer(spark.table("v"), Seq(Row(1, 100))) - val catalog = getTableCatalog[InMemoryTableCatalog](session, "testcat") + val catalog = getTableCatalog[InMemoryTableCatalog](spark, "testcat") val originalTableId = catalog.loadTable(testIdent).id catalog.dropTable(testIdent) @@ -293,25 +293,25 @@ trait DSv2TempViewWithStoredPlanTests extends DSv2ExternalMutationTestBase { assert(originalTableId != newTableId) // view resolves to the new empty table - checkRows(session.table("v"), Seq.empty) + checkAnswer(spark.table("v"), Seq.empty) - session.sql(s"INSERT INTO $testTable VALUES (2, 200)").collect() - checkRows(session.table("v"), Seq(Row(2, 200))) + spark.sql(s"INSERT INTO $testTable VALUES (2, 200)").collect() + checkAnswer(spark.table("v"), Seq(Row(2, 200))) } } } // Scenario 4.2 connector w/ cache (external drop/recreate, caching connector) test(s"${testPrefix}connector w/ cache: temp view stale after external drop/recreate") { - withTestSession { session => - withTestTableAndViews(session, cachingTestTable, Seq("v")) { - session.sql(s"CREATE TABLE $cachingTestTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $cachingTestTable VALUES (1, 100), (10, 1000)").collect() + withTable(cachingTestTable) { + withView("v") { + spark.sql(s"CREATE TABLE $cachingTestTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $cachingTestTable VALUES (1, 100), (10, 1000)").collect() - session.table(cachingTestTable).filter("salary < 999").createOrReplaceTempView("v") - checkRows(session.table("v"), Seq(Row(1, 100))) + spark.table(cachingTestTable).filter("salary < 999").createOrReplaceTempView("v") + checkAnswer(spark.table("v"), Seq(Row(1, 100))) - val catalog = getTableCatalog[CachingInMemoryTableCatalog](session, "cachingcat") + val catalog = getTableCatalog[CachingInMemoryTableCatalog](spark, "cachingcat") catalog.dropTable(testIdent) catalog.createTable( testIdent, @@ -322,11 +322,11 @@ trait DSv2TempViewWithStoredPlanTests extends DSv2ExternalMutationTestBase { .build()) // Caching connector returns stale table: drop/recreate invisible - checkRows(session.table("v"), Seq(Row(1, 100))) + checkAnswer(spark.table("v"), Seq(Row(1, 100))) // REFRESH TABLE invalidates the connector cache, view resolves to new empty table - session.sql(s"REFRESH TABLE $cachingTestTable").collect() - checkRows(session.table("v"), Seq.empty) + spark.sql(s"REFRESH TABLE $cachingTestTable").collect() + checkAnswer(spark.table("v"), Seq.empty) } } } @@ -334,29 +334,29 @@ trait DSv2TempViewWithStoredPlanTests extends DSv2ExternalMutationTestBase { // Scenario 5.1 (session drop and re-add column with same type, multiple views) test(s"${testPrefix}temp view with stored plan after session drop and re-add column same type" + " with unfiltered view") { - withTestSession { session => - withTestTableAndViews(session, testTable, Seq("v", "v_no_filter", "v_filter_is_null")) { - session.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $testTable VALUES (1, 100), (10, 1000)").collect() - - session.table(testTable).filter("salary < 999").createOrReplaceTempView("v") - session.table(testTable).createOrReplaceTempView("v_no_filter") - session.table(testTable).filter("salary IS NULL") + withTable(testTable) { + withView("v", "v_no_filter", "v_filter_is_null") { + spark.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $testTable VALUES (1, 100), (10, 1000)").collect() + + spark.table(testTable).filter("salary < 999").createOrReplaceTempView("v") + spark.table(testTable).createOrReplaceTempView("v_no_filter") + spark.table(testTable).filter("salary IS NULL") .createOrReplaceTempView("v_filter_is_null") - checkRows(session.table("v"), Seq(Row(1, 100))) - checkRows(session.table("v_no_filter"), Seq(Row(1, 100), Row(10, 1000))) - checkRows(session.table("v_filter_is_null"), Seq.empty) + checkAnswer(spark.table("v"), Seq(Row(1, 100))) + checkAnswer(spark.table("v_no_filter"), Seq(Row(1, 100), Row(10, 1000))) + checkAnswer(spark.table("v_filter_is_null"), Seq.empty) // drop and re-add column with same name and type - session.sql(s"ALTER TABLE $testTable DROP COLUMN salary").collect() - session.sql(s"ALTER TABLE $testTable ADD COLUMN salary INT").collect() + spark.sql(s"ALTER TABLE $testTable DROP COLUMN salary").collect() + spark.sql(s"ALTER TABLE $testTable ADD COLUMN salary INT").collect() // salary values are now null, so the filtered view returns nothing - checkRows(session.table("v"), Seq.empty) + checkAnswer(spark.table("v"), Seq.empty) // unfiltered view returns rows with null salary - checkRows(session.table("v_no_filter"), Seq(Row(1, null), Row(10, null))) + checkAnswer(spark.table("v_no_filter"), Seq(Row(1, null), Row(10, null))) // IS NULL filter now matches all rows - checkRows(session.table("v_filter_is_null"), Seq(Row(1, null), Row(10, null))) + checkAnswer(spark.table("v_filter_is_null"), Seq(Row(1, null), Row(10, null))) } } } @@ -364,31 +364,31 @@ trait DSv2TempViewWithStoredPlanTests extends DSv2ExternalMutationTestBase { // Scenario 5.2 (external drop and re-add column with same type) test(s"${testPrefix}temp view with stored plan after external drop and re-add column " + "same type") { - withTestSession { session => - withTestTableAndViews(session, testTable, Seq("v", "v_no_filter", "v_filter_is_null")) { - session.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $testTable VALUES (1, 100), (10, 1000)").collect() - - session.table(testTable).filter("salary < 999").createOrReplaceTempView("v") - session.table(testTable).createOrReplaceTempView("v_no_filter") - session.table(testTable).filter("salary IS NULL") + withTable(testTable) { + withView("v", "v_no_filter", "v_filter_is_null") { + spark.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $testTable VALUES (1, 100), (10, 1000)").collect() + + spark.table(testTable).filter("salary < 999").createOrReplaceTempView("v") + spark.table(testTable).createOrReplaceTempView("v_no_filter") + spark.table(testTable).filter("salary IS NULL") .createOrReplaceTempView("v_filter_is_null") - checkRows(session.table("v"), Seq(Row(1, 100))) - checkRows(session.table("v_no_filter"), Seq(Row(1, 100), Row(10, 1000))) - checkRows(session.table("v_filter_is_null"), Seq.empty) + checkAnswer(spark.table("v"), Seq(Row(1, 100))) + checkAnswer(spark.table("v_no_filter"), Seq(Row(1, 100), Row(10, 1000))) + checkAnswer(spark.table("v_filter_is_null"), Seq.empty) // external drop and re-add column via catalog API - val catalog = getTableCatalog[InMemoryTableCatalog](session, "testcat") + val catalog = getTableCatalog[InMemoryTableCatalog](spark, "testcat") val dropCol = TableChange.deleteColumn(Array("salary"), false) val addCol = TableChange.addColumn(Array("salary"), IntegerType, true) catalog.alterTable(testIdent, dropCol, addCol) // salary values are now null, so the filtered view returns nothing - checkRows(session.table("v"), Seq.empty) + checkAnswer(spark.table("v"), Seq.empty) // unfiltered view returns rows with null salary - checkRows(session.table("v_no_filter"), Seq(Row(1, null), Row(10, null))) + checkAnswer(spark.table("v_no_filter"), Seq(Row(1, null), Row(10, null))) // IS NULL filter now matches all rows - checkRows(session.table("v_filter_is_null"), Seq(Row(1, null), Row(10, null))) + checkAnswer(spark.table("v_filter_is_null"), Seq(Row(1, null), Row(10, null))) } } } @@ -396,44 +396,44 @@ trait DSv2TempViewWithStoredPlanTests extends DSv2ExternalMutationTestBase { // Scenario 5.2 connector w/ cache (external drop/re-add column, caching connector) test(s"${testPrefix}connector w/ cache: temp view stale after external drop/re-add column " + "same type") { - withTestSession { session => - withTestTableAndViews(session, cachingTestTable, Seq("v")) { - session.sql(s"CREATE TABLE $cachingTestTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $cachingTestTable VALUES (1, 100), (10, 1000)").collect() + withTable(cachingTestTable) { + withView("v") { + spark.sql(s"CREATE TABLE $cachingTestTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $cachingTestTable VALUES (1, 100), (10, 1000)").collect() - session.table(cachingTestTable).filter("salary < 999").createOrReplaceTempView("v") - checkRows(session.table("v"), Seq(Row(1, 100))) + spark.table(cachingTestTable).filter("salary < 999").createOrReplaceTempView("v") + checkAnswer(spark.table("v"), Seq(Row(1, 100))) - val catalog = getTableCatalog[CachingInMemoryTableCatalog](session, "cachingcat") + val catalog = getTableCatalog[CachingInMemoryTableCatalog](spark, "cachingcat") val dropCol = TableChange.deleteColumn(Array("salary"), false) val addCol = TableChange.addColumn(Array("salary"), IntegerType, true) catalog.alterTable(testIdent, dropCol, addCol) // Caching connector returns stale table: column drop/re-add invisible - checkRows(session.table("v"), Seq(Row(1, 100))) + checkAnswer(spark.table("v"), Seq(Row(1, 100))) // REFRESH TABLE invalidates the connector cache, salary values are null - session.sql(s"REFRESH TABLE $cachingTestTable").collect() - checkRows(session.table("v"), Seq.empty) + spark.sql(s"REFRESH TABLE $cachingTestTable").collect() + checkAnswer(spark.table("v"), Seq.empty) } } } // Scenario 6.1 (session drop and re-add column with different type) test(s"${testPrefix}temp view with stored plan detects session column type change") { - withTestSession { session => - withTestTableAndViews(session, testTable, Seq("v")) { - session.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $testTable VALUES (1, 100), (10, 1000)").collect() + withTable(testTable) { + withView("v") { + spark.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $testTable VALUES (1, 100), (10, 1000)").collect() - session.table(testTable).filter("salary < 999").createOrReplaceTempView("v") - checkRows(session.table("v"), Seq(Row(1, 100))) + spark.table(testTable).filter("salary < 999").createOrReplaceTempView("v") + checkAnswer(spark.table("v"), Seq(Row(1, 100))) - session.sql(s"ALTER TABLE $testTable DROP COLUMN salary").collect() - session.sql(s"ALTER TABLE $testTable ADD COLUMN salary STRING").collect() + spark.sql(s"ALTER TABLE $testTable DROP COLUMN salary").collect() + spark.sql(s"ALTER TABLE $testTable ADD COLUMN salary STRING").collect() checkError( - exception = intercept[AnalysisException] { session.table("v").collect() }, + exception = intercept[AnalysisException] { spark.table("v").collect() }, condition = "INCOMPATIBLE_COLUMN_CHANGES_AFTER_VIEW_WITH_PLAN_CREATION", parameters = Map( "viewName" -> "`v`", @@ -446,21 +446,21 @@ trait DSv2TempViewWithStoredPlanTests extends DSv2ExternalMutationTestBase { // Scenario 6.2 (external drop and re-add column with different type) test(s"${testPrefix}temp view with stored plan detects external column type change") { - withTestSession { session => - withTestTableAndViews(session, testTable, Seq("v")) { - session.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $testTable VALUES (1, 100), (10, 1000)").collect() + withTable(testTable) { + withView("v") { + spark.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $testTable VALUES (1, 100), (10, 1000)").collect() - session.table(testTable).filter("salary < 999").createOrReplaceTempView("v") - checkRows(session.table("v"), Seq(Row(1, 100))) + spark.table(testTable).filter("salary < 999").createOrReplaceTempView("v") + checkAnswer(spark.table("v"), Seq(Row(1, 100))) - val catalog = getTableCatalog[InMemoryTableCatalog](session, "testcat") + val catalog = getTableCatalog[InMemoryTableCatalog](spark, "testcat") val dropCol = TableChange.deleteColumn(Array("salary"), false) val addCol = TableChange.addColumn(Array("salary"), StringType, true) catalog.alterTable(testIdent, dropCol, addCol) checkError( - exception = intercept[AnalysisException] { session.table("v").collect() }, + exception = intercept[AnalysisException] { spark.table("v").collect() }, condition = "INCOMPATIBLE_COLUMN_CHANGES_AFTER_VIEW_WITH_PLAN_CREATION", parameters = Map( "viewName" -> "`v`", @@ -473,26 +473,26 @@ trait DSv2TempViewWithStoredPlanTests extends DSv2ExternalMutationTestBase { // Scenario 6.2 connector w/ cache (external column type change, caching connector) test(s"${testPrefix}connector w/ cache: temp view stale after external column type change") { - withTestSession { session => - withTestTableAndViews(session, cachingTestTable, Seq("v")) { - session.sql(s"CREATE TABLE $cachingTestTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $cachingTestTable VALUES (1, 100), (10, 1000)").collect() + withTable(cachingTestTable) { + withView("v") { + spark.sql(s"CREATE TABLE $cachingTestTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $cachingTestTable VALUES (1, 100), (10, 1000)").collect() - session.table(cachingTestTable).filter("salary < 999").createOrReplaceTempView("v") - checkRows(session.table("v"), Seq(Row(1, 100))) + spark.table(cachingTestTable).filter("salary < 999").createOrReplaceTempView("v") + checkAnswer(spark.table("v"), Seq(Row(1, 100))) - val catalog = getTableCatalog[CachingInMemoryTableCatalog](session, "cachingcat") + val catalog = getTableCatalog[CachingInMemoryTableCatalog](spark, "cachingcat") val dropCol = TableChange.deleteColumn(Array("salary"), false) val addCol = TableChange.addColumn(Array("salary"), StringType, true) catalog.alterTable(testIdent, dropCol, addCol) // Caching connector returns stale table: type change invisible, no error - checkRows(session.table("v"), Seq(Row(1, 100))) + checkAnswer(spark.table("v"), Seq(Row(1, 100))) // REFRESH TABLE invalidates the connector cache, type change detected - session.sql(s"REFRESH TABLE $cachingTestTable").collect() + spark.sql(s"REFRESH TABLE $cachingTestTable").collect() checkError( - exception = intercept[AnalysisException] { session.table("v").collect() }, + exception = intercept[AnalysisException] { spark.table("v").collect() }, condition = "INCOMPATIBLE_COLUMN_CHANGES_AFTER_VIEW_WITH_PLAN_CREATION", parameters = Map( "viewName" -> "`v`", @@ -505,18 +505,18 @@ trait DSv2TempViewWithStoredPlanTests extends DSv2ExternalMutationTestBase { // Scenario 7.1 (session type widening from INT to BIGINT) test(s"${testPrefix}temp view with stored plan detects session type widening") { - withTestSession { session => - withTestTableAndViews(session, testTable, Seq("v")) { - session.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $testTable VALUES (1, 100), (10, 1000)").collect() + withTable(testTable) { + withView("v") { + spark.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $testTable VALUES (1, 100), (10, 1000)").collect() - session.table(testTable).filter("salary < 999").createOrReplaceTempView("v") - checkRows(session.table("v"), Seq(Row(1, 100))) + spark.table(testTable).filter("salary < 999").createOrReplaceTempView("v") + checkAnswer(spark.table("v"), Seq(Row(1, 100))) - session.sql(s"ALTER TABLE $testTable ALTER COLUMN salary TYPE LONG").collect() + spark.sql(s"ALTER TABLE $testTable ALTER COLUMN salary TYPE LONG").collect() checkError( - exception = intercept[AnalysisException] { session.table("v").collect() }, + exception = intercept[AnalysisException] { spark.table("v").collect() }, condition = "INCOMPATIBLE_COLUMN_CHANGES_AFTER_VIEW_WITH_PLAN_CREATION", parameters = Map( "viewName" -> "`v`", @@ -529,20 +529,20 @@ trait DSv2TempViewWithStoredPlanTests extends DSv2ExternalMutationTestBase { // Scenario 7.2 (external type widening from INT to BIGINT) test(s"${testPrefix}temp view with stored plan detects external type widening") { - withTestSession { session => - withTestTableAndViews(session, testTable, Seq("v")) { - session.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $testTable VALUES (1, 100), (10, 1000)").collect() + withTable(testTable) { + withView("v") { + spark.sql(s"CREATE TABLE $testTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $testTable VALUES (1, 100), (10, 1000)").collect() - session.table(testTable).filter("salary < 999").createOrReplaceTempView("v") - checkRows(session.table("v"), Seq(Row(1, 100))) + spark.table(testTable).filter("salary < 999").createOrReplaceTempView("v") + checkAnswer(spark.table("v"), Seq(Row(1, 100))) - val catalog = getTableCatalog[InMemoryTableCatalog](session, "testcat") + val catalog = getTableCatalog[InMemoryTableCatalog](spark, "testcat") val updateType = TableChange.updateColumnType(Array("salary"), LongType) catalog.alterTable(testIdent, updateType) checkError( - exception = intercept[AnalysisException] { session.table("v").collect() }, + exception = intercept[AnalysisException] { spark.table("v").collect() }, condition = "INCOMPATIBLE_COLUMN_CHANGES_AFTER_VIEW_WITH_PLAN_CREATION", parameters = Map( "viewName" -> "`v`", @@ -555,25 +555,25 @@ trait DSv2TempViewWithStoredPlanTests extends DSv2ExternalMutationTestBase { // Scenario 7.2 connector w/ cache (external type widening, caching connector) test(s"${testPrefix}connector w/ cache: temp view stale after external type widening") { - withTestSession { session => - withTestTableAndViews(session, cachingTestTable, Seq("v")) { - session.sql(s"CREATE TABLE $cachingTestTable (id INT, salary INT) USING foo").collect() - session.sql(s"INSERT INTO $cachingTestTable VALUES (1, 100), (10, 1000)").collect() + withTable(cachingTestTable) { + withView("v") { + spark.sql(s"CREATE TABLE $cachingTestTable (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO $cachingTestTable VALUES (1, 100), (10, 1000)").collect() - session.table(cachingTestTable).filter("salary < 999").createOrReplaceTempView("v") - checkRows(session.table("v"), Seq(Row(1, 100))) + spark.table(cachingTestTable).filter("salary < 999").createOrReplaceTempView("v") + checkAnswer(spark.table("v"), Seq(Row(1, 100))) - val catalog = getTableCatalog[CachingInMemoryTableCatalog](session, "cachingcat") + val catalog = getTableCatalog[CachingInMemoryTableCatalog](spark, "cachingcat") val updateType = TableChange.updateColumnType(Array("salary"), LongType) catalog.alterTable(testIdent, updateType) // Caching connector returns stale table: type change invisible, no error - checkRows(session.table("v"), Seq(Row(1, 100))) + checkAnswer(spark.table("v"), Seq(Row(1, 100))) // REFRESH TABLE invalidates the connector cache, type change detected - session.sql(s"REFRESH TABLE $cachingTestTable").collect() + spark.sql(s"REFRESH TABLE $cachingTestTable").collect() checkError( - exception = intercept[AnalysisException] { session.table("v").collect() }, + exception = intercept[AnalysisException] { spark.table("v").collect() }, condition = "INCOMPATIBLE_COLUMN_CHANGES_AFTER_VIEW_WITH_PLAN_CREATION", parameters = Map( "viewName" -> "`v`", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala index 71632e07c78b7..4d4b96406bd28 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala @@ -24,7 +24,7 @@ import scala.jdk.CollectionConverters._ import scala.reflect.ClassTag import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode, SparkSession} +import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode, SessionQueryTest, SparkSession} import org.apache.spark.sql.QueryTest.withQueryExecutionsCaptured import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, LogicalPlan, ReplaceTableAsSelect} @@ -47,6 +47,7 @@ import org.apache.spark.unsafe.types.UTF8String class DataSourceV2DataFrameSuite extends InsertIntoTests(supportsDynamicOverwrite = true, includeSQLOnlyTests = false) + with SessionQueryTest with DSv2TempViewWithStoredPlanTests with DSv2RepeatedTableAccessTests with DSv2IncrementallyConstructedQueryTests @@ -97,12 +98,6 @@ class DataSourceV2DataFrameSuite // DSv2ExternalMutationTestBase implementations for classic mode override protected def testPrefix: String = "" - override protected def isConnect: Boolean = false - - override protected def withTestSession(fn: SparkSession => Unit): Unit = fn(spark) - - override protected def checkRows(df: => DataFrame, expected: Seq[Row]): Unit = - checkAnswer(df, expected) override protected def getTableCatalog[C <: TableCatalog: ClassTag]( session: SparkSession, @@ -115,16 +110,6 @@ class DataSourceV2DataFrameSuite c.asInstanceOf[C] } - override protected def withTestTableAndViews( - session: SparkSession, - table: String, - views: Seq[String] = Seq.empty)(fn: => Unit): Unit = { - withTable(table) { - try { fn } - finally { views.foreach(v => session.sql(s"DROP VIEW IF EXISTS $v")) } - } - } - override def verifyTable(tableName: String, expected: DataFrame): Unit = { checkAnswer(spark.table(tableName), expected) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index b20b6d397fd17..5c2bc6829ea59 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.classic import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.datasources.{SchemaColumnConvertNotSupportedException, SQLHadoopMapReduceCommitProtocol} import org.apache.spark.sql.execution.datasources.parquet.TestingUDT._ @@ -37,14 +38,15 @@ import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan import org.apache.spark.sql.functions.struct import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.util.Utils /** * A test suite that tests various Parquet queries. */ -abstract class ParquetQuerySuite extends ParquetTest with SharedSparkSession { +abstract class ParquetQuerySuite extends ParquetTest + with QueryTest + with classic.SparkSessionBinder { import testImplicits._ test("simple select queries") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala index fb26d3311ebef..d8065b21a6188 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala @@ -19,43 +19,16 @@ package org.apache.spark.sql.test import scala.concurrent.duration._ -import org.scalatest.{BeforeAndAfterEach, Suite} -import org.scalatest.concurrent.Eventually +import org.scalatest.Suite -import org.apache.spark.{DebugFilesystem, SparkConf} -import org.apache.spark.internal.config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK -import org.apache.spark.sql.{classic, QueryTest, QueryTestBase, SparkSession, SparkSessionProvider, SQLContext} -import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode -import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation -import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} +import org.apache.spark.sql.{QueryTest, QueryTestBase, SparkSessionBinderBase} +import org.apache.spark.sql.classic -trait SharedSparkSession extends QueryTest with SharedSparkSessionBase { - - /** - * Suites extending [[SharedSparkSession]] are sharing resources (e.g. SparkSession) in their - * tests. That trait initializes the spark session in its [[beforeAll()]] implementation before - * the automatic thread snapshot is performed, so the audit code could fail to report threads - * leaked by that shared session. - * - * The behavior is overridden here to take the snapshot before the spark session is initialized. - */ - override protected val enableAutoThreadAudit = false - - protected override def beforeAll(): Unit = { - doThreadPreAudit() - super.beforeAll() - } - - protected override def afterAll(): Unit = { - try { - super.afterAll() - } finally { - doThreadPostAudit() - } - } +trait SharedSparkSession extends QueryTest with classic.SparkSessionBinder { // Runs func (which must trigger exactly one SQL execution) and returns the SQL metrics of that // execution as a map keyed by (planNodeId, planNodeName, metricName) -> metricValue. + @deprecated("rarely used", "4.2.0") def runAndFetchMetrics(func: => Unit): Map[(Long, String, String), String] = { val statusStore = spark.sharedState.statusStore val oldCount = statusStore.executionsList().size @@ -82,124 +55,12 @@ trait SharedSparkSession extends QueryTest with SharedSparkSessionBase { } } + /** * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]]. */ -trait SharedSparkSessionBase - extends QueryTestBase - with SparkSessionProvider - with BeforeAndAfterEach - with Eventually { self: Suite => - - protected def sparkConf = { - val conf = new SparkConf() - .set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) - .set(UNSAFE_EXCEPTION_ON_MEMORY_LEAK, true) - .set(SQLConf.CODEGEN_FALLBACK.key, "false") - .set(SQLConf.CODEGEN_FACTORY_MODE.key, CodegenObjectFactoryMode.CODEGEN_ONLY.toString) - // Disable ConvertToLocalRelation for better test coverage. Test cases built on - // LocalRelation will exercise the optimization rules better by disabling it as - // this rule may potentially block testing of other optimization rules such as - // ConstantPropagation etc. - .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName) - conf.set( - StaticSQLConf.WAREHOUSE_PATH, - conf.get(StaticSQLConf.WAREHOUSE_PATH) + "/" + getClass.getCanonicalName) - conf.set(StaticSQLConf.LOAD_SESSION_EXTENSIONS_FROM_CLASSPATH, false) - conf.set(StaticSQLConf.SHUFFLE_EXCHANGE_MAX_THREAD_THRESHOLD, - sys.env.getOrElse("SPARK_TEST_SQL_SHUFFLE_EXCHANGE_MAX_THREAD_THRESHOLD", - StaticSQLConf.SHUFFLE_EXCHANGE_MAX_THREAD_THRESHOLD.defaultValueString).toInt) - conf.set(StaticSQLConf.RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD, - sys.env.getOrElse("SPARK_TEST_SQL_RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD", - StaticSQLConf.RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD.defaultValueString).toInt) - } - - /** - * The [[TestSparkSession]] to use for all tests in this suite. - * - * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local - * mode with the default test configurations. - */ - private var _spark: TestSparkSession = null - - /** - * The [[TestSparkSession]] to use for all tests in this suite. - */ - protected override def spark: classic.SparkSession = _spark - - /** - * The [[TestSQLContext]] to use for all tests in this suite. - */ - protected implicit def sqlContext: SQLContext = _spark.sqlContext - - protected def createSparkSession: TestSparkSession = { - classic.SparkSession.cleanupAnyExistingSession() - new TestSparkSession(sparkConf) - } +trait SharedSparkSessionBase extends QueryTestBase with SparkSessionBinderBase { self: Suite => - protected def sqlConf: SQLConf = _spark.sessionState.conf - - /** - * Initialize the [[TestSparkSession]]. Generally, this is just called from - * beforeAll; however, in test using styles other than FunSuite, there is - * often code that relies on the session between test group constructs and - * the actual tests, which may need this session. It is purely a semantic - * difference, but semantically, it makes more sense to call - * 'initializeSession' between a 'describe' and an 'it' call than it does to - * call 'beforeAll'. - */ - protected def initializeSession(): Unit = { - if (_spark == null) { - _spark = createSparkSession - } - } - - /** - * Make sure the [[TestSparkSession]] is initialized before any tests are run. - */ - protected override def beforeAll(): Unit = { - initializeSession() - - // Ensure we have initialized the context before calling parent code - super.beforeAll() - } - - /** - * Stop the underlying [[org.apache.spark.SparkContext]], if any. - */ - protected override def afterAll(): Unit = { - try { - super.afterAll() - } finally { - try { - if (_spark != null) { - try { - _spark.sessionState.catalog.reset() - } finally { - _spark.stop() - _spark = null - } - } - } finally { - SparkSession.clearActiveSession() - SparkSession.clearDefaultSession() - } - } - } - - protected override def beforeEach(): Unit = { - super.beforeEach() - DebugFilesystem.clearOpenStreams() - } - - protected override def afterEach(): Unit = { - super.afterEach() - // Clear all persistent datasets after each test - spark.sharedState.cacheManager.clearCache() - // files can be closed from other threads, so wait a bit - // normally this doesn't take more than 1s - eventually(timeout(10.seconds), interval(2.seconds)) { - DebugFilesystem.assertNoOpenStreams() - } - } + protected override def spark: classic.SparkSession = + super.spark.asInstanceOf[classic.SparkSession] } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala index 47cc9853f754d..172d374385474 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.hive.test import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.SparkSessionProvider -import org.apache.spark.sql.classic.SparkSession +import org.apache.spark.sql.classic.{SparkSession, SparkSessionProvider} import org.apache.spark.sql.hive.HiveExternalCatalog import org.apache.spark.sql.hive.client.HiveClient