Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-44259][CONNECT][TESTS] Make connect-client-jvm pass on Java 21 except RemoteSparkSession-based tests #41805

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import java.io.{PipedInputStream, PipedOutputStream}
import java.util.concurrent.{Executors, Semaphore, TimeUnit}

import org.apache.commons.io.output.ByteArrayOutputStream
import org.apache.commons.lang3.{JavaVersion, SystemUtils}
import org.scalatest.BeforeAndAfterEach

import org.apache.spark.sql.connect.client.util.RemoteSparkSession
Expand All @@ -42,26 +43,29 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach {
}

override def beforeAll(): Unit = {
super.beforeAll()
ammoniteOut = new ByteArrayOutputStream()
testSuiteOut = new PipedOutputStream()
// Connect the `testSuiteOut` and `ammoniteIn` pipes
ammoniteIn = new PipedInputStream(testSuiteOut)
errorStream = new ByteArrayOutputStream()

val args = Array("--port", serverPort.toString)
val task = new Runnable {
override def run(): Unit = {
ConnectRepl.doMain(
args = args,
semaphore = Some(semaphore),
inputStream = ammoniteIn,
outputStream = ammoniteOut,
errorStream = errorStream)
// TODO(SPARK-44121) Remove this check condition
if (SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) {
super.beforeAll()
ammoniteOut = new ByteArrayOutputStream()
testSuiteOut = new PipedOutputStream()
// Connect the `testSuiteOut` and `ammoniteIn` pipes
ammoniteIn = new PipedInputStream(testSuiteOut)
errorStream = new ByteArrayOutputStream()

val args = Array("--port", serverPort.toString)
val task = new Runnable {
override def run(): Unit = {
ConnectRepl.doMain(
args = args,
semaphore = Some(semaphore),
inputStream = ammoniteIn,
outputStream = ammoniteOut,
errorStream = errorStream)
}
}
}

executorService.submit(task)
executorService.submit(task)
}
}

override def afterAll(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ import java.util.concurrent.TimeUnit

import scala.io.Source

import org.scalatest.BeforeAndAfterAll
import org.apache.commons.lang3.{JavaVersion, SystemUtils}
import org.scalactic.source.Position
import org.scalatest.{BeforeAndAfterAll, Tag}
import sys.process._

import org.apache.spark.sql.SparkSession
Expand Down Expand Up @@ -170,41 +172,44 @@ trait RemoteSparkSession extends ConnectFunSuite with BeforeAndAfterAll {
protected lazy val serverPort: Int = port

override def beforeAll(): Unit = {
super.beforeAll()
SparkConnectServerUtils.start()
spark = SparkSession
.builder()
.client(SparkConnectClient.builder().port(serverPort).build())
.create()

// Retry and wait for the server to start
val stop = System.nanoTime() + TimeUnit.MINUTES.toNanos(1) // ~1 min
var sleepInternalMs = TimeUnit.SECONDS.toMillis(1) // 1s with * 2 backoff
var success = false
val error = new RuntimeException(s"Failed to start the test server on port $serverPort.")

while (!success && System.nanoTime() < stop) {
try {
// Run a simple query to verify the server is really up and ready
val result = spark
.sql("select val from (values ('Hello'), ('World')) as t(val)")
.collect()
assert(result.length == 2)
success = true
debug("Spark Connect Server is up.")
} catch {
// ignored the error
case e: Throwable =>
error.addSuppressed(e)
Thread.sleep(sleepInternalMs)
sleepInternalMs *= 2
// TODO(SPARK-44121) Remove this check condition
if (SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) {
super.beforeAll()
SparkConnectServerUtils.start()
spark = SparkSession
.builder()
.client(SparkConnectClient.builder().port(serverPort).build())
.create()

// Retry and wait for the server to start
val stop = System.nanoTime() + TimeUnit.MINUTES.toNanos(1) // ~1 min
var sleepInternalMs = TimeUnit.SECONDS.toMillis(1) // 1s with * 2 backoff
var success = false
val error = new RuntimeException(s"Failed to start the test server on port $serverPort.")

while (!success && System.nanoTime() < stop) {
try {
// Run a simple query to verify the server is really up and ready
val result = spark
.sql("select val from (values ('Hello'), ('World')) as t(val)")
.collect()
assert(result.length == 2)
success = true
debug("Spark Connect Server is up.")
} catch {
// ignored the error
case e: Throwable =>
error.addSuppressed(e)
Thread.sleep(sleepInternalMs)
sleepInternalMs *= 2
}
}
}

// Throw error if failed
if (!success) {
debug(error)
throw error
// Throw error if failed
if (!success) {
debug(error)
throw error
}
}
}

Expand All @@ -217,4 +222,17 @@ trait RemoteSparkSession extends ConnectFunSuite with BeforeAndAfterAll {
spark = null
super.afterAll()
}

/**
* SPARK-44259: override test function to skip `RemoteSparkSession-based` tests as default,
* we should delete this function after SPARK-44121 is completed.
*/
override protected def test(testName: String, testTags: Tag*)(testFun: => Any)
Copy link
Contributor Author

@LuciferYang LuciferYang Jun 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dongjoon-hyun @HyukjinKwon the new code override test function, all RemoteSparkSession-based tests using Java 21 are ignored by default, so there is no need to add the assume condition to test case one by one

Copy link
Contributor Author

@LuciferYang LuciferYang Jun 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this seem less intrusive ?

(implicit pos: Position): Unit = {
super.test(testName, testTags: _*) {
// TODO(SPARK-44121) Re-enable Arrow-based connect tests in Java 21
assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17))
testFun
}
}
}