Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support websocket compression #339

Merged
merged 7 commits into from
Jun 24, 2017
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 201 additions & 0 deletions Source/Compression.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
//
// Compression.swift
// Starscream
//
// Created by Joseph Ross on 5/23/17.
// Copyright © 2017 Vluxe. All rights reserved.
//

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Link to the spec here so that later if there is a bug one can quickly find the reference to follow: https://tools.ietf.org/html/rfc7692

import Foundation

private let ZLIB_VERSION = Array("1.2.8".utf8CString)

private let Z_OK:CInt = 0
private let Z_BUF_ERROR:CInt = -5

private let Z_SYNC_FLUSH:CInt = 2

class Decompressor {
private var strm = z_stream()
private var buffer = [UInt8](repeating: 0, count: 0x2000)
private var inflateInitialized = false
private let windowBits:Int

init?(windowBits:Int) {
self.windowBits = windowBits
guard initInflate() else { return nil }
}

private func initInflate() -> Bool {
if Z_OK == inflateInit2(strm: &strm, windowBits: -CInt(windowBits),
version: ZLIB_VERSION, streamSize: CInt(MemoryLayout<z_stream>.size))
{
inflateInitialized = true
return true
}
return false
}

func reset() throws {
teardownInflate()
guard initInflate() else { throw NSError() }
}

func decompress(_ data: Data, finish: Bool) throws -> Data {
return try data.withUnsafeBytes { (bytes:UnsafePointer<UInt8>) -> Data in
return try decompress(bytes: bytes, count: data.count, finish: finish)
}
}

func decompress(bytes: UnsafePointer<UInt8>, count: Int, finish: Bool) throws -> Data {
var decompressed = Data()
try decompress(bytes: bytes, count: count, out: &decompressed)

if finish {
let tail:[UInt8] = [0x00, 0x00, 0xFF, 0xFF]
try decompress(bytes: tail, count: tail.count, out: &decompressed)
}

return decompressed

}

private func decompress(bytes: UnsafePointer<UInt8>, count: Int, out:inout Data) throws {
var res:CInt = 0
strm.next_in = bytes
strm.avail_in = CUnsignedInt(count)

repeat {
strm.next_out = UnsafeMutablePointer<UInt8>(&buffer)
strm.avail_out = CUnsignedInt(buffer.count)

res = inflate(strm: &strm, flush: 0)

let byteCount = buffer.count - Int(strm.avail_out)
out.append(buffer, count: byteCount)
} while res == Z_OK && strm.avail_out == 0

guard (res == Z_OK && strm.avail_out > 0)
|| (res == Z_BUF_ERROR && Int(strm.avail_out) == buffer.count)
else {
throw NSError(domain: WebSocket.ErrorDomain, code: Int(WebSocket.InternalErrorCode.compressionError.rawValue), userInfo: nil)
}
}

private func teardownInflate() {
if inflateInitialized, Z_OK == inflateEnd(strm: &strm) {
inflateInitialized = false
}
}

deinit {
teardownInflate()
}

@_silgen_name("inflateInit2_") private func inflateInit2(strm: UnsafeMutableRawPointer, windowBits: CInt,
version: UnsafePointer<CChar>, streamSize: CInt) -> CInt
@_silgen_name("inflate") private func inflate(strm: UnsafeMutableRawPointer, flush: CInt) -> CInt
@discardableResult
@_silgen_name("inflateEnd") private func inflateEnd(strm: UnsafeMutableRawPointer) -> CInt
}

class Compressor {
private var strm = z_stream()
private var buffer = [UInt8](repeating: 0, count: 0x2000)
private var deflateInitialized = false
private let windowBits:Int

init?(windowBits: Int) {
self.windowBits = windowBits
guard initDeflate() else { return nil }
}

private func initDeflate() -> Bool {
if Z_OK == deflateInit2(strm: &strm, level: Z_DEFAULT_COMPRESSION, method: Z_DEFLATED,
windowBits: -CInt(windowBits), memLevel: 8, strategy: Z_DEFAULT_STRATEGY,
version: ZLIB_VERSION, streamSize: CInt(MemoryLayout<z_stream>.size))
{
deflateInitialized = true
return true
}
return false
}

func reset() throws {
teardownDeflate()
guard initDeflate() else { throw NSError() }
}

func compress(_ data: Data) throws -> Data {
var compressed = Data()
var res:CInt = 0
data.withUnsafeBytes { (ptr:UnsafePointer<UInt8>) -> Void in
strm.next_in = ptr
strm.avail_in = CUnsignedInt(data.count)

repeat {
strm.next_out = UnsafeMutablePointer<UInt8>(&buffer)
strm.avail_out = CUnsignedInt(buffer.count)

res = deflate(strm: &strm, flush: Z_SYNC_FLUSH)

let byteCount = buffer.count - Int(strm.avail_out)
compressed.append(buffer, count: byteCount)
}
while res == Z_OK && strm.avail_out == 0

}

guard res == Z_OK && strm.avail_out > 0
|| (res == Z_BUF_ERROR && Int(strm.avail_out) == buffer.count)
else {
throw NSError(domain: WebSocket.ErrorDomain, code: Int(WebSocket.InternalErrorCode.compressionError.rawValue), userInfo: nil)
}

compressed.removeLast(4)
return compressed
}

private func teardownDeflate() {
if deflateInitialized, Z_OK == deflateEnd(strm: &strm) {
deflateInitialized = false
}
}

deinit {
teardownDeflate()
}

@_silgen_name("deflateInit2_") private func deflateInit2(strm: UnsafeMutableRawPointer, level: CInt, method: CInt,
windowBits: CInt, memLevel: CInt, strategy: CInt,
version: UnsafePointer<CChar>, streamSize: CInt) -> CInt
@_silgen_name("deflate") private func deflate(strm: UnsafeMutableRawPointer, flush: CInt) -> CInt
@discardableResult
@_silgen_name("deflateEnd") private func deflateEnd(strm: UnsafeMutableRawPointer) -> CInt

private let Z_DEFAULT_COMPRESSION:CInt = -1
private let Z_DEFLATED:CInt = 8
private let Z_DEFAULT_STRATEGY:CInt = 0
}

