diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/Client.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/Client.scala index fed9334dbbab4..715df54e573c3 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/Client.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/Client.scala @@ -361,11 +361,13 @@ private[spark] class Client( DEFAULT_BLOCKMANAGER_PORT.toString) val driverSubmitter = buildDriverSubmissionClient(kubernetesClient, service, driverSubmitSslOptions) - val ping = Retry.retry(5, 5.seconds) { + val ping = Retry.retry(5, 5.seconds, + Some("Failed to contact the driver server")) { driverSubmitter.ping() } ping onFailure { case t: Throwable => + logError("Ping failed to the driver server", t) submitCompletedFuture.setException(t) kubernetesClient.services().delete(service) } @@ -532,17 +534,6 @@ private[spark] class Client( kubernetesClient: KubernetesClient, service: Service, driverSubmitSslOptions: SSLOptions): KubernetesSparkRestApi = { - val servicePort = service - .getSpec - .getPorts - .asScala - .filter(_.getName == SUBMISSION_SERVER_PORT_NAME) - .head - .getNodePort - // NodePort is exposed on every node, so just pick one of them. - // TODO be resilient to node failures and try all of them - val node = kubernetesClient.nodes.list.getItems.asScala.head - val nodeAddress = node.getStatus.getAddresses.asScala.head.getAddress val urlScheme = if (driverSubmitSslOptions.enabled) { "https" } else { @@ -551,15 +542,23 @@ private[spark] class Client( " to secure this step.") "http" } + val servicePort = service.getSpec.getPorts.asScala + .filter(_.getName == SUBMISSION_SERVER_PORT_NAME) + .head.getNodePort + val nodeUrls = kubernetesClient.nodes.list.getItems.asScala + .filterNot(_.getSpec.getUnschedulable) + .flatMap(_.getStatus.getAddresses.asScala.map(address => { + s"$urlScheme://${address.getAddress}:$servicePort" + })).toArray + require(nodeUrls.nonEmpty, "No nodes found to contact the driver!") val (trustManager, sslContext): (X509TrustManager, SSLContext) = if (driverSubmitSslOptions.enabled) { buildSslConnectionConfiguration(driverSubmitSslOptions) } else { (null, SSLContext.getDefault) } - val url = s"$urlScheme://$nodeAddress:$servicePort" HttpClientUtil.createClient[KubernetesSparkRestApi]( - url, + uris = nodeUrls, sslSocketFactory = sslContext.getSocketFactory, trustContext = trustManager) } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/Retry.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/Retry.scala index e5ce0bcd606b2..378583b29c547 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/Retry.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/Retry.scala @@ -19,24 +19,36 @@ package org.apache.spark.deploy.kubernetes import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration -private[spark] object Retry { +import org.apache.spark.SparkException +import org.apache.spark.internal.Logging + +private[spark] object Retry extends Logging { private def retryableFuture[T] - (times: Int, interval: Duration) + (attempt: Int, maxAttempts: Int, interval: Duration, retryMessage: Option[String]) (f: => Future[T]) (implicit executionContext: ExecutionContext): Future[T] = { f recoverWith { - case _ if times > 0 => { - Thread.sleep(interval.toMillis) - retryableFuture(times - 1, interval)(f) - } + case error: Throwable => + if (attempt <= maxAttempts) { + retryMessage.foreach { message => + logWarning(s"$message - attempt $attempt of $maxAttempts", error) + } + Thread.sleep(interval.toMillis) + retryableFuture(attempt + 1, maxAttempts, interval, retryMessage)(f) + } else { + Future.failed(retryMessage.map(message => + new SparkException(s"$message - reached $maxAttempts attempts," + + s" and aborting task.", error) + ).getOrElse(error)) + } } } def retry[T] - (times: Int, interval: Duration) + (times: Int, interval: Duration, retryMessage: Option[String] = None) (f: => T) (implicit executionContext: ExecutionContext): Future[T] = { - retryableFuture(times, interval)(Future[T] { f }) + retryableFuture(1, times, interval, retryMessage)(Future[T] { f }) } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/rest/kubernetes/HttpClientUtil.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/rest/kubernetes/HttpClientUtil.scala index eb7d411700829..1cabfbad656eb 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/rest/kubernetes/HttpClientUtil.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/rest/kubernetes/HttpClientUtil.scala @@ -20,7 +20,7 @@ import javax.net.ssl.{SSLContext, SSLSocketFactory, X509TrustManager} import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper} import com.fasterxml.jackson.module.scala.DefaultScalaModule -import feign.Feign +import feign.{Client, Feign, Request, Response} import feign.Request.Options import feign.jackson.{JacksonDecoder, JacksonEncoder} import feign.jaxrs.JAXRSContract @@ -32,7 +32,7 @@ import org.apache.spark.status.api.v1.JacksonMessageWriter private[spark] object HttpClientUtil { def createClient[T: ClassTag]( - uri: String, + uris: Array[String], sslSocketFactory: SSLSocketFactory = SSLContext.getDefault.getSocketFactory, trustContext: X509TrustManager = null, readTimeoutMillis: Int = 20000, @@ -45,13 +45,24 @@ private[spark] object HttpClientUtil { .registerModule(new DefaultScalaModule) .setDateFormat(JacksonMessageWriter.makeISODateFormat) objectMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) - val clazz = implicitly[ClassTag[T]].runtimeClass.asInstanceOf[Class[T]] + val target = new MultiServerFeignTarget[T](uris) + val baseHttpClient = new feign.okhttp.OkHttpClient(httpClientBuilder.build()) + val resetTargetHttpClient = new Client { + override def execute(request: Request, options: Options): Response = { + val response = baseHttpClient.execute(request, options) + if (response.status() >= 200 && response.status() < 300) { + target.reset() + } + response + } + } Feign.builder() - .client(new feign.okhttp.OkHttpClient(httpClientBuilder.build())) + .client(resetTargetHttpClient) .contract(new JAXRSContract) .encoder(new JacksonEncoder(objectMapper)) .decoder(new JacksonDecoder(objectMapper)) .options(new Options(connectTimeoutMillis, readTimeoutMillis)) - .target(clazz, uri) + .retryer(target) + .target(target) } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/rest/kubernetes/MultiServerFeignTarget.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/rest/kubernetes/MultiServerFeignTarget.scala new file mode 100644 index 0000000000000..fea7f057cfa1b --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/rest/kubernetes/MultiServerFeignTarget.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.rest.kubernetes + +import feign.{Request, RequestTemplate, RetryableException, Retryer, Target} +import scala.reflect.ClassTag +import scala.util.Random + +private[kubernetes] class MultiServerFeignTarget[T : ClassTag]( + private val servers: Seq[String]) extends Target[T] with Retryer { + require(servers.nonEmpty, "Must provide at least one server URI.") + + private val threadLocalShuffledServers = new ThreadLocal[Seq[String]] { + override def initialValue(): Seq[String] = Random.shuffle(servers) + } + + override def `type`(): Class[T] = { + implicitly[ClassTag[T]].runtimeClass.asInstanceOf[Class[T]] + } + + override def url(): String = threadLocalShuffledServers.get.head + + /** + * Cloning the target is done on every request, for use on the current + * thread - thus it's important that clone returns a "fresh" target. + */ + override def clone(): Retryer = { + reset() + this + } + + override def name(): String = { + s"${getClass.getSimpleName} with servers [${servers.mkString(",")}]" + } + + override def apply(requestTemplate: RequestTemplate): Request = { + if (!requestTemplate.url().startsWith("http")) { + requestTemplate.insert(0, url()) + } + requestTemplate.request() + } + + override def continueOrPropagate(e: RetryableException): Unit = { + threadLocalShuffledServers.set(threadLocalShuffledServers.get.drop(1)) + if (threadLocalShuffledServers.get.isEmpty) { + throw e + } + } + + def reset(): Unit = { + threadLocalShuffledServers.set(Random.shuffle(servers)) + } +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/kubernetes/integrationtest/minikube/Minikube.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/kubernetes/integrationtest/minikube/Minikube.scala index 60c6564579a6e..b42f97952394e 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/kubernetes/integrationtest/minikube/Minikube.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/kubernetes/integrationtest/minikube/Minikube.scala @@ -123,7 +123,7 @@ private[spark] object Minikube extends Logging { .build() val sslContext = SSLUtils.sslContext(kubernetesConf) val trustManager = SSLUtils.trustManagers(kubernetesConf)(0).asInstanceOf[X509TrustManager] - HttpClientUtil.createClient[T](url, sslContext.getSocketFactory, trustManager) + HttpClientUtil.createClient[T](Array(url), sslContext.getSocketFactory, trustManager) } def executeMinikubeSsh(command: String): Unit = {