diff --git a/CHANGELOG.md b/CHANGELOG.md index 1306c90e..3c50e8a4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - CA file was not closed after MD5 calculation when using PIA patches. - Mitigated an issue with MTU in TCP mode during negotiation. [#39](https://github.com/keeshux/tunnelkit/issues/39) +- Handle server-initiated renegotiation. [#41](https://github.com/keeshux/tunnelkit/pull/41) ## 1.2.0 (2018-10-20) diff --git a/TunnelKit/Sources/AppExtension/TunnelKitProvider.swift b/TunnelKit/Sources/AppExtension/TunnelKitProvider.swift index 47b87d4f..35e73737 100644 --- a/TunnelKit/Sources/AppExtension/TunnelKitProvider.swift +++ b/TunnelKit/Sources/AppExtension/TunnelKitProvider.swift @@ -652,7 +652,7 @@ extension TunnelKitProvider { } } else if let se = error as? SessionError { switch se { - case .negotiationTimeout, .pingTimeout: + case .negotiationTimeout, .pingTimeout, .staleSession: return .timeout case .badCredentials: diff --git a/TunnelKit/Sources/Core/SessionError.swift b/TunnelKit/Sources/Core/SessionError.swift index 632fe284..2198138e 100644 --- a/TunnelKit/Sources/Core/SessionError.swift +++ b/TunnelKit/Sources/Core/SessionError.swift @@ -67,6 +67,9 @@ public enum SessionError: String, Error { /// The server couldn't ping back before timeout. case pingTimeout + + /// The session reached a stale state and can't be recovered. + case staleSession } extension Error { diff --git a/TunnelKit/Sources/Core/SessionProxy.swift b/TunnelKit/Sources/Core/SessionProxy.swift index 53a6d943..e554ff8c 100644 --- a/TunnelKit/Sources/Core/SessionProxy.swift +++ b/TunnelKit/Sources/Core/SessionProxy.swift @@ -326,12 +326,6 @@ public class SessionProxy { private func start() { loopLink() hardReset() - - guard !keys.isEmpty else { - fatalError("Main loop must follow hard reset, keys are empty!") - } - - loopNegotiation() } private func loopNegotiation() { @@ -466,6 +460,12 @@ public class SessionProxy { // deferStop(.shutdown, e) // return } + if (code == .hardResetServerV2) && (negotiationKey.state != .hardReset) { + deferStop(.shutdown, SessionError.staleSession) + return + } else if (code == .softResetV1) && (negotiationKey.state != .softReset) { + softReset(isServerInitiated: true) + } sendAck(for: controlPacket) @@ -557,6 +557,10 @@ public class SessionProxy { let payload = hardResetPayload() ?? Data() negotiationKey.state = .hardReset + guard !keys.isEmpty else { + fatalError("Main loop must follow hard reset, keys are empty!") + } + loopNegotiation() enqueueControlPackets(code: .hardResetClientV2, key: UInt8(negotiationKeyIdx), payload: payload) } @@ -574,8 +578,12 @@ public class SessionProxy { } // Ruby: soft_reset - private func softReset() { - log.debug("Send soft reset") + private func softReset(isServerInitiated: Bool) { + if isServerInitiated { + log.debug("Handle soft reset") + } else { + log.debug("Send soft reset") + } resetControlChannel(forNewSession: false) negotiationKeyIdx = max(1, (negotiationKeyIdx + 1) % ProtocolMacros.numberOfKeys) @@ -586,7 +594,9 @@ public class SessionProxy { negotiationKey.state = .softReset negotiationKey.softReset = true loopNegotiation() - enqueueControlPackets(code: .softResetV1, key: UInt8(negotiationKeyIdx), payload: Data()) + if !isServerInitiated { + enqueueControlPackets(code: .softResetV1, key: UInt8(negotiationKeyIdx), payload: Data()) + } } // Ruby: on_tls_connect @@ -667,7 +677,7 @@ public class SessionProxy { let elapsed = -negotiationKey.startTime.timeIntervalSinceNow if (elapsed > renegotiatesAfter) { log.debug("Renegotiating after \(elapsed) seconds") - softReset() + softReset(isServerInitiated: false) } } @@ -683,15 +693,16 @@ public class SessionProxy { // Ruby: handle_ctrl_pkt private func handleControlPacket(_ packet: ControlPacket) { - guard (packet.key == negotiationKey.id) else { + guard packet.key == negotiationKey.id else { log.error("Bad key in control packet (\(packet.key) != \(negotiationKey.id))") // deferStop(.shutdown, SessionError.badKey) return } - if (((packet.code == .hardResetServerV2) && (negotiationKey.state == .hardReset)) || - ((packet.code == .softResetV1) && (negotiationKey.state == .softReset))) { - + // start new TLS handshake + if ((packet.code == .hardResetServerV2) && (negotiationKey.state == .hardReset)) || + ((packet.code == .softResetV1) && (negotiationKey.state == .softReset)) { + if negotiationKey.state == .hardReset { controlChannel.remoteSessionId = packet.sessionId } @@ -738,6 +749,7 @@ public class SessionProxy { log.debug("TLS.connect: Pulled ciphertext (\(cipherTextOut.count) bytes)") enqueueControlPackets(code: .controlV1, key: negotiationKey.id, payload: cipherTextOut) } + // exchange TLS ciphertext else if ((packet.code == .controlV1) && (negotiationKey.state == .tls)) { guard let remoteSessionId = controlChannel.remoteSessionId else { log.error("No remote sessionId found in packet (control packets before server HARD_RESET)")