diff --git a/lib/websocket-server.js b/lib/websocket-server.js index d0a29783f..059b0a572 100644 --- a/lib/websocket-server.js +++ b/lib/websocket-server.js @@ -230,21 +230,34 @@ class WebSocketServer extends EventEmitter { handleUpgrade(req, socket, head, cb) { socket.on('error', socketOnError); - const key = - req.headers['sec-websocket-key'] !== undefined - ? req.headers['sec-websocket-key'] - : false; + const key = req.headers['sec-websocket-key']; const version = +req.headers['sec-websocket-version']; - if ( - req.method !== 'GET' || - req.headers.upgrade.toLowerCase() !== 'websocket' || - !key || - !keyRegex.test(key) || - (version !== 8 && version !== 13) || - !this.shouldHandle(req) - ) { - return abortHandshake(socket, 400); + if (req.method !== 'GET') { + abortHandshake(socket, 405, 'The HTTP method is invalid'); + return; + } + + if (req.headers.upgrade.toLowerCase() !== 'websocket') { + abortHandshake(socket, 400, 'The Upgrade header is invalid'); + return; + } + + if (!key || !keyRegex.test(key)) { + const message = 'The Sec-WebSocket-Key header is missing or invalid'; + abortHandshake(socket, 400, message); + return; + } + + if (version !== 8 && version !== 13) { + const message = 'The Sec-WebSocket-Version header is missing or invalid'; + abortHandshake(socket, 400, message); + return; + } + + if (!this.shouldHandle(req)) { + abortHandshake(socket, 400); + return; } const secWebSocketProtocol = req.headers['sec-websocket-protocol']; @@ -254,7 +267,9 @@ class WebSocketServer extends EventEmitter { try { protocols = subprotocol.parse(secWebSocketProtocol); } catch (err) { - return abortHandshake(socket, 400); + const message = 'The Sec-WebSocket-Protocol header is invalid'; + abortHandshake(socket, 400, message); + return; } } @@ -279,7 +294,10 @@ class WebSocketServer extends EventEmitter { extensions[PerMessageDeflate.extensionName] = perMessageDeflate; } } catch (err) { - return abortHandshake(socket, 400); + const message = + 'The Sec-WebSocket-Extensions header is invalid or not acceptable'; + abortHandshake(socket, 400, message); + return; } } @@ -446,7 +464,7 @@ function emitClose(server) { } /** - * Handle premature socket errors. + * Handle socket errors. * * @private */ @@ -464,25 +482,30 @@ function socketOnError() { * @private */ function abortHandshake(socket, code, message, headers) { - if (socket.writable) { - message = message || http.STATUS_CODES[code]; - headers = { - Connection: 'close', - 'Content-Type': 'text/html', - 'Content-Length': Buffer.byteLength(message), - ...headers - }; + // + // The socket is writable unless the user destroyed or ended it before calling + // `server.handleUpgrade()` or in the `verifyClient` function, which is a user + // error. Handling this does not make much sense as the worst that can happen + // is that some of the data written by the user might be discarded due to the + // call to `socket.end()` below, which triggers an `'error'` event that in + // turn causes the socket to be destroyed. + // + message = message || http.STATUS_CODES[code]; + headers = { + Connection: 'close', + 'Content-Type': 'text/html', + 'Content-Length': Buffer.byteLength(message), + ...headers + }; - socket.write( - `HTTP/1.1 ${code} ${http.STATUS_CODES[code]}\r\n` + - Object.keys(headers) - .map((h) => `${h}: ${headers[h]}`) - .join('\r\n') + - '\r\n\r\n' + - message - ); - } + socket.once('finish', socket.destroy); - socket.removeListener('error', socketOnError); - socket.destroy(); + socket.end( + `HTTP/1.1 ${code} ${http.STATUS_CODES[code]}\r\n` + + Object.keys(headers) + .map((h) => `${h}: ${headers[h]}`) + .join('\r\n') + + '\r\n\r\n' + + message + ); } diff --git a/test/websocket-server.test.js b/test/websocket-server.test.js index fd494059f..5590d6c99 100644 --- a/test/websocket-server.test.js +++ b/test/websocket-server.test.js @@ -470,7 +470,9 @@ describe('WebSocketServer', () => { port: wss.address().port, headers: { Connection: 'Upgrade', - Upgrade: 'websocket' + Upgrade: 'websocket', + 'Sec-WebSocket-Key': 'dGhlIHNhbXBsZSBub25jZQ==', + 'Sec-WebSocket-Version': 13 } }); @@ -496,7 +498,20 @@ describe('WebSocketServer', () => { req.on('response', (res) => { assert.strictEqual(res.statusCode, 400); - wss.close(done); + + const chunks = []; + + res.on('data', (chunk) => { + chunks.push(chunk); + }); + + res.on('end', () => { + assert.strictEqual( + Buffer.concat(chunks).toString(), + 'The Sec-WebSocket-Key header is missing or invalid' + ); + wss.close(done); + }); }); }); }); @@ -539,6 +554,77 @@ describe('WebSocketServer', () => { }); describe('Connection establishing', () => { + it('fails if the HTTP method is not GET', (done) => { + const wss = new WebSocket.Server({ port: 0 }, () => { + const req = http.request({ + method: 'POST', + port: wss.address().port, + headers: { + Connection: 'Upgrade', + Upgrade: 'websocket' + } + }); + + req.on('response', (res) => { + assert.strictEqual(res.statusCode, 405); + + const chunks = []; + + res.on('data', (chunk) => { + chunks.push(chunk); + }); + + res.on('end', () => { + assert.strictEqual( + Buffer.concat(chunks).toString(), + 'The HTTP method is invalid' + ); + wss.close(done); + }); + }); + + req.end(); + }); + + wss.on('connection', () => { + done(new Error("Unexpected 'connection' event")); + }); + }); + + it('fails if the Upgrade header field value is not "websocket"', (done) => { + const wss = new WebSocket.Server({ port: 0 }, () => { + const req = http.get({ + port: wss.address().port, + headers: { + Connection: 'Upgrade', + Upgrade: 'foo' + } + }); + + req.on('response', (res) => { + assert.strictEqual(res.statusCode, 400); + + const chunks = []; + + res.on('data', (chunk) => { + chunks.push(chunk); + }); + + res.on('end', () => { + assert.strictEqual( + Buffer.concat(chunks).toString(), + 'The Upgrade header is invalid' + ); + wss.close(done); + }); + }); + }); + + wss.on('connection', () => { + done(new Error("Unexpected 'connection' event")); + }); + }); + it('fails if the Sec-WebSocket-Key header is invalid (1/2)', (done) => { const wss = new WebSocket.Server({ port: 0 }, () => { const req = http.get({ @@ -551,7 +637,20 @@ describe('WebSocketServer', () => { req.on('response', (res) => { assert.strictEqual(res.statusCode, 400); - wss.close(done); + + const chunks = []; + + res.on('data', (chunk) => { + chunks.push(chunk); + }); + + res.on('end', () => { + assert.strictEqual( + Buffer.concat(chunks).toString(), + 'The Sec-WebSocket-Key header is missing or invalid' + ); + wss.close(done); + }); }); }); @@ -573,7 +672,20 @@ describe('WebSocketServer', () => { req.on('response', (res) => { assert.strictEqual(res.statusCode, 400); - wss.close(done); + + const chunks = []; + + res.on('data', (chunk) => { + chunks.push(chunk); + }); + + res.on('end', () => { + assert.strictEqual( + Buffer.concat(chunks).toString(), + 'The Sec-WebSocket-Key header is missing or invalid' + ); + wss.close(done); + }); }); }); @@ -595,7 +707,20 @@ describe('WebSocketServer', () => { req.on('response', (res) => { assert.strictEqual(res.statusCode, 400); - wss.close(done); + + const chunks = []; + + res.on('data', (chunk) => { + chunks.push(chunk); + }); + + res.on('end', () => { + assert.strictEqual( + Buffer.concat(chunks).toString(), + 'The Sec-WebSocket-Version header is missing or invalid' + ); + wss.close(done); + }); }); }); @@ -618,7 +743,20 @@ describe('WebSocketServer', () => { req.on('response', (res) => { assert.strictEqual(res.statusCode, 400); - wss.close(done); + + const chunks = []; + + res.on('data', (chunk) => { + chunks.push(chunk); + }); + + res.on('end', () => { + assert.strictEqual( + Buffer.concat(chunks).toString(), + 'The Sec-WebSocket-Version header is missing or invalid' + ); + wss.close(done); + }); }); }); @@ -642,7 +780,20 @@ describe('WebSocketServer', () => { req.on('response', (res) => { assert.strictEqual(res.statusCode, 400); - wss.close(done); + + const chunks = []; + + res.on('data', (chunk) => { + chunks.push(chunk); + }); + + res.on('end', () => { + assert.strictEqual( + Buffer.concat(chunks).toString(), + 'The Sec-WebSocket-Protocol header is invalid' + ); + wss.close(done); + }); }); }); @@ -672,7 +823,21 @@ describe('WebSocketServer', () => { req.on('response', (res) => { assert.strictEqual(res.statusCode, 400); - wss.close(done); + + const chunks = []; + + res.on('data', (chunk) => { + chunks.push(chunk); + }); + + res.on('end', () => { + assert.strictEqual( + Buffer.concat(chunks).toString(), + 'The Sec-WebSocket-Extensions header is invalid or not ' + + 'acceptable' + ); + wss.close(done); + }); }); } );