diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala index e05828606d098..8d84dffc9d5bd 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala @@ -21,7 +21,9 @@ import java.util.concurrent.TimeUnit import scala.io.Source -import org.scalatest.BeforeAndAfterAll +import org.apache.commons.lang3.{JavaVersion, SystemUtils} +import org.scalactic.source.Position +import org.scalatest.{BeforeAndAfterAll, Tag} import sys.process._ import org.apache.spark.sql.SparkSession @@ -170,41 +172,44 @@ trait RemoteSparkSession extends ConnectFunSuite with BeforeAndAfterAll { protected lazy val serverPort: Int = port override def beforeAll(): Unit = { - super.beforeAll() - SparkConnectServerUtils.start() - spark = SparkSession - .builder() - .client(SparkConnectClient.builder().port(serverPort).build()) - .create() - - // Retry and wait for the server to start - val stop = System.nanoTime() + TimeUnit.MINUTES.toNanos(1) // ~1 min - var sleepInternalMs = TimeUnit.SECONDS.toMillis(1) // 1s with * 2 backoff - var success = false - val error = new RuntimeException(s"Failed to start the test server on port $serverPort.") - - while (!success && System.nanoTime() < stop) { - try { - // Run a simple query to verify the server is really up and ready - val result = spark - .sql("select val from (values ('Hello'), ('World')) as t(val)") - .collect() - assert(result.length == 2) - success = true - debug("Spark Connect Server is up.") - } catch { - // ignored the error - case e: Throwable => - error.addSuppressed(e) - Thread.sleep(sleepInternalMs) - sleepInternalMs *= 2 + // TODO(SPARK-44121) Remove this check condition + if (SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) { + super.beforeAll() + SparkConnectServerUtils.start() + spark = SparkSession + .builder() + .client(SparkConnectClient.builder().port(serverPort).build()) + .create() + + // Retry and wait for the server to start + val stop = System.nanoTime() + TimeUnit.MINUTES.toNanos(1) // ~1 min + var sleepInternalMs = TimeUnit.SECONDS.toMillis(1) // 1s with * 2 backoff + var success = false + val error = new RuntimeException(s"Failed to start the test server on port $serverPort.") + + while (!success && System.nanoTime() < stop) { + try { + // Run a simple query to verify the server is really up and ready + val result = spark + .sql("select val from (values ('Hello'), ('World')) as t(val)") + .collect() + assert(result.length == 2) + success = true + debug("Spark Connect Server is up.") + } catch { + // ignored the error + case e: Throwable => + error.addSuppressed(e) + Thread.sleep(sleepInternalMs) + sleepInternalMs *= 2 + } } - } - // Throw error if failed - if (!success) { - debug(error) - throw error + // Throw error if failed + if (!success) { + debug(error) + throw error + } } } @@ -217,4 +222,17 @@ trait RemoteSparkSession extends ConnectFunSuite with BeforeAndAfterAll { spark = null super.afterAll() } + + /** + * SPARK-44259: override test function to skip `RemoteSparkSession-based` tests as default, we + * should delete this function after SPARK-44121 is completed. + */ + override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit + pos: Position): Unit = { + super.test(testName, testTags: _*) { + // TODO(SPARK-44121) Re-enable Arrow-based connect tests in Java 21 + assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) + testFun + } + } }