private struct z_stream {
var next_in: UnsafePointer<UInt8>? = nil /* next input byte */
var avail_in: CUnsignedInt = 0 /* number of bytes available at next_in */
var total_in: CUnsignedLong = 0 /* total number of input bytes read so far */

var next_out: UnsafeMutablePointer<UInt8>? = nil /* next output byte should be put there */
var avail_out: CUnsignedInt = 0 /* remaining free space at next_out */
var total_out: CUnsignedLong = 0 /* total number of bytes output so far */

var msg: UnsafePointer<CChar>? = nil /* last error message, NULL if no error */
private var state: OpaquePointer? = nil /* not visible by applications */

private var zalloc: OpaquePointer? = nil /* used to allocate the internal state */
private var zfree: OpaquePointer? = nil /* used to free the internal state */
private var opaque: OpaquePointer? = nil /* private data object passed to zalloc and zfree */

var data_type: CInt = 0 /* best guess about the data type: binary or text */
var adler: CUnsignedLong = 0 /* adler32 value of the uncompressed data */
private var reserved: CUnsignedLong = 0 /* reserved for future use */
}

86 changes: 83 additions & 3 deletions Source/WebSocket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ open class WebSocket : NSObject, StreamDelegate {
enum InternalErrorCode: UInt16 {
// 0-999 WebSocket status codes not used
case outputStreamWriteError = 1
case compressionError = 2
}

// Where the callback is executed. It defaults to the main UI thread queue.
Expand All @@ -86,13 +87,15 @@ open class WebSocket : NSObject, StreamDelegate {
let headerWSProtocolName = "Sec-WebSocket-Protocol"
let headerWSVersionName = "Sec-WebSocket-Version"
let headerWSVersionValue = "13"
let headerWSExtensionName = "Sec-WebSocket-Extensions"
let headerWSKeyName = "Sec-WebSocket-Key"
let headerOriginName = "Origin"
let headerWSAcceptName = "Sec-WebSocket-Accept"
let BUFFER_MAX = 4096
let FinMask: UInt8 = 0x80
let OpCodeMask: UInt8 = 0x0F
let RSVMask: UInt8 = 0x70
let RSV1Mask: UInt8 = 0x40
let MaskMask: UInt8 = 0x80
let PayloadLenMask: UInt8 = 0x7F
let MaxFrameSize: Int = 32
Expand Down Expand Up @@ -128,6 +131,7 @@ open class WebSocket : NSObject, StreamDelegate {
public var headers = [String: String]()
public var voipEnabled = false
public var disableSSLCertValidation = false
public var enableCompression = true
public var security: SSLTrustValidator?
public var enabledSSLCipherSuites: [SSLCipherSuite]?
public var origin: String?
Expand All @@ -139,12 +143,24 @@ open class WebSocket : NSObject, StreamDelegate {
public var currentURL: URL { return url }

// MARK: - Private

private struct CompressionState {
var supportsCompression = false
var messageNeedsDecompression = false
var serverMaxWindowBits = 15
var clientMaxWindowBits = 15
var clientNoContextTakeover = false
var serverNoContextTakeover = false
var decompressor:Decompressor? = nil
var compressor:Compressor? = nil
}

private var url: URL
private var inputStream: InputStream?
private var outputStream: OutputStream?
private var connected = false
private var isConnecting = false
private var compressionState = CompressionState()
private var writeQueue = OperationQueue()
private var readStack = [WSResponse]()
private var inputQueue = [Data]()
Expand Down Expand Up @@ -279,6 +295,10 @@ open class WebSocket : NSObject, StreamDelegate {
if let origin = origin {
addHeader(urlRequest, key: headerOriginName, val: origin)
}
if enableCompression {
let val = "permessage-deflate; client_max_window_bits; server_max_window_bits=15"
addHeader(urlRequest, key: headerWSExtensionName, val: val)
}
addHeader(urlRequest, key: headerWSHostName, val: "\(url.host!):\(port!)")
for (key, value) in headers {
addHeader(urlRequest, key: key, val: value)
Expand Down Expand Up @@ -577,6 +597,34 @@ open class WebSocket : NSObject, StreamDelegate {
}
if let cfHeaders = CFHTTPMessageCopyAllHeaderFields(response) {
let headers = cfHeaders.takeRetainedValue() as NSDictionary
if let extensionHeader = headers[headerWSExtensionName as NSString] as? NSString {
let parts = extensionHeader.components(separatedBy: ";")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for readability what about moving this out to another method. There are already some uber methods on this class personally my preference on this is to keep method size down where it makes logical sense.

for p in parts {
let part = p.trimmingCharacters(in: .whitespaces)
if part == "permessage-deflate" {
compressionState.supportsCompression = true
} else if part.hasPrefix("server_max_window_bits="){
let valString = part.components(separatedBy: "=")[1]
if let val = Int(valString.trimmingCharacters(in: .whitespaces)) {
compressionState.serverMaxWindowBits = val
}
} else if part.hasPrefix("client_max_window_bits="){
let valString = part.components(separatedBy: "=")[1]
if let val = Int(valString.trimmingCharacters(in: .whitespaces)) {
compressionState.clientMaxWindowBits = val
}
} else if part == "client_no_context_takeover"{
compressionState.clientNoContextTakeover = true
} else if part == "server_no_context_takeover"{
compressionState.serverNoContextTakeover = true
}
}
if compressionState.supportsCompression {
compressionState.decompressor = Decompressor(windowBits: compressionState.serverMaxWindowBits)
compressionState.compressor = Compressor(windowBits: compressionState.clientMaxWindowBits)
}
}

if let acceptKey = headers[headerWSAcceptName as NSString] as? NSString {
if acceptKey.length > 0 {
return 0
Expand Down Expand Up @@ -650,7 +698,10 @@ open class WebSocket : NSObject, StreamDelegate {
let isMasked = (MaskMask & baseAddress[1])
let payloadLen = (PayloadLenMask & baseAddress[1])
var offset = 2
if (isMasked > 0 || (RSVMask & baseAddress[0]) > 0) && receivedOpcode != .pong {
if compressionState.supportsCompression && receivedOpcode != .continueFrame {
compressionState.messageNeedsDecompression = (RSV1Mask & baseAddress[0]) > 0
}
if (isMasked > 0 || (RSVMask & baseAddress[0]) > 0) && receivedOpcode != .pong && !compressionState.messageNeedsDecompression {
let errCode = CloseCode.protocolError.rawValue
doDisconnect(errorWithDetail("masked and rsv data is not currently supported", code: errCode))
writeError(errCode)
Expand Down Expand Up @@ -710,7 +761,23 @@ open class WebSocket : NSObject, StreamDelegate {
offset += size
len -= UInt64(size)
}
let data = Data(bytes: baseAddress+offset, count: Int(len))
let data: Data
if compressionState.messageNeedsDecompression, let decompressor = compressionState.decompressor {
do {
data = try decompressor.decompress(bytes: baseAddress+offset, count: Int(len), finish: isFin > 0)
if isFin > 0 && compressionState.serverNoContextTakeover{
try decompressor.reset()
}
} catch {
let closeReason = "Decompression failed: \(error)"
let closeCode = CloseCode.encoding.rawValue
doDisconnect(errorWithDetail(closeReason, code: closeCode))
writeError(closeCode)
return emptyBuffer
}
} else {
data = Data(bytes: baseAddress+offset, count: Int(len))
}

if receivedOpcode == .connectionClose {
var closeReason = "connection closed by server"
Expand Down Expand Up @@ -864,10 +931,23 @@ open class WebSocket : NSObject, StreamDelegate {
guard let s = self else { return }
guard let sOperation = operation else { return }
var offset = 2
var firstByte:UInt8 = s.FinMask | code.rawValue
var data = data
if [.textFrame, .binaryFrame].contains(code), let compressor = s.compressionState.compressor {
do {
data = try compressor.compress(data)
if s.compressionState.clientNoContextTakeover {
try compressor.reset()
}
firstByte |= s.RSV1Mask
} catch {
// TODO: report error? We can just send the uncompressed frame.
}
}
let dataLength = data.count
let frame = NSMutableData(capacity: dataLength + s.MaxFrameSize)
let buffer = UnsafeMutableRawPointer(frame!.mutableBytes).assumingMemoryBound(to: UInt8.self)
buffer[0] = s.FinMask | code.rawValue
buffer[0] = firstByte
if dataLength < 126 {
buffer[1] = CUnsignedChar(dataLength)
} else if dataLength <= Int(UInt16.max) {
Expand Down
Loading