Skip to content

Commit

Permalink
Merge pull request #217 from IBM-Swift/ctx-bug
Browse files Browse the repository at this point in the history
Remove a race condition from the WebSocket upgrade
  • Loading branch information
Pushkar N Kulkarni authored Jul 26, 2019
2 parents ee75eeb + 09b88f8 commit a7be981
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 20 deletions.
2 changes: 1 addition & 1 deletion Sources/KituraNet/HTTP/HTTPRequestHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ internal class HTTPRequestHandler: ChannelInboundHandler, RemovableChannelHandle

switch request {
case .head(let header):
serverRequest = HTTPServerRequest(ctx: context, requestHead: header, enableSSL: enableSSLVerification)
serverRequest = HTTPServerRequest(channel: context.channel, requestHead: header, enableSSL: enableSSLVerification)
self.clientRequestedKeepAlive = header.isKeepAlive
case .body(var buffer):
guard let serverRequest = serverRequest else {
Expand Down
20 changes: 8 additions & 12 deletions Sources/KituraNet/HTTP/HTTPServer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,6 @@ public class HTTPServer: Server {
/// The event loop group on which the HTTP handler runs
private let eventLoopGroup: MultiThreadedEventLoopGroup

private var ctx: ChannelHandlerContext?

/**
Creates an HTTP server object.

Expand Down Expand Up @@ -193,19 +191,18 @@ public class HTTPServer: Server {
}

/// Creates upgrade request and adds WebSocket handler to pipeline
private func upgradeHandler(webSocketHandlerFactory: ProtocolHandlerFactory, request: HTTPRequestHead) -> EventLoopFuture<Void> {
guard let ctx = self.ctx else { fatalError("The channel was probably closed during a protocol upgrade.") }
return ctx.eventLoop.submit {
let request = HTTPServerRequest(ctx: ctx, requestHead: request, enableSSL: false)
private func upgradeHandler(channel: Channel, webSocketHandlerFactory: ProtocolHandlerFactory, request: HTTPRequestHead) -> EventLoopFuture<Void> {
return channel.eventLoop.submit {
let request = HTTPServerRequest(channel: channel, requestHead: request, enableSSL: false)
return webSocketHandlerFactory.handler(for: request)
}.flatMap { (handler: ChannelHandler) -> EventLoopFuture<Void> in
return ctx.channel.pipeline.addHandler(handler).flatMap {
return channel.pipeline.addHandler(handler).flatMap {
if let _extensions = request.headers["Sec-WebSocket-Extensions"].first {
let handlers = webSocketHandlerFactory.extensionHandlers(header: _extensions)
return ctx.channel.pipeline.addHandlers(handlers, position: .before(handler))
return channel.pipeline.addHandlers(handlers, position: .before(handler))
} else {
// No extensions. We must return success.
return ctx.channel.eventLoop.makeSucceededFuture(())
return channel.eventLoop.makeSucceededFuture(())
}
}
}
Expand All @@ -222,7 +219,7 @@ public class HTTPServer: Server {

private func generateUpgradePipelineHandler(_ webSocketHandlerFactory: ProtocolHandlerFactory) -> UpgradePipelineHandlerFunction {
return { (channel: Channel, request: HTTPRequestHead) in
return self.upgradeHandler(webSocketHandlerFactory: webSocketHandlerFactory, request: request)
return self.upgradeHandler(channel: channel, webSocketHandlerFactory: webSocketHandlerFactory, request: request)
}
}

Expand Down Expand Up @@ -304,8 +301,7 @@ public class HTTPServer: Server {
.serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEPORT), value: allowPortReuse ? 1 : 0)
.childChannelInitializer { channel in
let httpHandler = HTTPRequestHandler(for: self)
let config: NIOHTTPServerUpgradeConfiguration = (upgraders: upgraders, completionHandler: { ctx in
self.ctx = ctx
let config: NIOHTTPServerUpgradeConfiguration = (upgraders: upgraders, completionHandler: { _ in
_ = channel.pipeline.removeHandler(httpHandler)
})
return channel.pipeline.configureHTTPServerPipeline(withServerUpgrade: config, withErrorHandling: true).flatMap {
Expand Down
14 changes: 7 additions & 7 deletions Sources/KituraNet/HTTP/HTTPServerRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ public class HTTPServerRequest: ServerRequest {
*/
public var method: String

private let ctx: ChannelHandlerContext
private let channel: Channel

private var enableSSL: Bool = false

Expand Down Expand Up @@ -208,20 +208,20 @@ public class HTTPServerRequest: ServerRequest {
}
}

init(ctx: ChannelHandlerContext, requestHead: HTTPRequestHead, enableSSL: Bool) {
init(channel: Channel, requestHead: HTTPRequestHead, enableSSL: Bool) {
// An HTTPServerRequest may be created only on the EventLoop assigned to handle
// the connection on which the HTTP request arrived.
assert(ctx.eventLoop.inEventLoop)
self.ctx = ctx
assert(channel.eventLoop.inEventLoop)
self.channel = channel
self.headers = HeadersContainer(with: requestHead.headers)
self.method = requestHead.method.rawValue
self.httpVersionMajor = UInt16(requestHead.version.major)
self.httpVersionMinor = UInt16(requestHead.version.minor)
self.rawURLString = requestHead.uri
self.enableSSL = enableSSL
self.localAddressHost = HTTPServerRequest.host(socketAddress: ctx.localAddress)
self.localAddressPort = ctx.localAddress?.port ?? 0
self.remoteAddress = HTTPServerRequest.host(socketAddress: ctx.remoteAddress)
self.localAddressHost = HTTPServerRequest.host(socketAddress: channel.localAddress)
self.localAddressPort = channel.localAddress?.port ?? 0
self.remoteAddress = HTTPServerRequest.host(socketAddress: channel.remoteAddress)
}

var buffer: BufferList?
Expand Down

0 comments on commit a7be981

Please sign in to comment.