Skip to content

Commit

Permalink
Improve shutdown logic of language server
Browse files Browse the repository at this point in the history
This PR addresses problems mentioned in #7470 and #7729:
- shutting a language server explicitly will not lead to a soft shutdown
- `project/status` endpoint returns the state of the language server

`LanguageServerController` now also signed up for `ClientConnect`
messages. For it to be unambiguous, we need to carry around the port
number of the language server as a way of identifying the right one.

One can now use `project/status` to additionally determine the state of
the language server.

Also relies on a proper fix for #7765.
  • Loading branch information
hubertp committed Sep 12, 2023
1 parent 6fd2295 commit 9607278
Show file tree
Hide file tree
Showing 23 changed files with 478 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class JsonConnectionController(
}

override def receive: Receive = {
case JsonRpcServer.WebConnect(webActor) =>
case JsonRpcServer.WebConnect(webActor, _) =>
unstashAll()
context.become(connected(webActor))
case _ => stash()
Expand Down Expand Up @@ -180,7 +180,7 @@ class JsonConnectionController(
case Request(_, id, _) =>
sender() ! ResponseError(Some(id), SessionNotInitialisedError)

case MessageHandler.Disconnected =>
case MessageHandler.Disconnected(_) =>
context.stop(self)
}

Expand Down Expand Up @@ -304,7 +304,7 @@ class JsonConnectionController(
case Request(InitProtocolConnection, id, _) =>
sender() ! ResponseError(Some(id), SessionAlreadyInitialisedError)

case MessageHandler.Disconnected =>
case MessageHandler.Disconnected(_) =>
logger.info("Json session terminated [{}].", rpcSession.clientId)
context.system.eventStream.publish(JsonSessionTerminated(rpcSession))
context.stop(self)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,15 @@ class JsonRpcServer(

implicit val ec: ExecutionContext = system.dispatcher

private def newUser(): Flow[Message, Message, NotUsed] = {
private def newUser(port: Int): Flow[Message, Message, NotUsed] = {
val messageHandler =
system.actorOf(
Props(
new MessageHandlerSupervisor(clientControllerFactory, protocolFactory)
new MessageHandlerSupervisor(
clientControllerFactory,
protocolFactory,
port
)
),
s"message-handler-supervisor-${UUID.randomUUID()}"
)
Expand All @@ -61,9 +65,11 @@ class JsonRpcServer(
.to(
Sink.actorRef[MessageHandler.WebMessage](
messageHandler,
MessageHandler.Disconnected,
{ _: Any =>
MessageHandler.Disconnected
MessageHandler.Disconnected(port),
{ _: Throwable =>
// TODO: If enabled, the warning would produce too much noise in tests
// logger.warn(s"Connection closed abruptly: ${e.getMessage}", e)
MessageHandler.Disconnected(port)
}
)
)
Expand All @@ -77,7 +83,7 @@ class JsonRpcServer(
OverflowStrategy.fail
)
.mapMaterializedValue { outActor =>
messageHandler ! MessageHandler.Connected(outActor)
messageHandler ! MessageHandler.Connected(outActor, port)
NotUsed
}
.map((outMsg: MessageHandler.WebMessage) => TextMessage(outMsg.message))
Expand All @@ -88,10 +94,10 @@ class JsonRpcServer(
Flow.fromSinkAndSource(incomingMessages, outgoingMessages)
}

private val route: Route = {
private def route(port: Int): Route = {
val webSocketEndpoint =
path(config.path) {
get { handleWebSocketMessages(newUser()) }
get { handleWebSocketMessages(newUser(port)) }
}

optionalEndpoints.foldLeft(webSocketEndpoint) { (chain, next) =>
Expand All @@ -109,7 +115,7 @@ class JsonRpcServer(
def bind(interface: String, port: Int): Future[Http.ServerBinding] =
Http()
.newServerAt(interface, port)
.bind(route)
.bind(route(port))
}

object JsonRpcServer {
Expand Down Expand Up @@ -138,6 +144,6 @@ object JsonRpcServer {
Config(outgoingBufferSize = 1000, lazyMessageTimeout = 10.seconds)
}

case class WebConnect(webActor: ActorRef)
case class WebConnect(webActor: ActorRef, port: Int)

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class MessageHandler(protocolFactory: ProtocolFactory, controller: ActorRef)
* @return the actor behavior.
*/
override def receive: Receive = {
case MessageHandler.Connected(webConnection) =>
case MessageHandler.Connected(webConnection, _) =>
unstashAll()
context.become(established(webConnection, Map()))
case _ => stash()
Expand All @@ -38,8 +38,8 @@ class MessageHandler(protocolFactory: ProtocolFactory, controller: ActorRef)
): Receive = {
case MessageHandler.WebMessage(msg) =>
handleWebMessage(msg, webConnection, awaitingResponses)
case MessageHandler.Disconnected =>
controller ! MessageHandler.Disconnected
case MessageHandler.Disconnected(port) =>
controller ! MessageHandler.Disconnected(port)
context.stop(self)
case request: Request[Method, Any] =>
issueRequest(request, webConnection, awaitingResponses)
Expand Down Expand Up @@ -192,10 +192,10 @@ object MessageHandler {
/** A control message used for [[MessageHandler]] initializations
* @param webConnection the actor representing the web.
*/
case class Connected(webConnection: ActorRef)
case class Connected(webConnection: ActorRef, port: Int)

/** A control message used to notify the controller about
* the connection being closed.
*/
case object Disconnected
case class Disconnected(port: Int)
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ import java.util.UUID
*/
final class MessageHandlerSupervisor(
clientControllerFactory: ClientControllerFactory,
protocolFactory: ProtocolFactory
protocolFactory: ProtocolFactory,
port: Int
) extends Actor
with LazyLogging
with Stash {
Expand Down Expand Up @@ -58,7 +59,7 @@ final class MessageHandlerSupervisor(
Props(new MessageHandler(protocolFactory, clientActor)),
s"message-handler-$clientId"
)
clientActor ! JsonRpcServer.WebConnect(messageHandler)
clientActor ! JsonRpcServer.WebConnect(messageHandler, port)
context.become(initialized(messageHandler))
unstashAll()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class MessageHandlerSpec
handler = system.actorOf(
Props(new MessageHandler(MyProtocolFactory, controller.ref))
)
handler ! Connected(out.ref)
handler ! Connected(out.ref, 0)
}

"Message handler" must {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ project-manager {
request-timeout = 10 seconds
boot-timeout = 40 seconds
shutdown-timeout = 20 seconds
delayed-shutdown-timeout = 8 seconds
delayed-shutdown-timeout = 3 seconds
socket-close-timeout = 15 seconds
retries = 5
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package org.enso.projectmanager.data

case class LanguageServerStatus(open: Boolean, shuttingDown: Boolean)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package org.enso.projectmanager.data

case class RunningStatus(open: Boolean, shuttingDown: Boolean)
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@ object ClientEvent {
/** Notifies the Language Server about a new client connecting.
*
* @param clientId an object representing a client
* @param port the port number to which the client connected
*/
case class ClientConnected(clientId: UUID) extends ClientEvent
case class ClientConnected(clientId: UUID, port: Int) extends ClientEvent

/** Notifies the Language Server about a client disconnecting.
* The client may not send any further messages after this one.
*
* @param clientId the internal id of this client
* @param port the port number from which the client disconnected
*/
case class ClientDisconnected(clientId: UUID) extends ClientEvent
case class ClientDisconnected(clientId: UUID, port: Int) extends ClientEvent

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ import nl.gn0s1s.bump.SemVer
import org.enso.logger.akka.ActorMessageLogging
import org.enso.projectmanager.boot.configuration._
import org.enso.projectmanager.data.{LanguageServerSockets, Socket}
import org.enso.projectmanager.event.ClientEvent.ClientDisconnected
import org.enso.projectmanager.event.ClientEvent.{
ClientConnected,
ClientDisconnected
}
import org.enso.projectmanager.event.ProjectEvent.ProjectClosed
import org.enso.projectmanager.infrastructure.http.AkkaBasedWebSocketConnectionFactory
import org.enso.projectmanager.infrastructure.languageserver.LanguageServerBootLoader.{
Expand All @@ -24,7 +27,11 @@ import org.enso.projectmanager.infrastructure.languageserver.LanguageServerBootL
}
import org.enso.projectmanager.infrastructure.languageserver.LanguageServerController._
import org.enso.projectmanager.infrastructure.languageserver.LanguageServerProtocol._
import org.enso.projectmanager.infrastructure.languageserver.LanguageServerRegistry.ServerShutDown
import org.enso.projectmanager.infrastructure.languageserver.LanguageServerRegistry.{
LanguageServerStatus,
LanguageServerStatusRequest,
ServerShutDown
}
import org.enso.projectmanager.model.Project
import org.enso.projectmanager.service.LoggingServiceDescriptor
import org.enso.projectmanager.util.UnhandledLogging
Expand Down Expand Up @@ -93,6 +100,7 @@ class LanguageServerController(

override def preStart(): Unit = {
context.system.eventStream.subscribe(self, classOf[ClientDisconnected])
context.system.eventStream.subscribe(self, classOf[ClientConnected])
self ! Boot
}

Expand Down Expand Up @@ -160,12 +168,12 @@ class LanguageServerController(
private def supervising(
connectionInfo: LanguageServerConnectionInfo,
serverProcessManager: ActorRef,
clients: Set[UUID] = Set.empty,
scheduledShutdown: Option[Cancellable] = None
clients: Set[UUID] = Set.empty,
scheduledShutdown: Option[(Cancellable, Int)] = None
): Receive =
LoggingReceive.withLabel("supervising") {
case StartServer(clientId, _, requestedEngineVersion, _, _) =>
scheduledShutdown.foreach(_.cancel())
scheduledShutdown.foreach(_._1.cancel())
if (requestedEngineVersion != engineVersion) {
sender() ! ServerBootFailed(
new IllegalStateException(
Expand All @@ -192,7 +200,7 @@ class LanguageServerController(
)
}
case Terminated(_) =>
scheduledShutdown.foreach(_.cancel())
scheduledShutdown.foreach(_._1.cancel())
logger.debug("Bootloader for {} terminated.", project)

case StopServer(clientId, _) =>
Expand All @@ -202,28 +210,49 @@ class LanguageServerController(
clients,
clientId,
Some(sender()),
explicitShutdownRequested = true,
None,
scheduledShutdown
)

case ScheduledShutdown(requester) =>
shutDownServer(requester)

case LanguageServerStatusRequest =>
sender() ! LanguageServerStatus(project.id, scheduledShutdown.isDefined)

case ShutDownServer =>
scheduledShutdown.foreach(_.cancel())
scheduledShutdown.foreach(_._1.cancel())
shutDownServer(None)

case ClientDisconnected(clientId) =>
case ClientDisconnected(clientId, port) =>
removeClient(
connectionInfo,
serverProcessManager,
clients,
clientId,
None,
explicitShutdownRequested = false,
atPort = Some(port),
scheduledShutdown
)
case ClientConnected(clientId, clientPort) =>
scheduledShutdown match {
case Some((cancellable, port)) if clientPort == port =>
cancellable.cancel()
context.become(
supervising(
connectionInfo,
serverProcessManager,
clients ++ Set(clientId),
None
)
)
case _ =>
}

case RenameProject(_, namespace, oldName, newName) =>
scheduledShutdown.foreach(_.cancel())
scheduledShutdown.foreach(_._1.cancel())
val socket = Socket(connectionInfo.interface, connectionInfo.rpcPort)
context.actorOf(
ProjectRenameAction
Expand All @@ -241,7 +270,7 @@ class LanguageServerController(
)

case ServerDied =>
scheduledShutdown.foreach(_.cancel())
scheduledShutdown.foreach(_._1.cancel())
logger.error("Language server died [{}].", connectionInfo)
context.stop(self)

Expand All @@ -253,30 +282,39 @@ class LanguageServerController(
clients: Set[UUID],
clientId: UUID,
maybeRequester: Option[ActorRef],
shutdownTimeout: Option[Cancellable]
explicitShutdownRequested: Boolean,
atPort: Option[Int],
shutdownTimeout: Option[(Cancellable, Int)]
): Unit = {
val updatedClients = clients - clientId
if (updatedClients.isEmpty) {
logger.debug("Delaying shutdown for project {}.", project.id)
val scheduledShutdown =
shutdownTimeout.orElse(
Some(
context.system.scheduler
.scheduleOnce(
timeoutConfig.delayedShutdownTimeout,
self,
ScheduledShutdown(maybeRequester)
if (!explicitShutdownRequested) {
logger.debug("Delaying shutdown for project {}.", project.id)
val scheduledShutdown: Option[(Cancellable, Int)] =
shutdownTimeout.orElse(
Some(
(
context.system.scheduler.scheduleOnce(
timeoutConfig.delayedShutdownTimeout,
self,
ScheduledShutdown(maybeRequester)
),
atPort.getOrElse(0)
)
)
)
context.become(
supervising(
connectionInfo,
serverProcessManager,
Set.empty,
scheduledShutdown
)
)
context.become(
supervising(
connectionInfo,
serverProcessManager,
Set.empty,
scheduledShutdown
)
)
} else {
shutdownTimeout.foreach(_._1.cancel())
shutDownServer(maybeRequester)
}
} else {
sender() ! CannotDisconnectOtherClients
context.become(
Expand Down Expand Up @@ -329,7 +367,7 @@ class LanguageServerController(
maybeRequester.foreach(_ ! ServerShutdownTimedOut)
stop()

case ClientDisconnected(clientId) =>
case ClientDisconnected(clientId, _) =>
logger.debug(
s"Received client ($clientId) disconnect request during shutdown. Ignoring."
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ trait LanguageServerGateway[F[+_, +_]] {
* @param projectId a project id
* @return true if project is open
*/
def isRunning(projectId: UUID): F[CheckTimeout.type, Boolean]
def isRunning(projectId: UUID): F[CheckTimeout.type, (Boolean, Boolean)]

/** Request a language server to rename project.
*
Expand Down
Loading

0 comments on commit 9607278

Please sign in to comment.