diff --git a/config/config-example.js b/config/config-example.js index 7d8858676213..3d4dfa273b69 100644 --- a/config/config-example.js +++ b/config/config-example.js @@ -25,6 +25,14 @@ exports.workers = 1; // TODO: allow SSL to actually be possible to use for third-party servers at // some point. +// golang - toggle using Go instead of Node for sockets workers +// Node workers are more unstable at handling connections because of bugs in +// sockjs-node, but sending/receiving messages over connections on Go workers +// is slightly slower due to the extra work involved in performing IPC with +// them safely. This should be left set to false unless you know what you are +// doing. +exports.golang = false; + // proxyip - proxy IPs with trusted X-Forwarded-For headers // This can be either false (meaning not to trust any proxies) or an array // of strings. Each string should be either an IP address or a subnet given @@ -32,7 +40,7 @@ exports.workers = 1; // know what you are doing. exports.proxyip = false; -// ofe - write heapdumps if sockets.js workers run out of memory. +// ofe - write heapdumps if Node sockets workers run out of memory // If you wish to enable this, you will need to install ofe, as it is not a // installed by default: // $ npm install --no-save ofe diff --git a/dev-tools/sockets.js b/dev-tools/sockets.js new file mode 100644 index 000000000000..a4e79d77e12c --- /dev/null +++ b/dev-tools/sockets.js @@ -0,0 +1,35 @@ +'use strict'; + +const {Session, SockJSConnection} = require('sockjs/lib/transport'); + +const chars = 'abcdefghijklmnopqrstuvwxyz1234567890-'; +let sessionidCount = 0; + +/** + * @return string + */ +function generateSessionid() { + let ret = ''; + let idx = sessionidCount; + for (let i = 0; i < 8; i++) { + ret = chars[idx % chars.length] + ret; + idx = idx / chars.length | 0; + } + sessionidCount++; + return ret; +} + +/** + * @param {string} sessionid + * @param {{options: {{}}} config + * @return SockJSConnection + */ +exports.createSocket = function (sessionid = generateSessionid(), config = {options: {}}) { + let session = new Session(sessionid, config); + let socket = new SockJSConnection(session); + socket.remoteAddress = '127.0.0.1'; + socket.protocol = 'websocket'; + return socket; +}; + +// TODO: move worker mocks here, use require('../sockets-workers').Multiplexer to stub IPC diff --git a/package-lock.json b/package-lock.json index fa4850361ed5..cd867c4ac88f 100644 --- a/package-lock.json +++ b/package-lock.json @@ -4,42 +4,79 @@ "lockfileVersion": 1, "requires": true, "dependencies": { + "@types/cloud-env": { + "version": "0.2.0", + "resolved": "https://registry.npmjs.org/@types/cloud-env/-/cloud-env-0.2.0.tgz", + "integrity": "sha512-18AOYo8HyJYEmKcTt/wD4e6HGEInmKLEOWrE6qL8ran0DfyCbmxzR3R1sJKv4XjPHJOyyiXV4bH6tEnuwSgSBQ==", + "dev": true + }, + "@types/mime": { + "version": "1.3.1", + "resolved": "https://registry.npmjs.org/@types/mime/-/mime-1.3.1.tgz", + "integrity": "sha512-rek8twk9C58gHYqIrUlJsx8NQMhlxqHzln9Z9ODqiNgv3/s+ZwIrfr+djqzsnVM12xe9hL98iJ20lj2RvCBv6A==", + "dev": true + }, "@types/node": { - "version": "8.0.1", - "resolved": "https://registry.npmjs.org/@types/node/-/node-8.0.1.tgz", - "integrity": "sha512-bys2VRs6H7HP8S26aHgPWSiSX7q81TToe5HSSvl5bQjoSElQ2SwbGw2p6/DSDb7Vr0oKhewFao9ZuTn8DSag9Q==", + "version": "8.0.28", + "resolved": "https://registry.npmjs.org/@types/node/-/node-8.0.28.tgz", + "integrity": "sha512-HupkFXEv3O3KSzcr3Ylfajg0kaerBg1DyaZzRBBQfrU3NN1mTBRE7sCveqHwXLS5Yrjvww8qFzkzYQQakG9FuQ==", "dev": true }, + "@types/node-static": { + "version": "0.7.0", + "resolved": "https://registry.npmjs.org/@types/node-static/-/node-static-0.7.0.tgz", + "integrity": "sha512-4SImtzapcVt+rQEAKVtbT0eh2D895DKnyrRkDgcSpw+LNnol9zlJPcU6yDvjWrEV/6nBSPQqzY0AP69v5v2iEQ==", + "dev": true, + "requires": { + "@types/mime": "1.3.1", + "@types/node": "8.0.28" + } + }, "@types/nodemailer": { "version": "1.3.33", "resolved": "https://registry.npmjs.org/@types/nodemailer/-/nodemailer-1.3.33.tgz", "integrity": "sha512-PONEJf/LwNcqgU/GpMIAquSBFdq+kCdpYI9TdoeGcTfLCsXzWunKzv4bUQs8zfKGz97CLymgoL0fMLYpOu+/1A==", "dev": true, "requires": { - "@types/node": "8.0.1", - "@types/nodemailer-direct-transport": "1.0.29", - "@types/nodemailer-smtp-transport": "2.7.2" + "@types/node": "8.0.28", + "@types/nodemailer-direct-transport": "1.0.30", + "@types/nodemailer-smtp-transport": "2.7.3" } }, "@types/nodemailer-direct-transport": { - "version": "1.0.29", - "resolved": "https://registry.npmjs.org/@types/nodemailer-direct-transport/-/nodemailer-direct-transport-1.0.29.tgz", - "integrity": "sha1-Ake7RzT4u/k5gkGxa0ECLhMD3fE=", + "version": "1.0.30", + "resolved": "https://registry.npmjs.org/@types/nodemailer-direct-transport/-/nodemailer-direct-transport-1.0.30.tgz", + "integrity": "sha512-gH49BNkXM8EZb/UgI4hUwWwTW3izRx5L+0VyohKkbVijvfUIhn7RALSpBjCUyXzEj0XZSNmQMFVc97Lj0z8UIw==", "dev": true, "requires": { "@types/nodemailer": "1.3.33" } }, "@types/nodemailer-smtp-transport": { - "version": "2.7.2", - "resolved": "https://registry.npmjs.org/@types/nodemailer-smtp-transport/-/nodemailer-smtp-transport-2.7.2.tgz", - "integrity": "sha512-sFeTdk87Xk4ADTs7HY32Thr/sD06HJHPbmjuCOGtqVhSnMeVPf3OQAU3d4oW9bhccRLpjCWUds00CrbuU/iLzw==", + "version": "2.7.3", + "resolved": "https://registry.npmjs.org/@types/nodemailer-smtp-transport/-/nodemailer-smtp-transport-2.7.3.tgz", + "integrity": "sha512-HxKPBErWelYVIWiKkUl06IaG4ojEMDtH6cAlojKgjsqwF8UQun4QeahYCWLCkA8/vKOX0G6VV1Vu2Z4x4ovqLQ==", "dev": true, "requires": { - "@types/node": "8.0.1", + "@types/node": "8.0.28", "@types/nodemailer": "1.3.33" } }, + "@types/ofe": { + "version": "0.5.0", + "resolved": "https://registry.npmjs.org/@types/ofe/-/ofe-0.5.0.tgz", + "integrity": "sha512-d/yCVOHDKVLQxzXLjU3cXol59ctzrfKkjzcWrGfy3fnacZUUbs8fTBWjqRcY+dr4v/yVNOy+jLL3q2/OM1LOCQ==", + "dev": true + }, + "@types/sockjs": { + "version": "0.3.31", + "resolved": "https://registry.npmjs.org/@types/sockjs/-/sockjs-0.3.31.tgz", + "integrity": "sha512-6d+6cH187jHWoUP07fzDQkC1fATu7TXMNJaCKwMwpcjkRvAK4T7bDw8sukL3rE8A/ZEPmo94YghsgCkI2/V8kw==", + "dev": true, + "requires": { + "@types/node": "8.0.28" + } + }, "acorn-jsx": { "version": "3.0.1", "resolved": "https://registry.npmjs.org/acorn-jsx/-/acorn-jsx-3.0.1.tgz", @@ -719,7 +756,7 @@ "globals": { "version": "9.18.0", "resolved": "https://registry.npmjs.org/globals/-/globals-9.18.0.tgz", - "integrity": "sha512-S0nG3CLEQiY/ILxqtztTWH/3iRRdyBLw6KMDxnKMchrtbj2OFmehVh0WUCfW3DUrIgx/qFrJPICrq4Z4sTR9UQ==", + "integrity": "sha1-qjiWs+abSH8X4x7SFD1pqOMMLYo=", "dev": true }, "globby": { diff --git a/package.json b/package.json index 5b0714568537..8b53145887ce 100644 --- a/package.json +++ b/package.json @@ -47,10 +47,14 @@ "private": true, "license": "MIT", "devDependencies": { + "@types/cloud-env": "^0.2.0", + "@types/node": "^8.0.28", + "@types/node-static": "^0.7.0", + "@types/nodemailer": "^1.3.33", + "@types/ofe": "^0.5.0", + "@types/sockjs": "^0.3.31", "eslint": "^4.0.0", "mocha": "^3.0.0", - "@types/node": "^8.0.1", - "@types/nodemailer": "^1.3.33", "typescript": "^2.5.0-dev.20170622" } } diff --git a/pokemon-showdown b/pokemon-showdown index 5d4fd0e9b4c8..1c98398de1b5 100755 --- a/pokemon-showdown +++ b/pokemon-showdown @@ -39,27 +39,119 @@ try { ); } -if (!process.argv[2] || /^[0-9]+$/.test(process.argv[2])) { - // Start the server. We manually load app.js so it can be configured to run as - // the main module, rather than this file being considered the main module. - // This ensures any dependencies that were just installed can be found when - // running on Windows and avoids any other potential side effects of the main - // module not being app.js like it is assumed to be. - // - // The port the server should host on can be passed using the second argument - // when launching with this file the same way app.js normally allows, e.g. to - // host on port 9000: - // $ ./pokemon-showdown 9000 - - require('module')._load('./app', module, true); -} else switch (process.argv[2]) { +// ALlow arguments passed to the launch script to be evaluated as commands. +let [, , arg2, arg3, arg4] = process.argv; +if (arg2 && /^[0-9]$/.test(arg2)) { + switch (arg2) { case 'generate-team': const Dex = require('./sim/dex'); global.toId = Dex.getId; - const seed = process.argv[4] ? process.argv[4].split(',').map(Number) : undefined; - console.log(Dex.packTeam(Dex.generateTeam(process.argv[3], seed))); - break; + const seed = arg4 ? arg4.split(',').map(Number) : undefined; + console.log(Dex.packTeam(Dex.generateTeam(arg3, seed))); + process.exit(0); default: - console.error('Unrecognized command: ' + process.argv[2]); + console.error(`Unrecognized command: ${arg2}`); process.exit(1); + } } + +// If evaluating commands wasn't the point of running this script, let's launch +// the server. + +// Check if the server is configured to use Go, and ensure the required +// environment variables and dependencies are available if that is the case + +let config; +try { + config = require('./config/config'); +} catch (e) {} + +if (config && config.golang) { + // GOPATH and GOROOT are optional to a degree, but we need them in order + // to be able to handle Go dependencies. Since Go only cares about the + // first path in the list, so will we. + const GOPATH = child_process.execSync('go env GOPATH', {stdio: null, encoding: 'utf8'}) + .trim() + .split(path.delimiter)[0] + .replace(/^"(.*)"$/, '$1'); + if (!GOPATH) { + // Should never happen, but it does on Bash on Ubuntu on Windows. + console.error('There is no $GOPATH environment variable set. Run:'); + console.error('$ go help GOPATH'); + console.error('For more information on how to configure it.'); + process.exit(0); + } + + const dependencies = ['github.com/gorilla/mux', 'github.com/igm/sockjs-go/sockjs']; + let packages = child_process.execSync('go list all', {stdio: null, encoding: 'utf8'}); + for (let dep of dependencies) { + if (!packages.includes(dep)) { + console.log(`Dependency ${dep} is not installed. Fetching...`); + child_process.execSync(`go get ${dep}`, {stdio: 'inherit'}); + } + } + + let stat; + let needsSrcDir = false; + try { + stat = fs.lstatSync(path.resolve(GOPATH, 'src/github.com/Zarel')); + } catch (e) { + needsSrcDir = true; + } finally { + if (stat && !stat.isDirectory()) { + needsSrcDir = true; + } + } + + let srcPath = path.resolve(process.cwd(), 'sockets'); + let tarPath = path.resolve(GOPATH, 'src/github.com/Zarel/Pokemon-Showdown/sockets'); + if (needsSrcDir) { + try { + fs.mkdirSync(path.resolve(GOPATH, 'src/github.com/Zarel')); + fs.mkdirSync(path.resolve(GOPATH, 'src/github.com/Zarel/Pokemon-Showdown')); + } catch (e) { + console.error(e); + console.error(`Cannot make go source directory for the sockets library files! Symlink them manually from ${srcPath} to ${tarPath}`); + process.exit(0); + } + } + + try { + stat = fs.lstatSync(path.resolve(GOPATH, 'src/github.com/Zarel/Pokemon-Showdown/sockets')); + } catch (e) {} + + if (!stat || !stat.isSymbolicLink()) { + // Windows requires administrator privileges to make symlinks, so we + // make junctions instead. For our purposes they're compatible enough + // with symlinks on UNIX-like OSes. + let symlinkType = (process.platform === 'win32') ? 'junction' : 'dir'; + try { + fs.symlinkSync(srcPath, tarPath, symlinkType); + } catch (e) { + console.error(`Cannot make go source directory for the sockets library files! Symlink them manually from ${srcPath} to ${tarPath}`); + process.exit(0); + } + } + + console.log('Building Go source libs...'); + try { + child_process.execSync('go install github.com/Zarel/Pokemon-Showdown/sockets', {stdio: 'inherit'}); + } catch (e) { + // Go will show the errors that caused compiling Go's files to fail, so + // there's no reason to bother logging anything of our own. + process.exit(0); + } +} + +// Start the server. We manually load app.js so it can be configured to run as +// the main module, rather than this file being considered the main module. +// This ensures any dependencies that were just installed can be found when +// running on Windows and avoids any other potential side effects of the main +// module not being app.js like it is assumed to be. +// +// The port the server should host on can be passed using the second argument +// when launching with this file the same way app.js normally allows, e.g. to +// host on port 9000: +// $ ./pokemon-showdown 9000 + +require('module')._load('./app', module, true); diff --git a/sockets-workers.js b/sockets-workers.js new file mode 100644 index 000000000000..1f1de9f61b4c --- /dev/null +++ b/sockets-workers.js @@ -0,0 +1,599 @@ +/** + * Connections + * Pokemon Showdown - http://pokemonshowdown.com/ + * + * Abstraction layer for multi-process SockJS connections. + * + * This file handles all the communications between the users' browsers and + * the main process. + * + * @license MIT license + */ + +'use strict'; + +const cluster = require('cluster'); +const fs = require('fs'); +const path = require('path'); + +// IPC command tokens +const EVAL = '$'; +const SOCKET_CONNECT = '*'; +const SOCKET_DISCONNECT = '!'; +const SOCKET_RECEIVE = '<'; +const SOCKET_SEND = '>'; +const CHANNEL_ADD = '+'; +const CHANNEL_REMOVE = '-'; +const CHANNEL_BROADCAST = '#'; +const SUBCHANNEL_MOVE = '.'; +const SUBCHANNEL_BROADCAST = ':'; + +// Subchannel IDs +const DEFAULT_SUBCHANNEL_ID = '0'; +const P1_SUBCHANNEL_ID = '1'; +const P2_SUBCHANNEL_ID = '2'; + +// Regex for splitting subchannel broadcasts between subchannels. +const SUBCHANNEL_MESSAGE_REGEX = /\|split\n([^\n]*)\n([^\n]*)\n([^\n]*)\n[^\n]*/g; + +/** + * Manages the worker's state for sockets, channels, and + * subchannels. This is responsible for parsing all outgoing and incoming + * messages. + */ +class Multiplexer { + constructor() { + /** @type {number} */ + this.socketCounter = 0; + /** @type {Map} */ + this.sockets = new Map(); + /** @type {Map>} */ + this.channels = new Map(); + /** @type {?NodeJS.Timer} */ + this.cleanupInterval = setInterval(() => this.sweepClosedSockets(), 10 * 60 * 1000); + } + + /** + * Mitigates a potential bug in SockJS or Faye-Websocket where + * sockets fail to emit a 'close' event after having disconnected. + */ + sweepClosedSockets() { + this.sockets.forEach(socket => { + if (socket.protocol === 'xhr-streaming' && + socket._session && + socket._session.recv) { + socket._session.recv.didClose(); + } + + // A ghost connection's `_session.to_tref._idlePrev` (and `_idleNext`) property is `null` while + // it is an object for normal users. Under normal circumstances, those properties should only be + // `null` when the timeout has already been called, but somehow it's not happening for some connections. + // Simply calling `_session.timeout_cb` (the function bound to the aformentioned timeout) manually + // on those connections kills those connections. For a bit of background, this timeout is the timeout + // that sockjs sets to wait for users to reconnect within that time to continue their session. + if (socket._session && + socket._session.to_tref && + !socket._session.to_tref._idlePrev) { + socket._session.timeout_cb(); + } + }); + + // Don't bother deleting the sockets from our map; their close event + // handler will deal with it. + } + + /** + * Sends an IPC message to the parent process. + * + * @param {string} token + * @param {string[]} params + */ + sendUpstream(token, ...params) { + let message = `${token}${params.join('\n')}`; + if (process.send) process.send(message); + } + + /** + * Parses the params in a downstream message sent as a + * command. + * + * @param {string} params + * @param {number} count + * @return {string[]} + */ + parseParams(params, count) { + let i = 0; + let idx = 0; + let ret = []; + while (i++ < count) { + let newIdx = params.indexOf('\n', idx); + if (newIdx < 0) { + // No remaining newlines; just use the rest of the string as + // the last parametre. + ret.push(params.slice(idx)); + break; + } + + let param = params.slice(idx, newIdx); + if (i === count) { + // We reached the number of parametres needed, but there is + // still some remaining string left. Glue it to the last one. + param += `\n${params.slice(newIdx + 1)}`; + } else { + idx = newIdx + 1; + } + + ret.push(param); + } + + return ret; + } + + /** + * Parses downstream messages. + * + * @param {string} data + * @return {boolean} + */ + receiveDownstream(data) { + // console.log(`worker received: ${data}`); + let token = data.charAt(0); + let params = data.substr(1); + switch (token) { + case EVAL: + return this.onEval(params); + case SOCKET_DISCONNECT: + return this.onSocketDisconnect(params); + case SOCKET_SEND: + // @ts-ignore + return this.onSocketSend(...this.parseParams(params, 2)); + case CHANNEL_ADD: + // @ts-ignore + return this.onChannelAdd(...this.parseParams(params, 2)); + case CHANNEL_REMOVE: + // @ts-ignore + return this.onChannelRemove(...this.parseParams(params, 2)); + case CHANNEL_BROADCAST: + // @ts-ignore + return this.onChannelBroadcast(...this.parseParams(params, 2)); + case SUBCHANNEL_MOVE: + // @ts-ignore + return this.onSubchannelMove(...this.parseParams(params, 3)); + case SUBCHANNEL_BROADCAST: + // @ts-ignore + return this.onSubchannelBroadcast(...this.parseParams(params, 2)); + default: + console.error(`Sockets: attempted to send unknown IPC message with token ${token}: ${params}`); + return false; + } + } + + /** + * Safely tries to destroy a socket's connection. + * + * @param {any} socket + */ + tryDestroySocket(socket) { + try { + socket.end(); + socket.destroy(); + } catch (e) {} + } + + /** + * Eval handler for downstream messages. + * + * @param {string} expr + * @return {boolean} + */ + onEval(expr) { + try { + eval(expr); + return true; + } catch (e) {} + return false; + } + + /** + * Sockets.socketConnect message handler. + * + * @param {any} socket + * @return {boolean} + */ + onSocketConnect(socket) { + if (!socket) return false; + if (!socket.remoteAddress) { + this.tryDestroySocket(socket); + return false; + } + + let socketid = '' + this.socketCounter++; + let ip = socket.remoteAddress; + let ips = socket.headers['x-forwarded-for'] || ''; + this.sockets.set(socketid, socket); + this.sendUpstream(SOCKET_CONNECT, socketid, ip, ips, socket.protocol); + + socket.on('data', /** @param {string} message */ message => { + this.onSocketReceive(socketid, message); + }); + + socket.on('close', () => { + this.sendUpstream(SOCKET_DISCONNECT, socketid); + this.sockets.delete(socketid); + this.channels.forEach((channel, channelid) => { + if (!channel || !channel.has(socketid)) return; + channel.delete(socketid); + if (!channel.size) this.channels.delete(channelid); + }); + }); + + return true; + } + + /** + * Sockets.socketDisconnect message handler. + * @param {string} socketid + * @return {boolean} + */ + onSocketDisconnect(socketid) { + let socket = this.sockets.get(socketid); + if (!socket) return false; + + this.tryDestroySocket(socket); + return true; + } + + /** + * Sockets.socketSend message handler. + * + * @param {string} socketid + * @param {string} message + * @return {boolean} + */ + onSocketSend(socketid, message) { + let socket = this.sockets.get(socketid); + if (!socket) return false; + + socket.write(message); + return true; + } + + /** + * onmessage event handler for sockets. Passes the message + * upstream. + * + * @param {string} socketid + * @param {string} message + * @return {boolean} + */ + onSocketReceive(socketid, message) { + // Drop empty messages (DDOS?). + if (!message) return false; + + // Drop >100KB messages. + if (message.length > 100 * 1024) { + console.log(`Dropping client message ${message.length / 1024}KB...`); + console.log(message.slice(0, 160)); + return false; + } + + // Drop legacy JSON messages. + if ((typeof message !== 'string') || message.startsWith('{')) return false; + + // Drop invalid messages (again, DDOS?). + if (message.endsWith('|') || !message.includes('|')) return false; + + this.sendUpstream(SOCKET_RECEIVE, socketid, message); + return true; + } + + /** + * Sockets.channelAdd message handler. + * + * @param {string} channelid + * @param {string} socketid + * @return {boolean} + */ + onChannelAdd(channelid, socketid) { + if (!this.sockets.has(socketid)) return false; + + if (this.channels.has(channelid)) { + let channel = this.channels.get(channelid); + if (!channel || channel.has(socketid)) return false; + channel.set(socketid, DEFAULT_SUBCHANNEL_ID); + } else { + let channel = new Map([[socketid, DEFAULT_SUBCHANNEL_ID]]); + this.channels.set(channelid, channel); + } + + return true; + } + + /** + * Sockets.channelRemove message handler. + * + * @param {string} channelid + * @param {string} socketid + * @return {boolean} + */ + onChannelRemove(channelid, socketid) { + let channel = this.channels.get(channelid); + if (!channel || !channel.has(socketid)) return false; + + channel.delete(socketid); + if (!channel.size) this.channels.delete(channelid); + return true; + } + + /** + * Sockets.channelSend and Sockets.channelBroadcast message + * handler. + * + * @param {string} channelid + * @param {string} message + * @return {boolean} + */ + onChannelBroadcast(channelid, message) { + let channel = this.channels.get(channelid); + if (!channel) return false; + + channel.forEach( + /** + * @param {string} subchannelid + * @param {string} socketid + */ + (subchannelid, socketid) => { + let socket = this.sockets.get(socketid); + socket.write(message); + } + ); + + return true; + } + + /** + * Sockets.subchannelMove message handler. + * + * @param {string} channelid + * @param {string} subchannelid + * @param {string} socketid + * @return {boolean} + */ + onSubchannelMove(channelid, subchannelid, socketid) { + if (!this.sockets.has(socketid)) return false; + + if (this.channels.has(channelid)) { + let channel = this.channels.get(channelid); + if (channel) channel.set(socketid, subchannelid); + } else { + let channel = new Map([[socketid, subchannelid]]); + this.channels.set(channelid, channel); + } + + return true; + } + + /** + * Sockets.subchannelBroadcast message handler. + * + * @param {string} channelid + * @param {string} message + * @return {boolean} + */ + onSubchannelBroadcast(channelid, message) { + let channel = this.channels.get(channelid); + if (!channel) return false; + + let msgs = {}; + channel.forEach( + /** + * @param {string} subchannelid + * @param {string} socketid + */ + (subchannelid, socketid) => { + let socket = this.sockets.get(socketid); + if (!socket) return; + + if (!(subchannelid in msgs)) { + switch (subchannelid) { + case DEFAULT_SUBCHANNEL_ID: + msgs[subchannelid] = message.replace(SUBCHANNEL_MESSAGE_REGEX, '$1'); + break; + case P1_SUBCHANNEL_ID: + msgs[subchannelid] = message.replace(SUBCHANNEL_MESSAGE_REGEX, '$2'); + break; + case P2_SUBCHANNEL_ID: + msgs[subchannelid] = message.replace(SUBCHANNEL_MESSAGE_REGEX, '$3'); + break; + } + } + + socket.write(msgs[subchannelid]); + } + ); + + return true; + } + + /** + * Cleans up the properties of the multiplexer once an internal message + * from the parent process dictates that the worker disconnect. We can't + * use the 'disconnect' handler for this because at that point the worker + * is already disconnected. + */ + destroy() { + // @ts-ignore + clearInterval(this.cleanupInterval); + this.cleanupInterval = null; + this.sockets.forEach(socket => this.tryDestroySocket(socket)); + this.sockets.clear(); + this.channels.clear(); + } +} + +if (cluster.isWorker) { + // @ts-ignore + global.Config = require('./config/config'); + + // @ts-ignore + if (process.env.PSPORT) Config.port = +process.env.PSPORT; + // @ts-ignore + if (process.env.PSBINDADDR) Config.bindaddress = process.env.PSBINDADDR; + // @ts-ignore + if (+process.env.PSNOSSL) Config.ssl = null; + + if (Config.ofe) { + try { + require.resolve('ofe'); + } catch (e) { + if (e.code !== 'MODULE_NOT_FOUND') throw e; // should never happen + throw new Error( + 'ofe is not installed, but it is a required dependency if Config.ofe is set to true! ' + + 'Run npm install ofe and restart the server.' + ); + } + + // Create a heapdump if the process runs out of memory. + require('ofe').call(); + } + + // Graceful crash. + process.on('uncaughtException', err => { + if (Config.crashguard) require('./crashlogger')(err, `Socket process ${cluster.worker.id} (${process.pid})`, true); + }); + + let app = require('http').createServer(); + /** @type {?NodeJS.Server} */ + let appssl = null; + if (Config.ssl) { + let key; + try { + key = path.resolve(__dirname, Config.ssl.options.key); + if (!fs.lstatSync(key).isFile()) throw new Error(); + try { + key = fs.readFileSync(key); + } catch (e) { + require('./crashlogger')(new Error(`Failed to read the configured SSL private key PEM file:\n${e.stack}`), `Socket process ${cluster.worker.id} (${process.pid})`, true); + } + } catch (e) { + console.warn('SSL private key config values will not support HTTPS server option values in the future. Please set it to use the absolute path of its PEM file.'); + key = Config.ssl.options.key; + } + + let cert; + try { + cert = path.resolve(__dirname, Config.ssl.options.cert); + if (!fs.lstatSync(cert).isFile()) throw new Error(); + try { + cert = fs.readFileSync(cert); + } catch (e) { + require('./crashlogger')(new Error(`Failed to read the configured SSL certificate PEM file:\n${e.stack}`), `Socket process ${cluster.worker.id} (${process.pid})`, true); + } + } catch (e) { + console.warn('SSL certificate config values will not support HTTPS server option values in the future. Please set it to use the absolute path of its PEM file.'); + cert = Config.ssl.options.cert; + } + + if (key && cert) { + try { + // In case there are additional SSL config settings besides the key and cert... + appssl = require('https').createServer(Object.assign({}, Config.ssl.options, {key, cert})); + } catch (e) { + require('./crashlogger')(new Error(`The SSL settings are misconfigured:\n${e.stack}`), `Socket process ${cluster.worker.id} (${process.pid})`, true); + } + } + } + + const StaticServer = require('node-static').Server; + const roomidRegex = /^\/[A-Za-z0-9][A-Za-z0-9-]*\/?$/; + const cssServer = new StaticServer('./config'); + const avatarServer = new StaticServer('./config/avatars'); + const staticServer = new StaticServer('./static'); + /** + * @param {any} req + * @param {any} res + */ + const staticRequestHandler = (req, res) => { + // console.log(`static rq: ${req.socket.remoteAddress}:${req.socket.remotePort} -> ${req.socket.localAddress}:${req.socket.localPort} - ${req.method} ${req.url} ${request.httpVersion} - ${req.rawHeaders.join('|')}`); + req.resume(); + req.addListener('end', () => { + if (Config.customhttpresponse && + Config.customhttpresponse(req, res)) { + return; + } + + let server = staticServer; + if (req.url === '/custom.css') { + server = cssServer; + } else if (req.url.startsWith('/avatars/')) { + req.url = req.url.substr(8); + server = avatarServer; + } else if (roomidRegex.test(req.url)) { + req.url = '/'; + } + + server.serve(req, res, e => { + // @ts-ignore + if (e && e.status === 404) { + staticServer.serveFile('404.html', 404, {}, req, res); + } + }); + }); + }; + + app.on('request', staticRequestHandler); + if (appssl) appssl.on('request', staticRequestHandler); + + // Launch the SockJS server. + const sockjs = require('sockjs'); + const server = sockjs.createServer({ + sockjs_url: '//play.pokemonshowdown.com/js/lib/sockjs-1.1.1-nwjsfix.min.js', + log(severity, message) { + if (severity === 'error') console.error(`Sockets worker SockJS error: ${message}`); + }, + prefix: '/showdown', + }); + + // Instantiate SockJS' multiplexer. This takes messages received downstream + // from the parent process and distributes them across the sockets they are + // targeting, as well as handling user disconnects and passing user + // messages upstream. + const multiplexer = new Multiplexer(); + + process.on('message', data => { + multiplexer.receiveDownstream(data); + }); + + // Clean up any remaining connections on disconnect. If this isn't done, + // the process will not exit until any remaining connections have been destroyed. + // Afterwards, the worker process will die on its own. + process.once('disconnect', () => { + multiplexer.destroy(); + app.close(); + /** @type {?NodeJS.Server} */ + if (appssl) appssl.close(); + }); + + server.on('connection', /** @param {any} socket */ socket => { + multiplexer.onSocketConnect(socket); + }); + + server.installHandlers(app, {}); + app.listen(Config.port, Config.bindaddress); + if (appssl) { + // @ts-ignore + server.installHandlers(appssl, {}); + appssl.listen(Config.ssl.port, Config.bindaddress); + } + + require('./repl').start( + `sockets-${cluster.worker.id}-${process.pid}`, + /** @param {string} cmd */ + cmd => eval(cmd) + ); +} + +module.exports = { + SUBCHANNEL_MESSAGE_REGEX, + Multiplexer, +}; \ No newline at end of file diff --git a/sockets.js b/sockets.js index 99a8712c693f..4bd3e272173b 100644 --- a/sockets.js +++ b/sockets.js @@ -4,568 +4,703 @@ * * Abstraction layer for multi-process SockJS connections. * - * This file handles all the communications between the users' - * browsers, the networking processes, and users.js in the - * main process. + * This file handles all the communications between the networking processes + * and users.js. * * @license MIT license */ 'use strict'; +const child_process = require('child_process'); const cluster = require('cluster'); -const fs = require('fs'); +const EventEmitter = require('events'); +const path = require('path'); if (cluster.isMaster) { cluster.setupMaster({ - exec: require('path').resolve(__dirname, 'sockets'), + exec: path.resolve(__dirname, 'sockets-workers'), }); +} - const workers = exports.workers = new Map(); - - const spawnWorker = exports.spawnWorker = function () { - let worker = cluster.fork({PSPORT: Config.port, PSBINDADDR: Config.bindaddress || '0.0.0.0', PSNOSSL: Config.ssl ? 0 : 1}); - let id = worker.id; - workers.set(id, worker); - worker.on('message', data => { - // console.log('master received: ' + data); - switch (data.charAt(0)) { - case '*': { - // *socketid, ip, protocol - // connect - let nlPos = data.indexOf('\n'); - let nlPos2 = data.indexOf('\n', nlPos + 1); - Users.socketConnect(worker, id, data.slice(1, nlPos), data.slice(nlPos + 1, nlPos2), data.slice(nlPos2 + 1)); - break; - } +/** + * IPC delimiter byte. This byte must stringify as a hexadeimal + * escape code when stringified as JSON to prevent messages from being able to + * contain the byte itself. + * + * @type {string} + */ +const DELIM = '\x03'; - case '!': { - // !socketid - // disconnect - Users.socketDisconnect(worker, id, data.substr(1)); - break; - } +/** + * Map of worker IDs to worker wrappers. + * + * @type {Map} + */ +const workers = new Map(); - case '<': { - // this.onListen()); + worker.on('message', /** @param {string} data */ data => this.onMessage(data)); + worker.once('error', /** @param {?Error} err */ err => this.onError(err)); + worker.once('exit', + /** + * @param {any} worker + * @param {?number} code + * @param {?string} status + */ + (worker, code, status) => this.onExit(worker, code, status) + ); + } + + /** + * Worker process getter + * + * @return {any} + */ + get process() { + return this.worker.process; + } - default: - // unhandled + /** + * Worker exitedAfterDisconnect getter + * + * @return {boolean | void} + */ + get exitedAfterDisconnect() { + return this.worker.exitedAfterDisconnect; + } + + /** + * Worker suicide getter + * + * @return {boolean | void} + */ + get suicide() { + return this.worker.exitedAfterDisconnect; + } + + /** + * Worker#disconnect wrapper + * + */ + disconnect() { + return this.worker.disconnect(); + } + + /** + * Worker#kill wrapper + * + * @param {string=} signal + */ + kill(signal) { + return this.worker.kill(signal); + } + + /** + * Worker#destroy wrapper + * + * @param {string=} signal + */ + destroy(signal) { + return this.worker.kill(signal); + } + + /** + * Worker#send wrapper + * + * @param {string} message + * @return {boolean} + */ + send(message) { + return this.worker.send(message); + } + + /** + * Worker#isConnected wrapper + * + * @return {boolean} + */ + isConnected() { + return this.worker.isConnected(); + } + + /** + * Worker#isDead wrapper + * + * @return {boolean} + */ + isDead() { + return this.worker.isDead(); + } + + /** + * 'listening' event handler for the worker. Logs which + * hostname and worker ID is listening to console. + */ + onListen() { + console.log(`Worker ${this.id} now listening on ${Config.bindaddress}:${Config.port}`); + if (Config.ssl) console.log(`Worker ${this.id} now listening for SSL on port ${Config.ssl.port}`); + console.log(`Test your server at http://${Config.bindaddress === '0.0.0.0' ? 'localhost' : Config.bindaddress}:${Config.port}`); + } + + /** + * 'message' event handler for the worker. Parses which type + * of command the incoming IPC message is calling, then passes its + * parametres to the appropriate method to handle. + * + * @param {string} data + */ + onMessage(data) { + // console.log(`master received: ${data}`); + let token = data.charAt(0); + let params = data.substr(1); + switch (token) { + case '*': + this.onSocketConnect(params); + break; + case '!': + this.onSocketDisconnect(params); + break; + case '<': + this.onSocketReceive(params); + break; + default: + console.error(`Sockets: received unknown IPC message with token ${token}: ${params}`); + break; + } + } + + /** + * Socket connection message handler. + * + * @param {string} params + */ + onSocketConnect(params) { + let [socketid, ip, header, protocol] = params.split('\n'); + + if (this.isTrustedProxyIp(ip)) { + let ips = header.split(','); + for (let i = ips.length; i--;) { + let proxy = ips[i].trim(); + if (proxy && !this.isTrustedProxyIp(proxy)) { + ip = proxy; + break; + } } - }); + } - return worker; - }; + Users.socketConnect(this, this.id, socketid, ip, protocol); + } - cluster.on('exit', (worker, code, signal) => { + /** + * Socket disconnect handler. + * + * @param {string} socketid + */ + onSocketDisconnect(socketid) { + Users.socketDisconnect(this, this.id, socketid); + } + + /** + * Socket message receive handler. + * + * @param {string} params + */ + onSocketReceive(params) { + let idx = params.indexOf('\n'); + let socketid = params.substr(0, idx); + let message = params.substr(idx + 1); + Users.socketReceive(this, this.id, socketid, message); + } + + /** + * Worker 'error' event handler. + * + * @param {?Error} err + */ + onError(err) { + this.error = err; + } + + /** + * Worker 'exit' event handler. + * + * @param {any} worker + * @param {?number} code + * @param {?string} signal + */ + onExit(worker, code, signal) { if (code === null && signal !== null) { - // Worker was killed by Sockets.killWorker or Sockets.killPid, probably. - console.log(`Worker ${worker.id} was forcibly killed with status ${signal}.`); + // Worker was killed by Sockets.killWorker or Sockets.killPid. + console.warn(`Worker ${this.id} was forcibly killed with the signal ${signal}`); workers.delete(worker.id); } else if (code === 0 && signal === null) { - // Happens when killing PS with ^C - console.log(`Worker ${worker.id} died, but returned a successful exit code.`); + console.warn(`Worker ${this.id} died, but returned a successful exit code.`); workers.delete(worker.id); - } else if (code > 0) { - // Worker was killed abnormally, likely because of a crash. - require('./crashlogger')(new Error(`Worker ${worker.id} abruptly died with code ${code} and signal ${signal}`), "The main process"); - // Don't delete the worker so it can be inspected if need be. + } else if (code !== null && code > 0) { + // Worker crashed. + if (this.error) { + require('./crashlogger')(new Error(`Worker ${this.id} abruptly died with the following stack trace: ${this.error.stack}`), 'The main process'); + } else { + require('./crashlogger')(new Error(`Worker ${this.id} abruptly died`), 'The main process'); + } + // Don't delete the worker - keep it for inspection. } - if (worker.isConnected()) worker.disconnect(); + if (this.isConnected()) this.disconnect(); // FIXME: this is a bad hack to get around a race condition in - // Connection#onDisconnect sending room deinit messages after already + // Connection#onDiscconnect sending room deinit messages after already // having removed the sockets from their channels. - worker.send = () => {}; + // @ts-ignore + this.send = () => {}; - let count = 0; - Users.connections.forEach(connection => { - if (connection.worker === worker) { - Users.socketDisconnect(worker, worker.id, connection.socketid); - count++; - } - }); + let count = Users.socketDisconnectAll(this); console.log(`${count} connections were lost.`); - // Attempt to recover. spawnWorker(); - }); - - exports.listen = function (port, bindAddress, workerCount) { - if (port !== undefined && !isNaN(port)) { - Config.port = port; - Config.ssl = null; - } else { - port = Config.port; - - // Autoconfigure when running in cloud environments. - try { - const cloudenv = require('cloud-env'); - bindAddress = cloudenv.get('IP', bindAddress); - port = cloudenv.get('PORT', port); - } catch (e) {} - } - if (bindAddress !== undefined) { - Config.bindaddress = bindAddress; - } - if (workerCount === undefined) { - workerCount = (Config.workers !== undefined ? Config.workers : 1); - } - for (let i = 0; i < workerCount; i++) { - spawnWorker(); - } - }; - - exports.killWorker = function (worker) { - let count = 0; - Users.connections.forEach(connection => { - if (connection.worker === worker) { - Users.socketDisconnect(worker, worker.id, connection.socketid); - count++; - } - }); - console.log(`${count} connections were lost.`); - - try { - worker.kill('SIGTERM'); - } catch (e) {} - - return count; - }; + } +} - exports.killPid = function (pid) { - for (let [workerid, worker] of workers) { // eslint-disable-line no-unused-vars - if (pid === worker.process.pid) { - return this.killWorker(worker); - } - } - return false; - }; - - exports.socketSend = function (worker, socketid, message) { - worker.send(`>${socketid}\n${message}`); - }; - exports.socketDisconnect = function (worker, socketid) { - worker.send(`!${socketid}`); - }; - - exports.channelBroadcast = function (channelid, message) { - workers.forEach(worker => { - worker.send(`#${channelid}\n${message}`); - }); - }; - exports.channelSend = function (worker, channelid, message) { - worker.send(`#${channelid}\n${message}`); - }; - exports.channelAdd = function (worker, channelid, socketid) { - worker.send(`+${channelid}\n${socketid}`); - }; - exports.channelRemove = function (worker, channelid, socketid) { - worker.send(`-${channelid}\n${socketid}`); - }; - - exports.subchannelBroadcast = function (channelid, message) { - workers.forEach(worker => { - worker.send(`:${channelid}\n${message}`); - }); - }; - exports.subchannelMove = function (worker, channelid, subchannelid, socketid) { - worker.send(`.${channelid}\n${subchannelid}\n${socketid}`); - }; -} else { - // is worker - global.Config = require('./config/config'); - - if (process.env.PSPORT) Config.port = +process.env.PSPORT; - if (process.env.PSBINDADDR) Config.bindaddress = process.env.PSBINDADDR; - if (+process.env.PSNOSSL) Config.ssl = null; - - if (Config.ofe) { - try { - require.resolve('ofe'); - } catch (e) { - if (e.code !== 'MODULE_NOT_FOUND') throw e; // should never happen - throw new Error( - 'ofe is not installed, but it is a required dependency if Config.ofe is set to true! ' + - 'Run npm install ofe and restart the server.' - ); - } +/** + * A mock Worker class for Go child processes. Similarly to + * Node.js workers, it uses a TCP net server to perform IPC. After launching + * the server, it will spawn the Go child process and wait for it to make a + * connection to the worker's server before performing IPC with it. + */ +class GoWorker extends EventEmitter { + /** + * @param {number} id + */ + constructor(id) { + super(); + + /** @type {number} */ + this.id = id; + /** @type {boolean | void} */ + this.exitedAfterDisconnect = undefined; + + /** @type {string} */ + this.ibuf = ''; + /** @type {?Error} */ + this.error = null; + + /** @type {any} */ + this.process = null; + /** @type {any} */ + this.connection = null; + /** @type {any} */ + this.server = require('net').createServer(); + this.server.once('connection', /** @param {any} connection */ connection => this.onChildConnect(connection)); + this.server.on('error', () => {}); + this.server.listen(() => process.nextTick(() => this.spawnChild())); + } - // Create a heapdump if the process runs out of memory. - require('ofe').call(); + /** + * Worker#disconnect mock + */ + disconnect() { + if (this.isConnected()) this.connection.destroy(); } - // Static HTTP server + /** + * Worker#kill mock + * + * @param {string} [signal = 'SIGTERM'] + */ + kill(signal = 'SIGTERM') { + if (this.process) this.process.kill(signal); + } - // This handles the custom CSS and custom avatar features, and also - // redirects yourserver:8001 to yourserver-8001.psim.us + /** + * Worker#destroy mock + * + * @param {string=} signal + */ + destroy(signal) { + return this.kill(signal); + } - // It's optional if you don't need these features. + /** + * Worker#send mock + * + * @param {string} message + * @return {boolean} + */ + send(message) { + if (!this.isConnected()) return false; + return this.connection.write(JSON.stringify(message) + DELIM); + } - global.Dnsbl = require('./dnsbl'); + /** + * Worker#isConnected mock + * + * @return {boolean} + */ + isConnected() { + return this.connection && !this.connection.destroyed; + } - if (Config.crashguard) { - // graceful crash - process.on('uncaughtException', err => { - require('./crashlogger')(err, `Socket process ${cluster.worker.id} (${process.pid})`, true); - }); + /** + * Worker#isDead mock + * + * @return {boolean} + */ + isDead() { + return this.process && (this.process.exitCode !== null || this.process.statusCode !== null); } - let app = require('http').createServer(); - let appssl = null; - if (Config.ssl) { - let key; - try { - key = require('path').resolve(__dirname, Config.ssl.options.key); - if (!fs.lstatSync(key).isFile()) throw new Error(); - try { - key = fs.readFileSync(key); - } catch (e) { - require('./crashlogger')(new Error(`Failed to read the configured SSL private key PEM file:\n${e.stack}`), `Socket process ${cluster.worker.id} (${process.pid})`, true); + /** + * Spawns the Go child process. Once the process has started, it will make + * a connection to the worker's TCP server. + */ + spawnChild() { + const GOPATH = child_process.execSync('go env GOPATH', {stdio: null, encoding: 'utf8'}) + .trim() + .split(path.delimiter)[0] + .replace(/^"(.*)"$/, '$1'); + + this.process = child_process.spawn( + path.resolve(GOPATH, 'bin/sockets'), [], { + env: { + PS_IPC_PORT: `:${this.server.address().port}`, + PS_CONFIG: JSON.stringify({ + workers: Config.workers || 1, + port: `:${Config.port || 8000}`, + bindAddress: Config.bindaddress || '0.0.0.0', + ssl: Config.ssl ? Object.assign({}, Config.ssl, {port: `:${Config.ssl.port}`}) : null, + }), + }, + stdio: ['inherit', 'inherit', 'pipe'], } - } catch (e) { - console.warn('SSL private key config values will not support HTTPS server option values in the future. Please set it to use the absolute path of its PEM file.'); - key = Config.ssl.options.key; - } + ); + + this.process.once('exit', /** @param {any[]} args */ (...args) => { + // Clean up the IPC server. + this.server.close(() => { + // @ts-ignore + if (this.server._eventsCount <= 2) { + // The child process died before ever opening the IPC + // connection and sending any messages over it. Let's avoid + // getting trapped in an endless loop of respawns and crashes + // if it crashed. + if (this.error) throw this.error; + } - let cert; - try { - cert = require('path').resolve(__dirname, Config.ssl.options.cert); - if (!fs.lstatSync(cert).isFile()) throw new Error(); - try { - cert = fs.readFileSync(cert); - } catch (e) { - require('./crashlogger')(new Error(`Failed to read the configured SSL certificate PEM file:\n${e.stack}`), `Socket process ${cluster.worker.id} (${process.pid})`, true); - } - } catch (e) { - console.warn('SSL certificate config values will not support HTTPS server option values in the future. Please set it to use the absolute path of its PEM file.'); - cert = Config.ssl.options.cert; - } + this.emit('exit', this, ...args); + }); + }); - if (key && cert) { - try { - // In case there are additional SSL config settings besides the key and cert... - appssl = require('https').createServer(Object.assign({}, Config.ssl.options, {key, cert})); - } catch (e) { - require('./crashlogger')(`The SSL settings are misconfigured:\n${e.stack}`, `Socket process ${cluster.worker.id} (${process.pid})`, true); - } - } + this.process.stderr.setEncoding('utf8'); + this.process.stderr.once('data', /** @param {string} data */ data => { + this.error = new Error(data); + this.emit('error', this.error); + }); } - // Static server - const StaticServer = require('node-static').Server; - const roomidRegex = /^\/(?:[A-Za-z0-9][A-Za-z0-9-]*)\/?$/; - const cssServer = new StaticServer('./config'); - const avatarServer = new StaticServer('./config/avatars'); - const staticServer = new StaticServer('./static'); - const staticRequestHandler = (req, res) => { - // console.log(`static rq: ${req.socket.remoteAddress}:${req.socket.remotePort} -> ${req.socket.localAddress}:${req.socket.localPort} - ${req.method} ${req.url} ${req.httpVersion} - ${req.rawHeaders.join('|')}`); - req.resume(); - req.addListener('end', () => { - if (Config.customhttpresponse && - Config.customhttpresponse(req, res)) { + /** + * 'connection' event handler for the TCP server. Begins the parsing of + * incoming IPC messages. + * @param {any} connection + */ + onChildConnect(connection) { + this.connection = connection; + this.connection.setEncoding('utf8'); + this.connection.on('data', /** @param {string} data */ data => { + let idx = data.lastIndexOf(DELIM); + if (idx < 0) { + this.ibuf += data; return; } - let server = staticServer; - if (req.url === '/custom.css') { - server = cssServer; - } else if (req.url.startsWith('/avatars/')) { - req.url = req.url.substr(8); - server = avatarServer; - } else if (roomidRegex.test(req.url)) { - req.url = '/'; + let messages = ''; + if (this.ibuf) { + messages += this.ibuf.slice(0); + this.ibuf = ''; } - server.serve(req, res, e => { - if (e && (e.status === 404)) { - staticServer.serveFile('404.html', 404, {}, req, res); - } - }); - }); - }; - - app.on('request', staticRequestHandler); - if (appssl) appssl.on('request', staticRequestHandler); - - // SockJS server - - // This is the main server that handles users connecting to our server - // and doing things on our server. - - const sockjs = require('sockjs'); - const server = sockjs.createServer({ - sockjs_url: "//play.pokemonshowdown.com/js/lib/sockjs-1.1.1-nwjsfix.min.js", - log: (severity, message) => { - if (severity === 'error') console.log('ERROR: ' + message); - }, - prefix: '/showdown', - }); - - const sockets = new Map(); - const channels = new Map(); - const subchannels = new Map(); - - // Deal with phantom connections. - const sweepSocketInterval = setInterval(() => { - sockets.forEach(socket => { - if (socket.protocol === 'xhr-streaming' && - socket._session && - socket._session.recv) { - socket._session.recv.didClose(); + if (idx === data.length - 1) { + messages += data.slice(0, -1); + } else { + messages += data.slice(0, idx); + this.ibuf += data.slice(idx + 1); } - // A ghost connection's `_session.to_tref._idlePrev` (and `_idleNext`) property is `null` while - // it is an object for normal users. Under normal circumstances, those properties should only be - // `null` when the timeout has already been called, but somehow it's not happening for some connections. - // Simply calling `_session.timeout_cb` (the function bound to the aformentioned timeout) manually - // on those connections kills those connections. For a bit of background, this timeout is the timeout - // that sockjs sets to wait for users to reconnect within that time to continue their session. - if (socket._session && - socket._session.to_tref && - !socket._session.to_tref._idlePrev) { - socket._session.timeout_cb(); + for (let message of messages.split(DELIM)) { + this.emit('message', JSON.parse(message)); } }); - }, 1000 * 60 * 10); - - process.on('message', data => { - // console.log('worker received: ' + data); - let socket = null; - let socketid = ''; - let channel = null; - let channelid = ''; - let subchannel = null; - let subchannelid = ''; - let nlLoc = -1; - let message = ''; - - switch (data.charAt(0)) { - case '$': // $code - eval(data.substr(1)); - break; + this.connection.on('error', () => {}); - case '!': // !socketid - // destroy - socketid = data.substr(1); - socket = sockets.get(socketid); - if (!socket) return; - socket.destroy(); - sockets.delete(socketid); - channels.forEach(channel => channel.delete(socketid)); - break; + process.nextTick(() => this.emit('listening')); + } +} - case '>': - // >socketid, message - // message - nlLoc = data.indexOf('\n'); - socketid = data.substr(1, nlLoc - 1); - socket = sockets.get(socketid); - if (!socket) return; - message = data.substr(nlLoc + 1); - socket.write(message); - break; +/** + * Worker ID counter. We don't use cluster's internal counter so + * Config.golang can be freely changed while the server is still running. + * + * @type {number} + */ +let nextWorkerid = 1; - case '#': - // #channelid, message - // message to channel - nlLoc = data.indexOf('\n'); - channelid = data.substr(1, nlLoc - 1); - channel = channels.get(channelid); - if (!channel) return; - message = data.substr(nlLoc + 1); - channel.forEach(socket => socket.write(message)); - break; +/** + * Config.golang cache. Checked when spawning new workers to + * ensure that Node and Go workers will not try to run at the same time. + * + * @type {boolean} + */ +let golangCache = !!Config.golang; - case '+': - // +channelid, socketid - // add to channel - nlLoc = data.indexOf('\n'); - socketid = data.substr(nlLoc + 1); - socket = sockets.get(socketid); - if (!socket) return; - channelid = data.substr(1, nlLoc - 1); - channel = channels.get(channelid); - if (!channel) { - channel = new Map(); - channels.set(channelid, channel); +/** + * Spawns a new worker. + * + * @return {WorkerWrapper} + */ +function spawnWorker() { + if (golangCache === !Config.golang) { + // Config settings were changed. Make sure none of the wrong kind of + // worker is already listening. + let workerType = Config.golang ? GoWorker : cluster.Worker; + for (let [workerid, worker] of workers) { + if (worker.isConnected() && !(worker.worker instanceof workerType)) { + let oldType = golangCache ? 'Go' : 'Node'; + let newType = Config.golang ? 'Go' : 'Node'; + throw new Error( + `Sockets: worker of ID ${workerid} is a ${oldType} worker, but config was changed to spawn ${newType} ones! + Set Config.golang back to ${golangCache} or kill all active workers before attempting to spawn more.` + ); } - channel.set(socketid, socket); - break; - - case '-': - // -channelid, socketid - // remove from channel - nlLoc = data.indexOf('\n'); - channelid = data.slice(1, nlLoc); - channel = channels.get(channelid); - if (!channel) return; - socketid = data.slice(nlLoc + 1); - channel.delete(socketid); - subchannel = subchannels.get(channelid); - if (subchannel) subchannel.delete(socketid); - if (!channel.size) { - channels.delete(channelid); - if (subchannel) subchannels.delete(channelid); + } + golangCache = !!Config.golang; + } else if (golangCache) { + // Prevent spawning multiple Go child processes by accident. + for (let [workerid, worker] of workers) { // eslint-disable-line no-unused-vars + if (worker.isConnected() && worker.worker instanceof GoWorker) { + throw new Error('Sockets: multiple Go child processes cannot be spawned!'); } - break; + } + } - case '.': - // .channelid, subchannelid, socketid - // move subchannel - nlLoc = data.indexOf('\n'); - channelid = data.slice(1, nlLoc); - let nlLoc2 = data.indexOf('\n', nlLoc + 1); - subchannelid = data.slice(nlLoc + 1, nlLoc2); - socketid = data.slice(nlLoc2 + 1); - - subchannel = subchannels.get(channelid); - if (!subchannel) { - subchannel = new Map(); - subchannels.set(channelid, subchannel); - } - if (subchannelid === '0') { - subchannel.delete(socketid); - } else { - subchannel.set(socketid, subchannelid); - } - break; + let worker; + if (golangCache) { + worker = new GoWorker(nextWorkerid); + } else { + worker = cluster.fork({ + PSPORT: Config.port, + PSBINDADDR: Config.bindaddress || '0.0.0.0', + PSNOSSL: Config.ssl ? 0 : 1, + }); + } - case ':': - // :channelid, message - // message to subchannel - nlLoc = data.indexOf('\n'); - channelid = data.slice(1, nlLoc); - channel = channels.get(channelid); - if (!channel) return; - - let messages = [null, null, null]; - message = data.substr(nlLoc + 1); - subchannel = subchannels.get(channelid); - channel.forEach((socket, socketid) => { - switch (subchannel ? subchannel.get(socketid) : '0') { - case '1': - if (!messages[1]) { - messages[1] = message.replace(/\n\|split\n[^\n]*\n([^\n]*)\n[^\n]*\n[^\n]*/g, '\n$1'); - } - socket.write(messages[1]); - break; - case '2': - if (!messages[2]) { - messages[2] = message.replace(/\n\|split\n[^\n]*\n[^\n]*\n([^\n]*)\n[^\n]*/g, '\n$1'); - } - socket.write(messages[2]); - break; - default: - if (!messages[0]) { - messages[0] = message.replace(/\n\|split\n([^\n]*)\n[^\n]*\n[^\n]*\n[^\n]*/g, '\n$1'); - } - socket.write(messages[0]); - break; - } - }); - break; - } - }); + let wrapper = new WorkerWrapper(worker, nextWorkerid++); + workers.set(wrapper.id, wrapper); + return wrapper; +} - // Clean up any remaining connections on disconnect. If this isn't done, - // the process will not exit until any remaining connections have been destroyed. - // Afterwards, the worker process will die on its own. - process.once('disconnect', () => { - clearInterval(sweepSocketInterval); +/** + * Initializes the configured number of worker processes. + * + * @param {any} port + * @param {any} bindAddress + * @param {any} workerCount + */ +function listen(port, bindAddress, workerCount) { + if (port !== undefined && !isNaN(port)) { + Config.port = port; + Config.ssl = null; + } else { + port = Config.port; + // Autoconfigure the app when running in cloud hosting environments: + try { + const cloudenv = require('cloud-env'); + bindAddress = cloudenv.get('IP', bindAddress); + port = cloudenv.get('PORT', port); + } catch (e) {} + } + if (bindAddress !== undefined) { + Config.bindaddress = bindAddress; + } - sockets.forEach(socket => { - try { - socket.destroy(); - } catch (e) {} - }); - sockets.clear(); - channels.clear(); - subchannels.clear(); + // Go only uses one child process since it does not share FD handles for + // serving like Node.js workers do. Workers are instead used to limit the + // number of concurrent requests that can be handled at once in the child + // process. + if (golangCache) { + spawnWorker(); + return; + } - app.close(); - if (appssl) appssl.close(); + if (workerCount === undefined) { + workerCount = (Config.workers !== undefined ? Config.workers : 1); + } + for (let i = 0; i < workerCount; i++) { + spawnWorker(); + } +} - // Let the server(s) finish closing. - setImmediate(() => process.exit(0)); - }); +/** + * Kills a worker process using the given worker object. + * + * @param {WorkerWrapper} worker + * @return {number} + */ +function killWorker(worker) { + let count = Users.socketDisconnectAll(worker); + console.log(`${count} connections were lost.`); + try { + worker.kill('SIGTERM'); + } catch (e) {} + workers.delete(worker.id); + return count; +} - // this is global so it can be hotpatched if necessary - let isTrustedProxyIp = Dnsbl.checker(Config.proxyip); - let socketCounter = 0; - server.on('connection', socket => { - // For reasons that are not entirely clear, SockJS sometimes triggers - // this event with a null `socket` argument. - if (!socket) return; - - if (!socket.remoteAddress) { - // SockJS sometimes fails to be able to cache the IP, port, and - // address from connection request headers. - try { - socket.destroy(); - } catch (e) {} - return; +/** + * Kills a worker process using the given worker PID. + * + * @param {number} pid + * @return {number | false} + */ +function killPid(pid) { + for (let [workerid, worker] of workers) { // eslint-disable-line no-unused-vars + if (pid === worker.process.pid) { + return killWorker(worker); } + } + return false; +} - let socketid = '' + (++socketCounter); - sockets.set(socketid, socket); - - let socketip = socket.remoteAddress; - if (isTrustedProxyIp(socketip)) { - let ips = (socket.headers['x-forwarded-for'] || '') - .split(',') - .reverse(); - for (let ip of ips) { - let proxy = ip.trim(); - if (!isTrustedProxyIp(proxy)) { - socketip = proxy; - break; - } - } - } +/** + * Sends a message to a socket in a given worker by ID. + * + * @param {WorkerWrapper} worker + * @param {string} socketid + * @param {string} message + */ +function socketSend(worker, socketid, message) { + worker.send(`>${socketid}\n${message}`); +} - process.send(`*${socketid}\n${socketip}\n${socket.protocol}`); +/** + * Forcefully disconnects a socket in a given worker by ID. + * + * @param {WorkerWrapper} worker + * @param {string} socketid + */ +function socketDisconnect(worker, socketid) { + worker.send(`!${socketid}`); +} - socket.on('data', message => { - // drop empty messages (DDoS?) - if (!message) return; - // drop messages over 100KB - if (message.length > (100 * 1024)) { - console.log(`Dropping client message ${message.length / 1024} KB...`); - console.log(message.slice(0, 160)); - return; - } - // drop legacy JSON messages - if (typeof message !== 'string' || message.startsWith('{')) return; - // drop blank messages (DDoS?) - let pipeIndex = message.indexOf('|'); - if (pipeIndex < 0 || pipeIndex === message.length - 1) return; +/** + * Broadcasts a message to all sockets in a given channel across + * all workers. + * + * @param {string} channelid + * @param {string} message + */ +function channelBroadcast(channelid, message) { + workers.forEach(worker => { + worker.send(`#${channelid}\n${message}`); + }); +} - process.send(`<${socketid}\n${message}`); - }); +/** + * Broadcasts a message to all sockets in a given channel and a + * given worker. + * + * @param {WorkerWrapper} worker + * @param {string} channelid + * @param {string} message + */ +function channelSend(worker, channelid, message) { + worker.send(`#${channelid}\n${message}`); +} - socket.once('close', () => { - process.send(`!${socketid}`); - sockets.delete(socketid); - channels.forEach(channel => channel.delete(socketid)); - }); - }); - server.installHandlers(app, {}); - app.listen(Config.port, Config.bindaddress); - console.log(`Worker ${cluster.worker.id} now listening on ${Config.bindaddress}:${Config.port}`); +/** + * Adds a socket to a given channel in a given worker by ID. + * + * @param {WorkerWrapper} worker + * @param {string} channelid + * @param {string} socketid + */ +function channelAdd(worker, channelid, socketid) { + worker.send(`+${channelid}\n${socketid}`); +} - if (appssl) { - server.installHandlers(appssl, {}); - appssl.listen(Config.ssl.port, Config.bindaddress); - console.log(`Worker ${cluster.worker.id} now listening for SSL on port ${Config.ssl.port}`); - } +/** + * Removes a socket from a given channel in a given worker by ID. + * + * @param {WorkerWrapper} worker + * @param {string} channelid + * @param {string} socketid + */ +function channelRemove(worker, channelid, socketid) { + worker.send(`-${channelid}\n${socketid}`); +} - console.log(`Test your server at http://${Config.bindaddress === '0.0.0.0' ? 'localhost' : Config.bindaddress}:${Config.port}`); +/** + * Broadcasts a message to be demuxed into three separate messages + * across three subchannels in a given channel across all workers. + * + * @param {string} channelid + * @param {string} message + */ +function subchannelBroadcast(channelid, message) { + workers.forEach(worker => { + worker.send(`:${channelid}\n${message}`); + }); +} - require('./repl').start(`sockets-${cluster.worker.id}-${process.pid}`, cmd => eval(cmd)); +/** + * Moves a given socket to a different subchannel in a channel by + * ID in the given worker. + * + * @param {WorkerWrapper} worker + * @param {string} channelid + * @param {string} subchannelid + * @param {string} socketid + */ +function subchannelMove(worker, channelid, subchannelid, socketid) { + worker.send(`.${channelid}\n${subchannelid}\n${socketid}`); } + +module.exports = { + WorkerWrapper, + GoWorker, + + workers, + spawnWorker, + listen, + killWorker, + killPid, + + socketSend, + socketDisconnect, + channelBroadcast, + channelSend, + channelAdd, + channelRemove, + subchannelBroadcast, + subchannelMove, +}; \ No newline at end of file diff --git a/sockets/lib/commands.go b/sockets/lib/commands.go new file mode 100644 index 000000000000..fc48c24c1dbb --- /dev/null +++ b/sockets/lib/commands.go @@ -0,0 +1,87 @@ +/** + * Commands + * https://pokemonshowdown.com/ + * + * Commands are an abstraction over IPC messages sent to and received from the + * parent process. Each message follows a specific syntax: a one character + * token, followed by any number of parametres separated by newlines. Commands + * give the multiplexer and IPC connection a simple way to determine which + * struct it's meant to be handled by, before enqueueing it to be distributed + * to workers to finally process their payload concurrently. + */ + +package sockets + +import "strings" + +// IPC message types +const ( + SOCKET_CONNECT byte = '*' + SOCKET_DISCONNECT byte = '!' + SOCKET_RECEIVE byte = '<' + SOCKET_SEND byte = '>' + CHANNEL_ADD byte = '+' + CHANNEL_REMOVE byte = '-' + CHANNEL_BROADCAST byte = '#' + SUBCHANNEL_MOVE byte = '.' + SUBCHANNEL_BROADCAST byte = ':' +) + +var PARAM_COUNTS = map[byte]int{ + SOCKET_CONNECT: 4, + SOCKET_DISCONNECT: 1, + SOCKET_RECEIVE: 2, + SOCKET_SEND: 2, + CHANNEL_ADD: 2, + CHANNEL_REMOVE: 2, + CHANNEL_BROADCAST: 2, + SUBCHANNEL_MOVE: 3, + SUBCHANNEL_BROADCAST: 2, +} + +type Command struct { + token byte // Token designating the type of command. + params []string // The command parametre list, parsed. + target CommandIO // The target to process this command. +} + +// The multiplexer and the IPC connection both implement this interface. Its +// purpose is solely to allow the two structs to be used in Command. +type CommandIO interface { + Process(*Command) error // Invokes one of its methods using the command's token and parametres. +} + +func NewCommand(msg string, target CommandIO) *Command { + token := msg[0] + count := PARAM_COUNTS[token] + params := strings.SplitN(msg[1:], "\n", count) + return &Command{ + token: token, + params: params, + target: target, + } +} + +func BuildCommand(target CommandIO, token byte, params ...string) *Command { + return &Command{ + token: token, + params: params, + target: target, + } +} + +func (c *Command) Token() byte { + return c.token +} + +func (c *Command) Params() []string { + return c.params +} + +func (c *Command) Message() string { + return string(c.token) + strings.Join(c.params, "\n") +} + +func (c *Command) Process() error { + return c.target.Process(c) +} diff --git a/sockets/lib/commands_test.go b/sockets/lib/commands_test.go new file mode 100644 index 000000000000..cb85b37fb9e3 --- /dev/null +++ b/sockets/lib/commands_test.go @@ -0,0 +1,26 @@ +package sockets + +import "testing" + +type testTarget struct { + CommandIO +} + +func TestCommands(t *testing.T) { + tokens := []byte{ + SOCKET_CONNECT, + SOCKET_DISCONNECT, + SOCKET_RECEIVE, + SOCKET_SEND, + CHANNEL_ADD, + CHANNEL_REMOVE, + CHANNEL_BROADCAST, + SUBCHANNEL_MOVE, + SUBCHANNEL_BROADCAST, + } + + cmds := make([]*Command, len(tokens)) + for i, token := range tokens { + cmds[i] = NewCommand(string(token)+"1\n2\n3\n4", testTarget{}) + } +} diff --git a/sockets/lib/config.go b/sockets/lib/config.go new file mode 100644 index 000000000000..98bb907507ab --- /dev/null +++ b/sockets/lib/config.go @@ -0,0 +1,38 @@ +/** + * Config + * https://pokemonshowdown.com/ + * + * Config is a struct representing the config settings the parent process + * passed to us by stringifying pertinent settings as JSON and assigning it to + * the $PS_CONFIG environment variable. + */ + +package sockets + +import ( + "encoding/json" + "os" +) + +type sslcert struct { + Cert string `json:"cert"` // Path to the SSL certificate. + Key string `json:"key"` // Path to the SSL key. +} + +type sslconf struct { + Port string `json:"port"` // HTTPS server port. + Options sslcert `json:"options,omitempty"` // SSL config settings. +} + +type config struct { + Workers int `json:"workers"` // Number of workers for the master to spawn. + Port string `json:"port"` // HTTP server port. + BindAddress string `json:"bindAddress"` // HTTP/HTTPS server(s) hostname. + SSL sslconf `json:"ssl,omitempty"` // HTTPS config settings. +} + +func NewConfig(envVar string) (c config, err error) { + configEnv := os.Getenv(envVar) + err = json.Unmarshal([]byte(configEnv), &c) + return +} diff --git a/sockets/lib/config_test.go b/sockets/lib/config_test.go new file mode 100644 index 000000000000..34495d055f39 --- /dev/null +++ b/sockets/lib/config_test.go @@ -0,0 +1,36 @@ +package sockets + +import ( + "encoding/json" + "testing" +) + +func TestConfig(t *testing.T) { + var c config + cj := []byte(`{"workers": 1, "port": ":8000", "bindAddress": "0.0.0.0", "ssl": null}`) + err := json.Unmarshal(cj, &c) + if err != nil { + t.Errorf("Sockets: failed to parse config JSON with SSL being null: %v", err) + } + if c.SSL.Port != "" || c.SSL.Options.Cert != "" || c.SSL.Options.Key != "" { + t.Errorf("Sockets: config failed to omit null SSL config") + } + + c.SSL = sslconf{ + Port: ":443", + Options: sslcert{ + Cert: "", + Key: "", + }, + } + + cj, _ = json.Marshal(c) + if err != nil { + t.Errorf("Sockets: failed to stringify config JSON: %v", err) + } + + err = json.Unmarshal(cj, &c) + if err != nil { + t.Errorf("Sockets: failed to parse config JSON containing SSL config") + } +} diff --git a/sockets/lib/ipc.go b/sockets/lib/ipc.go new file mode 100644 index 000000000000..36c123ce1011 --- /dev/null +++ b/sockets/lib/ipc.go @@ -0,0 +1,125 @@ +/** + * IPC - Inter-Process Communication + * https://pokemonshowdown.com/ + * + * This handles all communication between us and the parent process. The parent + * process creates a local TCP server using a random port. The port is passed + * down to us through the $PS_IPC_PORT environment variable. A TCP connection + * to the parent's server is created, allowing us to send messages back and + * forth. + */ + +package sockets + +import ( + "bufio" + "encoding/json" + "fmt" + "net" + "os" + "time" +) + +// This must be a byte that stringifies to either a hexadecimal escape code. +// Otherwise, it would be possible for someone to send a message with the +// delimiter and break up messages. +const DELIM byte = '\x03' + +type Connection struct { + addr *net.TCPAddr // Parent process' TCP server address. + conn *net.TCPConn // Connection to the parent process' TCP server. + mux *Multiplexer // Target for commands originating from here. + listening bool // Whether or not this is connected and listening for IPC messages. +} + +func NewConnection(envVar string) (*Connection, error) { + port := os.Getenv(envVar) + addr, err := net.ResolveTCPAddr("tcp", "localhost"+port) + if err != nil { + return nil, fmt.Errorf("Sockets: failed to parse TCP address to connect to the parent process with: %v", err) + } + + conn, err := net.DialTCP("tcp", nil, addr) + if err != nil { + return nil, fmt.Errorf("Sockets: failed to connect to TCP server: %v", err) + } + + c := &Connection{ + addr: addr, + conn: conn, + listening: false, + } + + return c, nil +} + +func (c *Connection) Listening() bool { + return c.listening +} + +func (c *Connection) Listen(mux *Multiplexer) { + if c.listening { + return + } + + c.mux = mux + c.listening = true + + go func() { + reader := bufio.NewReader(c.conn) + for { + token, err := reader.ReadBytes(DELIM) + if len(token) == 0 || err != nil { + continue + } + + var msg string + err = json.Unmarshal(token[:len(token)-1], &msg) + if err != nil { + continue + } + + go func() { + cmd := NewCommand(msg, c.mux) + CmdQueue <- cmd + }() + + time.Sleep(1 * time.Nanosecond) + } + }() +} + +// Final step in evaluating commands targeted at the IPC connection. +func (c *Connection) Process(cmd *Command) error { + // fmt.Printf("Sockets => IPC: %v\n", cmd.Message()) + if !c.listening { + return fmt.Errorf("Sockets: can't process connection commands when the connection isn't listening yet") + } + + msg := cmd.Message() + _, err := c.write(msg) + return err +} + +func (c *Connection) Close() error { + if !c.listening { + return nil + } + + return c.conn.Close() +} + +func (c *Connection) write(msg string) (int, error) { + if !c.listening { + return 0, fmt.Errorf("Sockets: can't write messages over a connection that isn't listening yet...") + } + + data, err := json.Marshal(msg) + if err != nil { + return 0, fmt.Errorf("Sockets: failed to parse upstream IPC message: %v", err) + } + + // The max allowed length for a message that Multiplexer.socketReceive will + // not drop is short enough for us not to need to buffer here. + return c.conn.Write(append(data, DELIM)) +} \ No newline at end of file diff --git a/sockets/lib/ipc_test.go b/sockets/lib/ipc_test.go new file mode 100644 index 000000000000..4981aaddbe9a --- /dev/null +++ b/sockets/lib/ipc_test.go @@ -0,0 +1,38 @@ +package sockets + +import ( + "net" + "os" + "testing" +) + +func TestConnection(t *testing.T) { + port := ":3000" + ln, err := net.Listen("tcp", "localhost"+port) + defer ln.Close() + if err != nil { + t.Errorf("Sockets: failed to launch TCP server on port %v: %v", port, err) + } + + envVar := "PS_IPC_PORT" + err = os.Setenv(envVar, port) + if err != nil { + t.Errorf("Sockets: failed to set %v environment variable: %v", envVar, port) + } + + conn, err := NewConnection(envVar) + defer conn.Close() + if err != nil { + t.Errorf("%v", err) + } + + mux := NewMultiplexer() + mux.Listen(conn) + conn.Listen(mux) + + cmd := BuildCommand(mux, SOCKET_SEND, "0", "|ayy lmao") + err = conn.Process(cmd) + if err != nil { + t.Errorf("%v", err) + } +} diff --git a/sockets/lib/master.go b/sockets/lib/master.go new file mode 100644 index 000000000000..4b5217067f5c --- /dev/null +++ b/sockets/lib/master.go @@ -0,0 +1,82 @@ +/** + * Master - Master/Worker pattern implementation + * https://pokemonshowdown.com/ + * + * This makes it possible to parse messages from sockets and the parent process + * concurrently. A command queue stores commands created by the multiplexer and + * IPC connection. The master contains a pool of command channels belonging to + * workers. Once a command is available in the queue, the master takes a + * worker's command channel from its pool and enqueues it. The worker takes the + * command and processes it before enqueueing its command channel back into the + * master's pool. The workers are distributed round-robin, much like Node's + * cluster module (when not using Windows). + */ + +package sockets + +// A global command channel for the multiplexer and IPC connection to enqueue +// their new commands to be processed by the workers. +var CmdQueue = make(chan *Command) + +type master struct { + wpool chan chan *Command // Pool of worker command queues. + count int // Number of workers. +} + +func NewMaster(count int) *master { + wpool := make(chan chan *Command, count) + return &master{ + wpool: wpool, + count: count, + } +} + +// Create the initial set of workers and make them listen before the master. +func (m *master) Spawn() { + for i := 0; i < m.count; i++ { + w := newWorker(m.wpool) + w.listen() + } +} + +// Listen for new commands to remove from the command queue and pass to the +// first available worker. +func (m *master) Listen() { + for { + cmd := <-CmdQueue + cmdch := <-m.wpool + cmdch <- cmd + } +} + +type worker struct { + wpool chan chan *Command // The master's pool of worker command queues. + cmdch chan *Command // Queue for incoming commands from CmdQueue. + quit chan bool // Channel used to kill the worker when needed. +} + +func newWorker(wpool chan chan *Command) *worker { + cmdch := make(chan *Command) + quit := make(chan bool) + return &worker{ + wpool: wpool, + cmdch: cmdch, + quit: quit, + } +} + +func (w *worker) listen() { + go func() { + for { + w.wpool <- w.cmdch + select { + case cmd := <-w.cmdch: + // Invokes *Multiplexer.Process or *Connection.Process, where + // the command is finally handled and used to update state. + cmd.target.Process(cmd) + case <-w.quit: + return + } + } + }() +} diff --git a/sockets/lib/master_test.go b/sockets/lib/master_test.go new file mode 100644 index 000000000000..7fdcb163ef5b --- /dev/null +++ b/sockets/lib/master_test.go @@ -0,0 +1,85 @@ +package sockets + +import ( + "net" + "net/http" + "strconv" + "testing" + + "github.com/igm/sockjs-go/sockjs" +) + +type testSocket struct { + sockjs.Session +} + +func (ts testSocket) Send(msg string) error { + return nil +} + +func (ts testSocket) Close(code uint32, signal string) error { + return nil +} + +func (ts testSocket) Request() *http.Request { + return &http.Request{} +} + +func TestMasterListen(t *testing.T) { + t.Parallel() + ln, _ := net.Listen("tcp", "localhost:3000") + defer ln.Close() + + conn, _ := NewConnection("PS_IPC_PORT") + defer conn.Close() + + mux := NewMultiplexer() + mux.Listen(conn) + conn.Listen(mux) + + m := NewMaster(4) + m.Spawn() + go m.Listen() + + for i := 0; i < m.count*250; i++ { + id := strconv.Itoa(i) + t.Run("Worker/Multiplexer command #"+id, func(t *testing.T) { + go func(id string, mux *Multiplexer, conn *Connection) { + mux.smux.Lock() + sid := strconv.FormatUint(mux.nsid, 10) + mux.sockets[sid] = testSocket{} + mux.nsid++ + mux.smux.Unlock() + + cmd := BuildCommand(mux, SOCKET_DISCONNECT, sid) + cmd.Process() + if len(CmdQueue) != 0 { + t.Error("Sockets: master failed to pass command struct from worker to multiplexer") + } + + mux.socketRemove(sid, true) + }(id, mux, conn) + }) + t.Run("Worker/Connection command #"+id, func(t *testing.T) { + go func(id string, mux *Multiplexer, conn *Connection) { + mux.smux.Lock() + sid := strconv.FormatUint(mux.nsid, 10) + mux.sockets[sid] = testSocket{} + mux.nsid++ + mux.smux.Unlock() + + cmd := BuildCommand(conn, SOCKET_CONNECT, sid, "0.0.0.0", "", "websocket") + cmd.Process() + if len(CmdQueue) != 0 { + t.Error("Sockets: master failed to pass command struct from worker to connection") + } + + mux.socketRemove(sid, true) + }(id, mux, conn) + }) + } + + for len(m.wpool) > 0 { + <-m.wpool + } +} diff --git a/sockets/lib/multiplexer.go b/sockets/lib/multiplexer.go new file mode 100644 index 000000000000..265ca6b5821f --- /dev/null +++ b/sockets/lib/multiplexer.go @@ -0,0 +1,371 @@ +/** + * Multiplexer - Socket/Channel/Subchannel state machine + * https://pokemonshowdown.com/ + * + * This keeps track of the sockets that connect to the SockJS server. Sockets + * are stored in the multiplexer to allow the parent process to manipulate them + * as it pleases. Channels represent rooms in the parent process; subchannels + * split battle rooms into three groups: side 1, side 2, and spectators. + * Certain messages will display differently depending on which subchannel the + * user's socket is in. + */ + +package sockets + +import ( + "fmt" + "net" + "path" + "regexp" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/igm/sockjs-go/sockjs" +) + +// Subchannel IDs +const ( + DEFAULT_SUBCHANNEL_ID byte = '0' + P1_SUBCHANNEL_ID byte = '1' + P2_SUBCHANNEL_ID byte = '2' +) + +// Map of socket IDs to subchannel IDs. +type Channel map[string]byte + +type Multiplexer struct { + nsid uint64 // Socket ID counter. + sockets map[string]sockjs.Session // Map of socket IDs to sockets. + smux sync.RWMutex // nsid and sockets mutex. + channels map[string]Channel // Map of channel (i.e. room) IDs to channels. + cmux sync.RWMutex // channels mutex. + scre *regexp.Regexp // Regex for splitting subchannel broadcasts into their three messages. + conn *Connection // Target for commands originating from here. +} + +func NewMultiplexer() *Multiplexer { + sockets := make(map[string]sockjs.Session) + channels := make(map[string]Channel) + scre := regexp.MustCompile(`\|split\n([^\n]*)\n([^\n]*)\n([^\n]*)\n[^\n]*`) + return &Multiplexer{ + sockets: sockets, + channels: channels, + scre: scre, + } +} + +func (m *Multiplexer) Listen(conn *Connection) { + m.conn = conn +} + +func (m *Multiplexer) Process(cmd *Command) (err error) { + // fmt.Printf("IPC => Sockets: %v\n", cmd.Message()) + params := cmd.Params() + + // Parse the command's params and call the appropriate method. + switch token := cmd.Token(); token { + case SOCKET_DISCONNECT: + sid := params[0] + err = m.socketRemove(sid, true) + case SOCKET_SEND: + sid := params[0] + msg := params[1] + err = m.socketSend(sid, msg) + case SOCKET_RECEIVE: + sid := params[0] + msg := params[1] + err = m.socketReceive(sid, msg) + case CHANNEL_ADD: + cid := params[0] + sid := params[1] + err = m.channelAdd(cid, sid) + case CHANNEL_REMOVE: + cid := params[0] + sid := params[1] + err = m.channelRemove(cid, sid) + case CHANNEL_BROADCAST: + cid := params[0] + msg := params[1] + err = m.channelBroadcast(cid, msg) + case SUBCHANNEL_MOVE: + cid := params[0] + scid := params[1][0] + sid := params[2] + err = m.subchannelMove(cid, scid, sid) + case SUBCHANNEL_BROADCAST: + cid := params[0] + msg := params[1] + err = m.subchannelBroadcast(cid, msg) + default: + err = fmt.Errorf("Sockets: received unknown message of type %v: %v", cmd.Token(), cmd.Message()) + } + + if err != nil { + // Something went wrong somewhere, but it's likely a timing issue from + // the parent process. Let's just log the error instead of crashing. + fmt.Printf("%v\n", err) + } + + return +} + +func (m *Multiplexer) socketAdd(s sockjs.Session) (sid string) { + nsid := atomic.LoadUint64(&m.nsid) + sid = strconv.FormatUint(nsid, 10) + atomic.AddUint64(&m.nsid, 1) + + m.smux.Lock() + m.sockets[sid] = s + m.smux.Unlock() + + if m.conn.Listening() { + req := s.Request() + ip, _, _ := net.SplitHostPort(req.RemoteAddr) + ips := req.Header.Get("X-Forwarded-For") + protocol := path.Base(req.URL.Path) + + go func() { + cmd := BuildCommand(m.conn, SOCKET_CONNECT, sid, ip, ips, protocol) + CmdQueue <- cmd + }() + + time.Sleep(1 * time.Nanosecond) + } + + return +} + +func (m *Multiplexer) socketRemove(sid string, forced bool) error { + m.cmux.Lock() + for cid, c := range m.channels { + if _, ok := c[sid]; ok { + delete(c, sid) + if len(c) == 0 { + delete((*m).channels, cid) + } + } + } + m.cmux.Unlock() + + m.smux.Lock() + defer m.smux.Unlock() + + s, ok := m.sockets[sid] + if ok { + delete((*m).sockets, sid) + } else { + return fmt.Errorf("Sockets: attempted to remove non-existent socket of ID %v", sid) + } + + if forced { + s.Close(1000, "Normal closure") + } else { + // User-initiated disconnect. Poke the parent process to clean up. + if m.conn.Listening() { + go func() { + cmd := BuildCommand(m.conn, SOCKET_DISCONNECT, sid) + CmdQueue <- cmd + }() + + time.Sleep(1 * time.Nanosecond) + } + } + + return nil +} + +func (m *Multiplexer) socketReceive(sid string, msg string) error { + m.smux.RLock() + defer m.smux.RUnlock() + + if _, ok := m.sockets[sid]; ok { + // Drop empty messages (DDOS?). + if len(msg) == 0 { + return nil + } + + // Drop >100KB messages. + if len(msg) > 100*1024 { + fmt.Printf("Dropping client message %vKB...\n%v\n", len(msg)/1024, msg[:160]) + return nil + } + + // Drop legacy JSON messages. + if strings.HasPrefix(msg, "{") { + return nil + } + + // Drop invalid messages (again, DDOS?). + if strings.HasSuffix(msg, "|") || !strings.Contains(msg, "|") { + return nil + } + + if m.conn.Listening() { + go func() { + cmd := BuildCommand(m.conn, SOCKET_RECEIVE, sid, msg) + CmdQueue <- cmd + }() + + time.Sleep(1 * time.Nanosecond) + } + + return nil + } + + // This should never happen. If it does, it's likely a SockJS bug. + return fmt.Errorf("Sockets: received message for a non-existent socket of ID %v: %v", sid, msg) +} + +func (m *Multiplexer) socketSend(sid string, msg string) error { + m.smux.RLock() + defer m.smux.RUnlock() + + if s, ok := m.sockets[sid]; ok { + s.Send(msg) + return nil + } + + return fmt.Errorf("Sockets: attempted to send to non-existent socket of ID %v: %v", sid, msg) +} + +func (m *Multiplexer) channelAdd(cid string, sid string) error { + m.cmux.Lock() + defer m.cmux.Unlock() + + c, ok := m.channels[cid] + if !ok { + c = make(Channel) + m.channels[cid] = c + } + + c[sid] = DEFAULT_SUBCHANNEL_ID + + return nil +} + +func (m *Multiplexer) channelRemove(cid string, sid string) error { + m.cmux.Lock() + defer m.cmux.Unlock() + + c, ok := m.channels[cid] + if ok { + if _, ok = c[sid]; !ok { + return fmt.Errorf("Sockets: failed to remove non-existent socket of ID %v from channel %v", sid, cid) + } + } else { + // This happens on user-initiated disconnect. Mitigate until this race + // condition is fixed. + return nil + } + + delete(c, sid) + if len(c) == 0 { + delete((*m).channels, cid) + } + + return nil +} + +func (m *Multiplexer) channelBroadcast(cid string, msg string) error { + m.cmux.RLock() + defer m.cmux.RUnlock() + + c, ok := m.channels[cid] + if !ok { + // This happens occasionally when the sole user in a room leaves. + // Mitigate until this race condition is fixed. + return nil + } + + m.smux.RLock() + defer m.smux.RUnlock() + + for sid := range c { + var s sockjs.Session + if s, ok = m.sockets[sid]; ok { + s.Send(msg) + } else { + return fmt.Errorf("Sockets: attempted to broadcast to non-existent socket of ID %v in channel %v: %v", sid, cid, msg) + } + } + + return nil +} + +func (m *Multiplexer) subchannelMove(cid string, scid byte, sid string) error { + m.cmux.Lock() + defer m.cmux.Unlock() + + c, ok := m.channels[cid] + if !ok { + return fmt.Errorf("Sockets: attempted to move socket of ID %v in non-existent channel %v to subchannel %v", sid, cid, scid) + } + + c[sid] = scid + return nil +} + +func (m *Multiplexer) subchannelBroadcast(cid string, msg string) error { + m.cmux.RLock() + defer m.cmux.RUnlock() + + c, ok := m.channels[cid] + if !ok { + return fmt.Errorf("Sockets: attempted to broadcast to subchannels in channel %v, which doesn't exist: %v", cid, msg) + } + + m.smux.RLock() + defer m.smux.RUnlock() + + msgs := make(map[byte]string) + for sid, scid := range c { + s, ok := m.sockets[sid] + if !ok { + return fmt.Errorf("Sockets: attempted to broadcast to subchannels in channel %v, but socket of ID %v doesn't exist: %v", cid, sid, msg) + } + + if _, ok := msgs[scid]; !ok { + switch scid { + case DEFAULT_SUBCHANNEL_ID: + msgs[scid] = m.scre.ReplaceAllString(msg, "$1") + case P1_SUBCHANNEL_ID: + msgs[scid] = m.scre.ReplaceAllString(msg, "$2") + case P2_SUBCHANNEL_ID: + msgs[scid] = m.scre.ReplaceAllString(msg, "$3") + } + } + + s.Send(msgs[scid]) + } + + return nil +} + +// This is the HTTP handler for the SockJS server. This is where new sockets +// arrive for us to use. +func (m *Multiplexer) Handler(s sockjs.Session) { + sid := m.socketAdd(s) + for { + msg, err := s.Recv() + if err != nil { + if err == sockjs.ErrSessionNotOpen { + // User disconnected. + } else { + fmt.Printf("Sockets: SockJS error on message receive for socket of ID %v: %v\n", sid, err) + } + break + } + + if err = m.socketReceive(sid, msg); err != nil { + fmt.Printf("%v\n", err) + break + } + } + + if err := m.socketRemove(sid, false); err != nil { + fmt.Printf("%v\n", err) + } +} diff --git a/sockets/lib/multiplexer_test.go b/sockets/lib/multiplexer_test.go new file mode 100644 index 000000000000..66ba1558cdb6 --- /dev/null +++ b/sockets/lib/multiplexer_test.go @@ -0,0 +1,131 @@ +package sockets + +import ( + "fmt" + "net" + "testing" +) + +func TestMultiplexer(t *testing.T) { + port := ":3000" + ln, _ := net.Listen("tcp", "localhost"+port) + defer ln.Close() + + conn, _ := NewConnection("PS_IPC_PORT") + defer conn.Close() + mux := NewMultiplexer() + mux.Listen(conn) + // Do not make the connection listen. + + ts := testSocket{} + t.Run("socketAdd", func(t *testing.T) { + sid := mux.socketAdd(ts) + if len(mux.sockets) != 1 { + t.Errorf("Sockets: expected sockets length to be %v, but is actually %v", 1, len(mux.sockets)) + } + delete((*mux).sockets, sid) + }) + t.Run("socketRemove", func(t *testing.T) { + sid := mux.socketAdd(ts) + if err := mux.socketRemove(sid, true); err != nil { + t.Errorf("%v", err) + } + if len(mux.sockets) != 0 { + t.Errorf("Sockets: expected sockets length to be %v, but is actually %v", 0, len(mux.sockets)) + } + if err := mux.socketRemove(sid, true); err == nil { + t.Errorf("Sockets: did not remove socket of ID %v on socket remove", sid) + } + + sid = mux.socketAdd(ts) + if err := mux.socketRemove(sid, false); err != nil { + t.Errorf("%v", err) + } + if err := mux.socketRemove(sid, false); err == nil { + t.Errorf("Sockets: did not remove socket of ID %v on socket remove", sid) + } + }) + t.Run("socketSend", func(t *testing.T) { + sid := mux.socketAdd(ts) + if err := mux.socketSend(sid, ">global\n|ayy lmao"); err != nil { + t.Errorf("%v", err) + } + mux.socketRemove(sid, true) + }) + t.Run("channelAdd", func(t *testing.T) { + sid := mux.socketAdd(ts) + cid := "global" + if err := mux.channelAdd(cid, sid); err != nil { + t.Errorf("%v", err) + } + if len(mux.channels) != 1 { + t.Errorf("Sockets: expected channels length to be %v, but is actually %v", 1, len(mux.channels)) + } + if err := mux.channelAdd(cid, sid); err != nil { + t.Errorf("%v", err) + } + mux.channelRemove(cid, sid) + mux.socketRemove(sid, true) + }) + t.Run("channelRemove", func(t *testing.T) { + sid := mux.socketAdd(ts) + cid := "global" + mux.channelAdd(cid, sid) + if err := mux.channelRemove(cid, sid); err != nil { + t.Errorf("%v", err) + } + if len(mux.channels) != 0 { + t.Errorf("Sockets: expected channels length to be %v, but is actually %v", 0, len(mux.channels)) + } + if err := mux.channelRemove(cid, sid); err != nil { + t.Errorf("%v", err) + } + mux.socketRemove(sid, true) + }) + t.Run("channelBroadcast", func(t *testing.T) { + sid := mux.socketAdd(ts) + cid := "global" + mux.channelAdd(cid, sid) + if err := mux.channelBroadcast(cid, "|raw|ayy lmao"); err != nil { + t.Errorf("%v", err) + } + mux.channelRemove(cid, sid) + if err := mux.channelBroadcast(cid, "|raw|ayy lmao"); err != nil { + t.Errorf("%v", err) + } + mux.socketRemove(sid, true) + }) + t.Run("subchannelMove", func(t *testing.T) { + sid := mux.socketAdd(ts) + cid := "global" + mux.channelAdd(cid, sid) + if err := mux.subchannelMove(cid, P1_SUBCHANNEL_ID, sid); err != nil { + t.Errorf("%v", err) + } + if scid := mux.channels[cid][sid]; scid != P1_SUBCHANNEL_ID { + t.Errorf("Sockets: expected subchannel for socket of ID %v in channel %v to be %v, but is actually %v", sid, cid, P1_SUBCHANNEL_ID, scid) + } + mux.channelRemove(cid, sid) + mux.socketRemove(sid, true) + }) + t.Run("subchannelBroadcast", func(t *testing.T) { + msg := "|split\n0\n1\n2\n|\n|split\n3\n4\n5\n|" + scids := []byte{DEFAULT_SUBCHANNEL_ID, P1_SUBCHANNEL_ID, P2_SUBCHANNEL_ID} + for idx, scid := range scids { + amsg := mux.scre.ReplaceAllString(msg, fmt.Sprintf("$%v", idx+1)) + if emsg := fmt.Sprintf("%v\n%v", idx, idx+3); emsg != amsg { + t.Errorf("Sockets: expected broadcast to subchannel of ID %v to be %v, but is actually %v", string(scid), emsg, amsg) + } + } + + sid := mux.socketAdd(ts) + cid := "global" + mux.channelAdd(cid, sid) + mux.subchannelMove(cid, P1_SUBCHANNEL_ID, sid) + if err := mux.subchannelBroadcast(cid, msg); err != nil { + t.Errorf("%v", err) + } + mux.channelRemove(cid, sid) + mux.socketRemove(sid, true) + }) +} diff --git a/sockets/main.go b/sockets/main.go new file mode 100644 index 000000000000..63a64de1f8f9 --- /dev/null +++ b/sockets/main.go @@ -0,0 +1,125 @@ +package main + +import ( + "crypto/tls" + "io/ioutil" + "log" + "net" + "net/http" + "path/filepath" + + "github.com/Zarel/Pokemon-Showdown/sockets/lib" + + "github.com/gorilla/mux" + "github.com/igm/sockjs-go/sockjs" +) + +func main() { + // Parse our config settings passed through the $PS_CONFIG environment + // variable by the parent process. + config, err := sockets.NewConfig("PS_CONFIG") + if err != nil { + log.Fatalf("Sockets: failed to read parent's config settings from environment: %v") + } + + // Instantiate the socket multiplexer and IPC struct. + smux := sockets.NewMultiplexer() + conn, err := sockets.NewConnection("PS_IPC_PORT") + if err != nil { + log.Fatalf("%v", err) + } + defer conn.Close() + + // Begin listening for incoming messages from sockets and the TCP + // connection to the parent process. For now, they'll just get enqueued + // for workers to manage later. + smux.Listen(conn) + conn.Listen(smux) + + // Set up routing. + r := mux.NewRouter() + + opts := sockjs.Options{ + SockJSURL: "//play.pokemonshowdown.com/js/lib/sockjs-1.1.1-nwjsfix.min.js", + Websocket: true, + HeartbeatDelay: sockjs.DefaultOptions.HeartbeatDelay, + DisconnectDelay: sockjs.DefaultOptions.DisconnectDelay, + JSessionID: sockjs.DefaultOptions.JSessionID, + } + + r.PathPrefix("/showdown"). + Handler(sockjs.NewHandler("/showdown", opts, smux.Handler)) + + customCssDir, _ := filepath.Abs("./config") + r.Handle("/custom.css", http.FileServer(http.Dir(customCssDir))) + + avatarDir, _ := filepath.Abs("./config/avatars") + r.PathPrefix("/avatars/"). + Handler(http.StripPrefix("/avatars/", http.FileServer(http.Dir(avatarDir)))) + + indexPath, _ := filepath.Abs("./static/index.html") + r.PathPrefix("/{roomid:[A-Za-z0-9][A-Za-z0-9-]*}"). + HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.ServeFile(w, r, indexPath) + }) + + notFoundPath, _ := filepath.Abs("./static/404.html") + notFoundPage, _ := ioutil.ReadFile(notFoundPath) + r.NotFoundHandler = + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + w.Write(notFoundPage) + }) + + staticDir, _ := filepath.Abs("./static") + r.Handle("/", http.FileServer(http.Dir(staticDir))) + + // Begin serving over HTTP. + go func(ba string, port string) { + addr, err := net.ResolveTCPAddr("tcp4", ba+port) + if err != nil { + log.Fatalf("Sockets: failed to resolve the TCP address of the parent's server: %v", err) + } + + ln, err := net.ListenTCP("tcp4", addr) + defer ln.Close() + if err != nil { + log.Fatalf("Sockets: failed to listen over HTTP: %v", err) + } + + err = http.Serve(ln, r) + log.Fatalf("Sockets: HTTP server failed: %v", err) + }(config.BindAddress, config.Port) + + // Begin serving over HTTPS if configured to do so. + if config.SSL.Options.Cert != "" && config.SSL.Options.Key != "" { + go func(ba string, port string, cert string, key string) { + certs, err := tls.LoadX509KeyPair(cert, key) + if err != nil { + log.Fatalf("Sockets: failed to load certificate and key files for TLS: %v", err) + } + + srv := &http.Server{ + Handler: r, + Addr: ba + port, + TLSConfig: &tls.Config{Certificates: []tls.Certificate{certs}}, + } + + var ln net.Listener + ln, err = tls.Listen("tcp4", srv.Addr, srv.TLSConfig) + if err != nil { + log.Fatalf("Sockets: failed to listen over HTTPS: %v", err) + } + + defer ln.Close() + err = http.Serve(ln, r) + log.Fatalf("Sockets: HTTPS server failed: %v", err) + }(config.BindAddress, config.SSL.Port, config.SSL.Options.Cert, config.SSL.Options.Key) + } + + // Finally, spawn workers.to pipe messages received at the multiplexer or + // IPC connection to each other concurrently. + master := sockets.NewMaster(config.Workers) + master.Spawn() + master.Listen() +} diff --git a/test/application/sockets.js b/test/application/sockets.js index 7fc963ebe9bf..7c3182e55cbd 100644 --- a/test/application/sockets.js +++ b/test/application/sockets.js @@ -1,212 +1,96 @@ 'use strict'; const assert = require('assert'); -const cluster = require('cluster'); -describe.skip('Sockets', function () { - const spawnWorker = () => ( - new Promise(resolve => { - let worker = Sockets.spawnWorker(); - worker.removeAllListeners('message'); - resolve(worker); - }) - ); +let sockets; +describe('Sockets workers', function () { before(function () { - cluster.settings.silent = true; - cluster.removeAllListeners('disconnect'); + sockets = require('../../sockets-workers'); + + this.mux = new sockets.Multiplexer(); + clearInterval(this.mux.cleanupInterval); + this.mux.cleanupInterval = null; + + this.socket = require('../../dev-tools/sockets').createSocket(); }); afterEach(function () { - Sockets.workers.forEach((worker, workerid) => { - worker.kill(); - Sockets.workers.delete(workerid); - }); + this.mux.socketCounter = 0; + this.mux.sockets.clear(); + this.mux.channels.clear(); + }); + + after(function () { + this.mux.tryDestroySocket(this.socket); + this.socket = null; + this.mux = null; + }); + + it('should parse more than two params', function () { + let params = '1\n1\n0\n'; + let ret = this.mux.parseParams(params, 4); + assert.deepStrictEqual(ret, ['1', '1', '0', '']); + }); + + it('should parse params with multiple newlines', function () { + let params = '0\n|1\n|2'; + let ret = this.mux.parseParams(params, 2); + assert.deepStrictEqual(ret, ['0', '|1\n|2']); + }); + + it('should add sockets on connect', function () { + let res = this.mux.onSocketConnect(this.socket); + assert.ok(res); + }); + + it('should remove sockets on disconnect', function () { + this.mux.onSocketConnect(this.socket); + let res = this.mux.onSocketDisconnect('0', this.socket); + assert.ok(res); + }); + + it('should add sockets to channels', function () { + this.mux.onSocketConnect(this.socket); + let res = this.mux.onChannelAdd('global', '0'); + assert.ok(res); + res = this.mux.onChannelAdd('global', '0'); + assert.ok(!res); + this.mux.channels.set('lobby', new Map()); + res = this.mux.onChannelAdd('lobby', '0'); + assert.ok(res); + }); + + it('should remove sockets from channels', function () { + this.mux.onSocketConnect(this.socket); + this.mux.onChannelAdd('global', '0'); + let res = this.mux.onChannelRemove('global', '0'); + assert.ok(res); + res = this.mux.onChannelRemove('global', '0'); + assert.ok(!res); }); - describe('master', function () { - it('should be able to spawn workers', function () { - Sockets.spawnWorker(); - assert.strictEqual(Sockets.workers.size, 1); - }); - - it('should be able to spawn workers on listen', function () { - Sockets.listen(0, '127.0.0.1', 1); - assert.strictEqual(Sockets.workers.size, 1); - }); - - it('should be able to kill workers', function () { - return spawnWorker().then(worker => { - Sockets.killWorker(worker); - assert.strictEqual(Sockets.workers.size, 0); - }); - }); - - it('should be able to kill workers by PID', function () { - return spawnWorker().then(worker => { - Sockets.killPid(worker.process.pid); - assert.strictEqual(Sockets.workers.size, 0); - }); - }); + it('should move sockets to subchannels', function () { + this.mux.onSocketConnect(this.socket); + this.mux.onChannelAdd('global', '0'); + let res = this.mux.onSubchannelMove('global', '1', '0'); + assert.ok(res); }); - describe('workers', function () { - // This composes a sequence of HOFs that send a message to a worker, - // wait for its response, then return the worker for the next function - // to use. - const chain = (eventHandler, msg) => worker => { - worker.once('message', eventHandler(worker)); - msg = msg || `$ - const {Session} = require('sockjs/lib/transport'); - const socket = new Session('aaaaaaaa', server); - socket.remoteAddress = '127.0.0.1'; - if (!('headers' in socket)) socket.headers = {}; - socket.headers['x-forwarded-for'] = ''; - socket.protocol = 'websocket'; - socket.write = msg => process.send(msg); - server.emit('connection', socket);`; - worker.send(msg); - return worker; - }; - - const spawnSocket = eventHandler => spawnWorker().then(chain(eventHandler)); - - it('should allow sockets to connect', function () { - return spawnSocket(worker => data => { - let cmd = data.charAt(0); - let [sid, ip, protocol] = data.substr(1).split('\n'); - assert.strictEqual(cmd, '*'); - assert.strictEqual(sid, '1'); - assert.strictEqual(ip, '127.0.0.1'); - assert.strictEqual(protocol, 'websocket'); - }); - }); - - it('should allow sockets to disconnect', function () { - let querySocket; - return spawnSocket(worker => data => { - let sid = data.substr(1, data.indexOf('\n')); - querySocket = `$ - let socket = sockets.get(${sid}); - process.send(!socket);`; - Sockets.socketDisconnect(worker, sid); - }).then(chain(worker => data => { - assert.ok(data); - }, querySocket)); - }); - - it('should allow sockets to send messages', function () { - let msg = 'ayy lmao'; - let socketSend; - return spawnSocket(worker => data => { - let sid = data.substr(1, data.indexOf('\n')); - socketSend = `>${sid}\n${msg}`; - }).then(chain(worker => data => { - assert.strictEqual(data, msg); - }, socketSend)); - }); - - it('should allow sockets to receive messages', function () { - let sid; - let msg; - let mockReceive; - return spawnSocket(worker => data => { - sid = data.substr(1, data.indexOf('\n')); - msg = '|/cmd rooms'; - mockReceive = `$ - let socket = sockets.get(${sid}); - socket.emit('data', ${msg});`; - }).then(chain(worker => data => { - let cmd = data.charAt(0); - let params = data.substr(1).split('\n'); - assert.strictEqual(cmd, '<'); - assert.strictEqual(sid, params[0]); - assert.strictEqual(msg, params[1]); - }, mockReceive)); - }); - - it('should create a channel for the first socket to get added to it', function () { - let queryChannel; - return spawnSocket(worker => data => { - let sid = data.substr(1, data.indexOf('\n')); - let cid = 'global'; - queryChannel = `$ - let channel = channels.get(${cid}); - process.send(channel && channel.has(${sid}));`; - Sockets.channelAdd(worker, cid, sid); - }).then(chain(worker => data => { - assert.ok(data); - }, queryChannel)); - }); - - it('should remove a channel if the last socket gets removed from it', function () { - let queryChannel; - return spawnSocket(worker => data => { - let sid = data.substr(1, data.indexOf('\n')); - let cid = 'global'; - queryChannel = `$ - process.send(!sockets.has(${sid}) && !channels.has(${cid}));`; - Sockets.channelAdd(worker, cid, sid); - Sockets.channelRemove(worker, cid, sid); - }).then(chain(worker => data => { - assert.ok(data); - }, queryChannel)); - }); - - it('should send to all sockets in a channel', function () { - let msg = 'ayy lmao'; - let cid = 'global'; - let channelSend = `#${cid}\n${msg}`; - return spawnSocket(worker => data => { - let sid = data.substr(1, data.indexOf('\n')); - Sockets.channelAdd(worker, cid, sid); - }).then(chain(worker => data => { - assert.strictEqual(data, msg); - }, channelSend)); - }); - - it('should create a subchannel when moving a socket to it', function () { - let querySubchannel; - return spawnSocket(worker => data => { - let sid = data.substr(1, data.indexOf('\n')); - let cid = 'battle-ou-1'; - let scid = '1'; - querySubchannel = `$ - let subchannel = subchannels[${cid}]; - process.send(!!subchannel && (subchannel.get(${sid}) === ${scid}));`; - Sockets.subchannelMove(worker, cid, scid, sid); - }).then(chain(worker => data => { - assert.ok(data); - }, querySubchannel)); - }); - - it('should remove a subchannel when removing its last socket', function () { - let querySubchannel; - return spawnSocket(worker => data => { - let sid = data.substr(1, data.indexOf('\n')); - let cid = 'battle-ou-1'; - let scid = '1'; - querySubchannel = `$ - let subchannel = subchannels.get(${cid}); - process.send(!!subchannel && (subchannel.get(${sid}) === ${scid}));`; - Sockets.subchannelMove(worker, cid, scid, sid); - Sockets.channelRemove(worker, cid, sid); - }).then(chain(worker => data => { - assert.ok(data); - }, querySubchannel)); - }); - - it('should send to sockets in a subchannel', function () { - let cid = 'battle-ou-1'; - let msg = 'ayy lmao'; - let subchannelSend = `.${cid}\n\n|split\n\n${msg}\n\n`; - return spawnSocket(worker => data => { - let sid = data.substr(1, data.indexOf('\n')); - let scid = '1'; - Sockets.subchannelMove(worker, cid, scid, sid); - }).then(chain(worker => data => { - assert.strictEqual(data, msg); - }, subchannelSend)); - }); + it('should broadcast to subchannels', function () { + let messages = '|split\n0\n1\n2\n|\n|split\n3\n4\n5\n|'; + for (let i = 0; i < 3; i++) { + let message = messages.replace(sockets.SUBCHANNEL_MESSAGE_REGEX, `$${i + 1}`); + assert.strictEqual(message, `${i}\n${i + 3}`); + } + + this.mux.onSocketConnect(this.socket); + this.mux.onChannelAdd('global', '0'); + this.mux.onSubchannelMove('global', '1', '0'); + let res = this.mux.onSubchannelBroadcast('global', messages); + assert.ok(res); + this.mux.onChannelRemove('global', '0'); + res = this.mux.onSubchannelBroadcast('global', messages); + assert.ok(!res); }); }); diff --git a/tsconfig.json b/tsconfig.json index c18e73cfacaf..885538cca7c2 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -15,11 +15,13 @@ "./sim/prng.js", "./crashlogger.js", "./dnsbl.js", + "./fs.js", "./ladders-matchmaker.js", "./monitor.js", - "./repl.js", - "./fs.js", "./process-manager.js", + "./repl.js", + "./sockets.js", + "./sockets-workers.js", "./verifier.js" ] } diff --git a/users.js b/users.js index 61a886c57cb0..5114cc4c02ae 100644 --- a/users.js +++ b/users.js @@ -1545,23 +1545,32 @@ Users.pruneInactiveTimer = setInterval(() => { * Routing *********************************************************/ +/** + * Creates a user and connection object for a new socket and sends a challenge + * string to the user for authentication. + * @param {WorkerWrapper} worker + * @param {number} workerid + * @param {string} socketid + * @param {string} ip + * @param {string} protocol + */ Users.socketConnect = function (worker, workerid, socketid, ip, protocol) { - let id = '' + workerid + '-' + socketid; + let id = `${workerid}-${socketid}`; let connection = new Connection(id, worker, socketid, null, ip, protocol); connections.set(id, connection); let banned = Punishments.checkIpBanned(connection); - if (banned) { - return connection.destroy(); - } + if (banned) return connection.destroy(); + // Emergency mode connections logging if (Config.emergency) { - FS('logs/cons.emergency.log').append('[' + ip + ']\n'); + FS('logs/cons.emergency.log').append(`[${ip}]\n`); } let user = new User(connection); connection.user = user; Punishments.checkIp(user, connection); + // Generate 1024-bit challenge string. require('crypto').randomBytes(128, (err, buffer) => { if (err) { @@ -1582,17 +1591,29 @@ Users.socketConnect = function (worker, workerid, socketid, ip, protocol) { user.joinRoom('global', connection); }; +/** + * Forcefully disconnects a socket. + * @param {WorkerWrapper} worker + * @param {number} workerid + * @param {string} socketid + */ Users.socketDisconnect = function (worker, workerid, socketid) { - let id = '' + workerid + '-' + socketid; - + let id = `${workerid}-${socketid}`; let connection = connections.get(id); if (!connection) return; + connection.onDisconnect(); }; -Users.socketReceive = function (worker, workerid, socketid, message) { - let id = '' + workerid + '-' + socketid; - +/** + * Parses a chat message received by a socket. + * @param {WorkerWrapper} worker + * @param {number} workerid + * @param {string} socketid + * @param {string} data + */ +Users.socketReceive = function (worker, workerid, socketid, data) { + let id = `${workerid}-${socketid}`; let connection = connections.get(id); if (!connection) return; @@ -1601,36 +1622,31 @@ Users.socketReceive = function (worker, workerid, socketid, message) { // `data` event. To prevent this, we log exceptions and prevent them // from propagating out of this function. - // drop legacy JSON messages - if (message.charAt(0) === '{') return; - - // drop invalid messages without a pipe character - let pipeIndex = message.indexOf('|'); - if (pipeIndex < 0) return; - - const user = connection.user; + let {user} = connection; if (!user) return; // The client obviates the room id when sending messages to Lobby by default - const roomId = message.substr(0, pipeIndex) || (Rooms.lobby || Rooms.global).id; - message = message.slice(pipeIndex + 1); - - const room = Rooms(roomId); + let pipeIndex = data.indexOf('|'); + let roomid = data.substr(0, pipeIndex) || (Rooms.lobby || Rooms.global).id; + let message = data.slice(pipeIndex + 1); + let room = Rooms(roomid); if (!room) return; + if (Chat.multiLinePattern.test(message)) { user.chat(message, room, connection); return; } - const lines = message.split('\n'); + let lines = message.split('\n'); if (!lines[lines.length - 1]) lines.pop(); if (lines.length > (user.isStaff ? THROTTLE_MULTILINE_WARN_STAFF : THROTTLE_MULTILINE_WARN)) { connection.popup(`You're sending too many lines at once. Try using a paste service like [[Pastebin]].`); return; } + // Emergency logging if (Config.emergency) { - FS('logs/emergency.log').append(`[${user} (${connection.ip})] ${roomId}|${message}\n`); + FS('logs/emergency.log').append(`[${user} (${connection.ip})] ${data}\n`); } let startTime = Date.now(); @@ -1639,6 +1655,23 @@ Users.socketReceive = function (worker, workerid, socketid, message) { } let deltaTime = Date.now() - startTime; if (deltaTime > 1000) { - Monitor.warn(`[slow] ${deltaTime}ms - ${user.name} <${connection.ip}>: ${roomId}|${message}`); + Monitor.warn(`[slow] ${deltaTime}ms - ${user.name} <${connection.ip}>: ${data}`); } }; + +/** + * Clears all connections whose sockets were contained by a + * worker. Called after a worker's process crashes or gets killed. + * @param {WorkerWrapper} worker + * @return {number} + */ +Users.socketDisconnectAll = function (worker) { + let count = 0; + connections.forEach(connection => { + if (connection.worker === worker) { + Users.socketDisconnect(worker, worker.id, connection.socketid); + count++; + } + }); + return count; +};