diff --git a/config/config-example.js b/config/config-example.js index 54080450f5bb1..faa93699cb46f 100644 --- a/config/config-example.js +++ b/config/config-example.js @@ -10,6 +10,13 @@ exports.port = 8000; // know what you are doing. exports.proxyip = false; +// Go language - whether or not to use Go instead of Node.js to host the static +// and SockJS servers. Go is more likely to be more performant than Node.js for +// this purpose, but this should be kept set to false unless you're capable of +// debugging any issues that may arise due to the additional complexity of +// the code needed for this to run. +exports.golang = false; + // Pokemon of the Day - put a pokemon's name here to make it Pokemon of the Day // The PotD will always be in the #2 slot (not #1 so it won't be a lead) // in every Random Battle team. diff --git a/dev-tools/sockets.js b/dev-tools/sockets.js new file mode 100644 index 0000000000000..a4e79d77e12c7 --- /dev/null +++ b/dev-tools/sockets.js @@ -0,0 +1,35 @@ +'use strict'; + +const {Session, SockJSConnection} = require('sockjs/lib/transport'); + +const chars = 'abcdefghijklmnopqrstuvwxyz1234567890-'; +let sessionidCount = 0; + +/** + * @return string + */ +function generateSessionid() { + let ret = ''; + let idx = sessionidCount; + for (let i = 0; i < 8; i++) { + ret = chars[idx % chars.length] + ret; + idx = idx / chars.length | 0; + } + sessionidCount++; + return ret; +} + +/** + * @param {string} sessionid + * @param {{options: {{}}} config + * @return SockJSConnection + */ +exports.createSocket = function (sessionid = generateSessionid(), config = {options: {}}) { + let session = new Session(sessionid, config); + let socket = new SockJSConnection(session); + socket.remoteAddress = '127.0.0.1'; + socket.protocol = 'websocket'; + return socket; +}; + +// TODO: move worker mocks here, use require('../sockets-workers').Multiplexer to stub IPC diff --git a/pokemon-showdown b/pokemon-showdown index 553b327e0c5d5..22d77cec4d28d 100755 --- a/pokemon-showdown +++ b/pokemon-showdown @@ -2,6 +2,8 @@ 'use strict'; const child_process = require('child_process'); +const fs = require('fs'); +const path = require('path'); // Make sure we're Node 6+ @@ -22,6 +24,78 @@ try { child_process.execSync('npm install --production', {stdio: 'inherit'}); } +// Check if the server is configured to use Go, and ensure the required +// environment variables and dependencies are available if that is the case + +let config; +try { + config = require('./config/config'); +} catch (e) {} + +if (config && config.golang) { + if (!process.env.GOPATH) { + console.log('The GOPATH environment variable is not set! It is required in order to run the server using Go.'); + process.exit(0); + } + if (!process.env.GOROOT) { + console.log('The GOROOT environment variable is not set! It is required in order to run the server using Go.'); + process.exit(0); + } + + const dependencies = ['github.com/gorilla/mux', 'github.com/igm/sockjs-go/sockjs']; + let packages = child_process.execSync('go list all', {stdio: null, encoding: 'utf8'}); + for (let dep of dependencies) { + if (!packages.includes(dep)) { + console.log(`Dependency ${dep} is not installed. Fetching...`); + child_process.execSync(`go get ${dep}`, {stdio: 'inherit'}); + } + } + + const {GOPATH} = process.env; + let stat; + let needsSrcDir = false; + try { + stat = fs.lstatSync(path.resolve(GOPATH, 'src/github.com/Zarel')); + } catch (e) { + needsSrcDir = true; + } finally { + if (stat && !stat.isDirectory()) { + needsSrcDir = true; + } + } + + if (needsSrcDir) { + try { + fs.mkdirSync(path.resolve(GOPATH, 'src/github.com/Zarel')); + } catch (e) { + console.error(`Cannot make go source directory for the sockets library files! Symlink them manually from ${__dirname} to ${path.resolve(GOPATH, 'src/github.com/Zarel/Pokemon-Showdown/')}`); + process.exit(0); + } + } + + try { + stat = fs.lstatSync(path.resolve(GOPATH, 'src/github.com/Zarel/Pokemon-Showdown')); + } catch (e) {} + + if (!stat || !stat.isSymbolicLink()) { + try { + // FIXME: does this even work on Windows? Check to see if `mklink /J` + // might be needed instead + fs.symlink(__dirname, path.resolve(GOPATH, 'src/github.com/Zarel/Pokemon-Showdown')); + } catch (e) { + console.error(`Cannot make go source directory for the sockets library files! Symlink them manually from ${__dirname} to ${path.resolve(GOPATH, './src/github.com/Zarel/Pokemon-Showdown/')}`); + process.exit(0); + } + } + + console.log('Building Go source libs...'); + try { + child_process.execSync('go install github.com/Zarel/Pokemon-Showdown/sockets', {stdio: 'inherit'}); + } catch (e) { + process.exit(0); + } +} + // Start the server. We manually load app.js so it can be configured to run as // the main module, rather than this file being considered the main module. // This ensures any dependencies that were just installed can be found when diff --git a/sockets-workers.js b/sockets-workers.js new file mode 100644 index 0000000000000..d2a559609b7d8 --- /dev/null +++ b/sockets-workers.js @@ -0,0 +1,555 @@ +/** + * Connections + * Pokemon Showdown - http://pokemonshowdown.com/ + * + * Abstraction layer for multi-process SockJS connections. + * + * This file handles all the communications between the users' browsers and + * the main process. + * + * @license MIT license + */ + +'use strict'; + +const cluster = require('cluster'); +const fs = require('fs'); +const sockjs = require('sockjs'); +const StaticServer = require('node-static').Server; + +if (!global.Config) global.Config = require('./config/config'); +if (!global.Dnsbl) global.Dnsbl = require('./dnsbl'); +if (!global.Monitor) global.Monitor = require('./monitor'); + +// IPC command tokens +const EVAL = '$'; +const SOCKET_CONNECT = '*'; +const SOCKET_DISCONNECT = '!'; +const SOCKET_RECEIVE = '<'; +const SOCKET_SEND = '>'; +const CHANNEL_ADD = '+'; +const CHANNEL_REMOVE = '-'; +const CHANNEL_BROADCAST = '#'; +const SUBCHANNEL_ID_MOVE = '.'; +const SUBCHANNEL_ID_BROADCAST = ':'; + +// Subchannel IDs +const DEFAULT_SUBCHANNEL_ID = '0'; +const P1_SUBCHANNEL_ID = '1'; +const P2_SUBCHANNEL_ID = '2'; + +// Regex for splitting subchannel broadcasts between subchannels. +const SUBCHANNEL_ID_MESSAGE_REGEX = /\n\/split(\n[^\n]*)(\n[^\n]*)(\n[^\n]*)\n[^\n]*/g; + +/* + * @typedef {Map} Channel + * @typedef {Map} Sockets + * @typedef {Map} Channels + */ + +/** + * @class Multiplexer + * @description Manages the worker's state for sockets, channels, and + * subchannels. This is responsible for parsing all outgoing and incoming + * messages. + */ +class Multiplexer { + /** + * @param {number} socketCounter + * @param {Sockets} sockets + * @param {Channels} channels + * @param {NodeJS.Timer | null} cleanupInterval + */ + constructor() { + this.socketCounter = 0; + this.sockets = new Map(); + this.channels = new Map(); + this.cleanupInterval = setInterval(() => this.sweepClosedSockets(), 10 * 60 * 1000); + } + + /** + * @description Mitigates a potential bug in SockJS or Faye-Websocket where + * sockets fail to emit a 'close' event after having disconnected. + * @returns {void} + */ + sweepClosedSockets() { + this.sockets.forEach(socket => { + if (socket.protocol === 'xhr-streaming' && + socket._session && + socket._session.recv) { + socket._session.recv.didClose(); + } + + // A ghost connection's `_session.to_tref._idlePrev` (and `_idleNext`) property is `null` while + // it is an object for normal users. Under normal circumstances, those properties should only be + // `null` when the timeout has already been called, but somehow it's not happening for some connections. + // Simply calling `_session.timeout_cb` (the function bound to the aformentioned timeout) manually + // on those connections kills those connections. For a bit of background, this timeout is the timeout + // that sockjs sets to wait for users to reconnect within that time to continue their session. + if (socket._session && + socket._session.to_tref && + !socket._session.to_tref._idlePrev) { + socket._session.timeout_cb(); + } + }); + + // Don't bother deleting the sockets from our map; their close event + // handler will deal with it. + } + + /** + * @description Sends an IPC message to the parent process. + * @param {string} token + * @param {string[]} params + * @returns {void} + */ + sendUpstream(token, ...params) { + let message = `${token}${params.join('\n')}`; + // @ts-ignore + process.send(message); + } + + /** + * @description Parses the params in a downstream message sent as a + * command. + * @param {string} params + * @param {number} count + * @returns {string[]} + */ + parseParams(params, count) { + let i = 0; + let idx = 0; + let ret = []; + while (i++ < count) { + let newIdx = params.indexOf('\n', idx); + if (newIdx < 0) { + // No remaining newlines; just use the rest of the string as + // the last parametre. + ret.push(params.slice(idx)); + break; + } + + let param = params.slice(idx, newIdx); + if (i === count) { + // We reached the number of parametres needed, but there is + // still some remaining string left. Glue it to the last one. + param += `\n${params.slice(newIdx + 1)}`; + } else { + idx = newIdx + 1; + } + + ret.push(param); + } + + return ret; + } + + /** + * @description Parses downstream messages. + * @param {string} data + * @returns {boolean} + */ + receiveDownstream(data) { + let command = data.charAt(0); + let params = data.substr(1); + let socketid; + let channelid; + let subchannelid; + let message; + switch (command) { + case EVAL: + return this.onEval(params); + case SOCKET_DISCONNECT: + return this.onSocketDisconnect(params); + case SOCKET_SEND: + [socketid, message] = this.parseParams(params, 2); + return this.onSocketSend(socketid, message); + case CHANNEL_ADD: + [channelid, socketid] = this.parseParams(params, 2); + return this.onChannelAdd(channelid, socketid); + case CHANNEL_REMOVE: + [channelid, socketid] = this.parseParams(params, 2); + return this.onChannelRemove(channelid, socketid); + case CHANNEL_BROADCAST: + [channelid, message] = this.parseParams(params, 2); + return this.onChannelBroadcast(channelid, message); + case SUBCHANNEL_ID_MOVE: + [channelid, subchannelid, socketid] = this.parseParams(params, 3); + return this.onSubchannelMove(channelid, subchannelid, socketid); + case SUBCHANNEL_ID_BROADCAST: + [channelid, message] = this.parseParams(params, 2); + return this.onSubchannelBroadcast(channelid, message); + default: + Monitor.debug(`Sockets worker IPC error: unknown command type in downstream message: ${data}`); + return false; + } + } + + /** + * @description Safely tries to destroy a socket's connection. + * @param {any} socket + * @returns {void} + */ + tryDestroySocket(socket) { + try { + socket.end(); + socket.destroy(); + } catch (e) {} + } + + /** + * @description Eval handler for downstream messages. + * @param {string} expr + * @returns {boolean} + */ + onEval(expr) { + try { + eval(expr); + return true; + } catch (e) {} + return false; + } + + /** + * @description Sockets.socketConnect message handler. + * @param {any} socket + * @returns {boolean} + */ + onSocketConnect(socket) { + if (!socket) return false; + if (!socket.remoteAddress) { + this.tryDestroySocket(socket); + return false; + } + + let socketid = '' + this.socketCounter++; + let ip = socket.remoteAddress; + let ips = socket.headers['x-forwarded-for'] || ''; + this.sockets.set(socketid, socket); + this.sendUpstream(SOCKET_CONNECT, socketid, ip, ips, socket.protocol); + + socket.on('data', /** @param {string} message */ message => { + this.onSocketReceive(socketid, message); + }); + + socket.on('close', () => { + this.sendUpstream(SOCKET_DISCONNECT, socketid); + this.sockets.delete(socketid); + this.channels.forEach((channel, channelid) => { + if (!channel.has(socketid)) return; + channel.delete(socketid); + if (!channel.size) this.channels.delete(channelid); + }); + }); + + return true; + } + + /** + * @description Sockets.socketDisconnect message handler. + * @param {string} socketid + * @returns {boolean} + */ + onSocketDisconnect(socketid) { + let socket = this.sockets.get(socketid); + if (!socket) return false; + + this.tryDestroySocket(socket); + return true; + } + + /** + * @description Sockets.socketSend message handler. + * @param {string} socketid + * @param {string} message + * @returns {boolean} + */ + onSocketSend(socketid, message) { + let socket = this.sockets.get(socketid); + if (!socket) return false; + + socket.write(message); + return true; + } + + /** + * @description onmessage event handler for sockets. Passes the message + * upstream. + * @param {string} socketid + * @param {string} message + * @returns {boolean} + */ + onSocketReceive(socketid, message) { + // Drop empty messages (DDOS?). + if (!message) return false; + + // Drop >100KB messages. + if (message.length > (1000 * 1024)) { + console.log(`Dropping client message ${message.length / 1024} KB...`); + console.log(message.slice(0, 160)); + return false; + } + + // Drop legacy JSON messages. + if ((typeof message !== 'string') || message.startsWith('{')) return false; + + // Drop invalid messages (again, DDOS?). + if (!message.includes('|') || message.endsWith('|')) return false; + + this.sendUpstream(SOCKET_RECEIVE, socketid, message); + return true; + } + + /** + * @description Sockets.channelAdd message handler. + * @param {string} channelid + * @param {string} socketid + * @returns {boolean} + */ + onChannelAdd(channelid, socketid) { + if (!this.sockets.has(socketid)) return false; + + if (this.channels.has(channelid)) { + let channel = this.channels.get(channelid); + if (channel.has(socketid)) return false; + channel.set(socketid, DEFAULT_SUBCHANNEL_ID); + } else { + let channel = new Map(); + channel.set(socketid, DEFAULT_SUBCHANNEL_ID); + this.channels.set(channelid, channel); + } + + return true; + } + + /** + * @description Sockets.channelRemove message handler. + * @param {string} channelid + * @param {string} socketid + * @returns {boolean} + */ + onChannelRemove(channelid, socketid) { + let channel = this.channels.get(channelid); + if (!channel) return false; + + channel.delete(socketid); + if (!channel.size) this.channels.delete(channelid); + + return true; + } + + /** + * @description Sockets.channelSend and Sockets.channelBroadcast message + * handler. + * @param {string} channelid + * @param {string} message + * @returns {boolean} + */ + onChannelBroadcast(channelid, message) { + let channel = this.channels.get(channelid); + if (!channel) return false; + + channel.forEach( + /** @param {string} subchannelid */ + /** @param {string} socketid */ + (subchannelid, socketid) => { + let socket = this.sockets.get(socketid); + socket.write(message); + } + ); + + return true; + } + + /** + * @description Sockets.subchannelMove message handler. + * @param {string} channelid + * @param {string} subchannelid + * @param {string} socketid + * @returns {boolean} + */ + onSubchannelMove(channelid, subchannelid, socketid) { + if (!this.sockets.has(socketid)) return false; + + if (this.channels.has(channelid)) { + let channel = new Map([[socketid, subchannelid]]); + this.channels.set(channelid, channel); + } else { + let channel = this.channels.get(channelid); + channel.set(socketid, subchannelid); + } + + return true; + } + + /** + * @description Sockets.subchannelBroadcast message handler. + * @param {string} channelid + * @param {string} message + * @returns {boolean} + */ + onSubchannelBroadcast(channelid, message) { + let channel = this.channels.get(channelid); + if (!channel) return false; + + /** @type {RegExpExecArray | null} */ + let matches = SUBCHANNEL_ID_MESSAGE_REGEX.exec(message); + if (!matches) return false; + + let [match, msg1, msg2, msg3] = matches.splice(0); + channel.forEach( + /** @param {string} subchannelid */ + /** @param {string} socketid */ + (subchannelid, socketid) => { + let socket = this.sockets.get(socketid); + if (!socket) return; + + switch (subchannelid) { + case DEFAULT_SUBCHANNEL_ID: + socket.write(msg1); + break; + case P1_SUBCHANNEL_ID: + socket.write(msg2); + break; + case P2_SUBCHANNEL_ID: + socket.write(msg3); + break; + default: + Monitor.debug(`Sockets worker ${cluster.worker.id} received a message targeted at an unknown subchannel: ${match}`); + break; + } + } + ); + + return true; + } +} + +exports.Multiplexer = Multiplexer; + +if (cluster.isWorker) { + if (process.env.PSPORT) Config.port = +process.env.PSPORT; + if (process.env.PSBINDADDR) Config.bindaddress = process.env.PSBINDADDR; + if (+process.env.PSNOSSL) Config.ssl = null; + if (Config.crashguard) { + // Graceful crash. + process.on('uncaughtException', /** @param {Error} err */ err => { + require('./crashlogger')(err, `Socket process ${cluster.worker.id} (${process.pid})`, true); + }); + } + + // This is optional. If ofe is installed, it will take a heapdump if the + // process runs out of memory. + try { + require('ofe').call(); + } catch (e) {} + + let app = require('http').createServer(); + let appssl = null; + if (Config.ssl) { + let key; + let cert; + try { + key = fs.readFileSync(Config.ssl.options.key); + cert = fs.readFileSync(Config.ssl.options.cert); + Config.ssl.options.key = key; + Config.ssl.options.cert = cert; + } catch (e) { + console.error('The configured SSL key and cert must be the filenames of their according files now in order for Go processes to be able to host over HTTPS.'); + } finally { + appssl = require('https').createServer(Config.ssl.options); + } + } + + // Launch the static server. + try { + let cssserver = new StaticServer('./config'); + let avatarserver = new StaticServer('./config/avatars'); + let staticserver = new StaticServer('./static'); + /** @param {any} request */ + /** @param {any} response */ + let staticRequestHandler = (request, response) => { + // console.log("static rq: " + request.socket.remoteAddress + ":" + request.socket.remotePort + " -> " + request.socket.localAddress + ":" + request.socket.localPort + " - " + request.method + " " + request.url + " " + request.httpVersion + " - " + request.rawHeaders.join('|')); + request.resume(); + request.addListener('end', () => { + if (Config.customhttpresponse && + Config.customhttpresponse(request, response)) { + return; + } + let server; + if (request.url === '/custom.css') { + server = cssserver; + } else if (request.url.substr(0, 9) === '/avatars/') { + request.url = request.url.substr(8); + server = avatarserver; + } else { + if (/^\/([A-Za-z0-9][A-Za-z0-9-]*)\/?$/.test(request.url)) { + request.url = '/'; + } + server = staticserver; + } + + server.serve(request, response, + /** @param {any} e */ + /** @param {any} res */ + (e, res) => { + if (e && (e.status === 404)) { + staticserver.serveFile('404.html', 404, {}, request, response); + } + } + ); + }); + }; + app.on('request', staticRequestHandler); + if (appssl) appssl.on('request', staticRequestHandler); + } catch (e) {} + + // Launch the SockJS server. + /** @type {any} */ + const server = sockjs.createServer({ + sockjs_url: '//play.pokemonshowdown.com/js/lib/sockjs-1.1.1-nwjsfix.min.js', + /** @param {string} severity */ + /** @param {string} message */ + log: (severity, message) => { + if (severity === 'error') Monitor.debug(`Sockets worker SockJS error: ${message}`); + }, + prefix: '/showdown', + }); + + // Instantiate SockJS' multiplexer. This takes messages received downstream + // from the parent process and distributes them across the sockets they are + // targeting, as well as handling user disconnects and passing user + // messages upstream. + const multiplexer = new Multiplexer(); + + process.on('message', /** @param {string} data */ data => { + // console.log('worker received: ' + data); + let ret = multiplexer.receiveDownstream(data); + if (!ret) { + Monitor.debug(`Sockets worker IPC error: failed to parse downstream message: ${data}`); + } + }); + + process.on('disconnect', () => { + process.exit(0); + }); + + server.on('connection', /** @param {any} socket */ socket => { + multiplexer.onSocketConnect(socket); + }); + + server.installHandlers(app, {}); + if (!Config.bindaddress) Config.bindaddress = '0.0.0.0'; + app.listen(Config.port, Config.bindaddress); + console.log(`Worker ${cluster.worker.id} now listening on ${Config.bindaddress}:${Config.port}`); + + if (appssl) { + server.installHandlers(appssl, {}); + appssl.listen(Config.ssl.port, Config.bindaddress); + console.log(`Worker ${cluster.worker.id} now listening for SSL on port ${Config.ssl.port}`); + } + + console.log(`Test your server at http://${Config.bindaddress === '0.0.0.0' ? 'localhost' : Config.bindaddress}:${Config.port}`); + + require('./repl').start('sockets-', `${cluster.worker.id}-${process.pid}`, /** @param {string} cmd */ cmd => eval(cmd)); +} diff --git a/sockets.js b/sockets.js index 324da89dbb5ef..c135a1cd23f35 100644 --- a/sockets.js +++ b/sockets.js @@ -4,9 +4,8 @@ * * Abstraction layer for multi-process SockJS connections. * - * This file handles all the communications between the users' - * browsers, the networking processes, and users.js in the - * main process. + * This file handles all the communications between the networking processes + * and users.js. * * @license MIT license */ @@ -14,487 +13,612 @@ 'use strict'; const cluster = require('cluster'); -global.Config = require('./config/config'); +const EventEmitter = require('events'); + +if (!global.Config) global.Config = require('./config/config'); if (cluster.isMaster) { cluster.setupMaster({ - exec: require('path').resolve(__dirname, 'sockets'), + exec: require('path').resolve(__dirname, 'sockets-workers'), }); +} - const workers = exports.workers = new Map(); - - const spawnWorker = exports.spawnWorker = function () { - let worker = cluster.fork({PSPORT: Config.port, PSBINDADDR: Config.bindaddress || '0.0.0.0', PSNOSSL: Config.ssl ? 0 : 1}); - let id = worker.id; - workers.set(id, worker); - worker.on('message', data => { - // console.log('master received: ' + data); - switch (data.charAt(0)) { - case '*': { - // *socketid, ip, protocol - // connect - let nlPos = data.indexOf('\n'); - let nlPos2 = data.indexOf('\n', nlPos + 1); - Users.socketConnect(worker, id, data.slice(1, nlPos), data.slice(nlPos + 1, nlPos2), data.slice(nlPos2 + 1)); - break; - } +/** @typedef {any} NodeJSWorker */ +const {Worker} = cluster; // eslint-disable-line no-unused-vars +/** @typedef {any} Socket */ +const {Socket} = require('net'); // eslint-disable-line no-unused-vars +/** @typedef {NodeJSWorker | GoWorker} Worker */ - case '!': { - // !socketid - // disconnect - Users.socketDisconnect(worker, id, data.substr(1)); - break; - } +/** + * @description IPC delimiter byte. Required to parse messages sent to and from + * Go workers. + * @type {string} + */ +const DELIM = '\u0003'; - case '<': { - // boolean} isTrustedProxyIp + */ + constructor(worker) { + this.id = worker.id; + this.worker = worker; + this.process = worker.process; + this.exitedAfterDisconnect = worker.exitedAfterDisconnect; + this.isTrustedProxyIp = Dnsbl.checker(Config.proxyip); + + worker.on('message', + /** @param {string} data */ + data => this.onMessage(data) + ); + worker.on('error', () => { + // Ignore. Neither kind of child process ever print to stderr + // without throwing/panicking and emitting the diconnect/exit + // events. + }); + worker.once('disconnect', + /** @param {string} data */ + data => { + if (this.exitedAfterDisconnect !== undefined) return; + this.exitedAfterDisconnect = true; + process.nextTick(() => this.onDisconnect(data)); } - - default: - // unhandled + ); + worker.once('exit', + /** @param {number} code */ + /** @param {string} signal */ + (code, signal) => { + if (this.exitedAfterDisconnect !== undefined) return; + this.exitedAfterDisconnect = false; + process.nextTick(() => this.onExit(code, signal)); } - }); + ); + } - return worker; - }; + /** + * @description Worker#suicide getter wrapper + * @returns {boolean | undefined} + */ + get suicide() { + return this.exitedAfterDisconnect; + } - cluster.on('disconnect', worker => { - // worker crashed, try our best to clean up - require('./crashlogger')(new Error(`Worker ${worker.id} abruptly died`), "The main process"); + /** + * @description Worker#suicide setter wrapper + * @param {boolean} val + * @returns {void} + */ + set suicide(val) { + this.exitedAfterDisconnect = val; + } - // this could get called during cleanup; prevent it from crashing - // note: overwriting Worker#send is unnecessary in Node.js v7.0.0 and above - // see https://github.com/nodejs/node/commit/8c53d2fe9f102944cc1889c4efcac7a86224cf0a - worker.send = () => {}; + /** + * @description Worker#kill wrapper + * @param {string} signal + * @returns {void} + */ + kill(signal = 'SIGTERM') { + return this.worker.kill(signal); + } - let count = 0; - Users.connections.forEach(connection => { - if (connection.worker === worker) { - Users.socketDisconnect(worker, worker.id, connection.socketid); - count++; - } - }); - console.error(`${count} connections were lost.`); + /** + * @description Worker#destroy wrapper + * @param {string} signal + * @returns {void} + */ + destroy(signal) { + return this.kill(signal); + } - // don't delete the worker, so we can investigate it if necessary. + /** + * @description Worker#send wrapper + * @param {string} message + * @param {any?} sendHandle + * @returns {void} + */ + send(message, sendHandle) { + return this.worker.send(message, sendHandle); + } - // attempt to recover - spawnWorker(); - }); + /** + * @description Worker#isConnected wrapper + * @returns {boolean} + */ + isConnected() { + return this.worker.isConnected(); + } - exports.listen = function (port, bindAddress, workerCount) { - if (port !== undefined && !isNaN(port)) { - Config.port = port; - Config.ssl = null; - } else { - port = Config.port; - // Autoconfigure the app when running in cloud hosting environments: - try { - let cloudenv = require('cloud-env'); - bindAddress = cloudenv.get('IP', bindAddress); - port = cloudenv.get('PORT', port); - } catch (e) {} - } - if (bindAddress !== undefined) { - Config.bindaddress = bindAddress; - } - if (workerCount === undefined) { - workerCount = (Config.workers !== undefined ? Config.workers : 1); - } - for (let i = 0; i < workerCount; i++) { - spawnWorker(); - } - }; - - exports.killWorker = function (worker) { - let count = 0; - Users.connections.forEach(connection => { - if (connection.worker === worker) { - Users.socketDisconnect(worker, worker.id, connection.socketid); - count++; + /** + * @description Worker#isDead wrapper + * @returns {boolean} + */ + isDead() { + return this.worker.isDead(); + } + + /** + * @description Splits the parametres of incoming IPC messages from the + * worker's child process for the 'message' event handler. + * @param {string} params + * @param {number} count + * @returns {string[]} + */ + parseParams(params, count) { + let i = 0; + let idx = 0; + let ret = []; + while (i++ < count) { + let newIdx = params.indexOf('\n', idx); + if (newIdx < 0) { + // No remaining newlines; just use the rest of the string as + // the last parametre. + ret.push(params.slice(idx)); + break; } - }); - try { - worker.kill(); - } catch (e) {} - workers.delete(worker.id); - return count; - }; - - exports.killPid = function (pid) { - pid = '' + pid; - for (let [workerid, worker] of workers) { // eslint-disable-line no-unused-vars - if (pid === '' + worker.process.pid) { - return this.killWorker(worker); + + let param = params.slice(idx, newIdx); + if (i === count) { + // We reached the number of parametres needed, but there is + // still some remaining string left. Glue it to the last one. + param += `\n${params.slice(newIdx + 1)}`; + } else { + idx = newIdx + 1; } + + ret.push(param); } - return false; - }; - - exports.socketSend = function (worker, socketid, message) { - worker.send(`>${socketid}\n${message}`); - }; - exports.socketDisconnect = function (worker, socketid) { - worker.send(`!${socketid}`); - }; - - exports.channelBroadcast = function (channelid, message) { - workers.forEach(worker => { - worker.send(`#${channelid}\n${message}`); - }); - }; - exports.channelSend = function (worker, channelid, message) { - worker.send(`#${channelid}\n${message}`); - }; - exports.channelAdd = function (worker, channelid, socketid) { - worker.send(`+${channelid}\n${socketid}`); - }; - exports.channelRemove = function (worker, channelid, socketid) { - worker.send(`-${channelid}\n${socketid}`); - }; - - exports.subchannelBroadcast = function (channelid, message) { - workers.forEach(worker => { - worker.send(`:${channelid}\n${message}`); - }); - }; - exports.subchannelMove = function (worker, channelid, subchannelid, socketid) { - worker.send(`.${channelid}\n${subchannelid}\n${socketid}`); - }; -} else { - // is worker - - if (process.env.PSPORT) Config.port = +process.env.PSPORT; - if (process.env.PSBINDADDR) Config.bindaddress = process.env.PSBINDADDR; - if (+process.env.PSNOSSL) Config.ssl = null; - - // ofe is optional - // if installed, it will heap dump if the process runs out of memory - try { - require('ofe').call(); - } catch (e) {} - // Static HTTP server + return ret; + } - // This handles the custom CSS and custom avatar features, and also - // redirects yourserver:8001 to yourserver-8001.psim.us + /** + * @description 'message' event handler for the worker. Parses which type + * of command the incoming IPC message uses, then parses its parametres and + * calls the appropriate Users method. + * @param {string} data + * @returns {boolean} + */ + onMessage(data) { + // console.log('master received: ' + data); + let command = data.charAt(0); + let params = data.substr(1); + switch (command) { + case '*': + let [socketid, ip, header, protocol] = this.parseParams(params, 4); + let ips; + if (this.isTrustedProxyIp(ip)) { + ips = (header || '').split(','); + for (let i = ips.length; i--;) { + ip = ips[i].trim() || ip; + if (!this.isTrustedProxyIp(ip)) break; + } + } + Users.socketConnect(this.worker, this.id, socketid, ip, protocol); + break; + case '!': + Users.socketDisconnect(this.worker, this.id, params); + break; + case '<': + Users.socketReceive(this.worker, this.id, ...this.parseParams(params, 2)); + break; + default: + Monitor.debug(`Sockets: master received unknown IPC command type: ${data}`); + break; + } + } - // It's optional if you don't need these features. + /** + * @description 'disconnect' event handler for the worker. Cleans up any + * remaining users whose sockets were contained by the worker's child + * process, then attempts to respawn it.. + * @param {string} data + * @returns {void} + */ + onDisconnect(data) { + require('./crashlogger')(new Error(`Worker ${this.id} abruptly died with the following stack trace: ${data}`), 'The main process'); + console.error(`${Users.socketDisconnectAll(this.worker)} connections were lost.`); + spawnWorker(); + } - global.Dnsbl = require('./dnsbl'); + /** + * @description 'exit' event handler for the worker. Only used by GoWorker + * instances, since the 'disconnect' event is only available for Node.js + * workers. + * @param {number} code + * @param {string?} signal + * @returns {void} + */ + onExit(code, signal) { + require('./crashlogger')(new Error(`Worker ${this.id} abruptly died with code ${code} and signal ${signal}`), 'The main process'); + console.error(`${Users.socketDisconnectAll(this.worker)} connections were lost.`); + spawnWorker(); + } +} - if (Config.crashguard) { - // graceful crash - process.on('uncaughtException', err => { - require('./crashlogger')(err, `Socket process ${cluster.worker.id} (${process.pid})`, true); - }); +exports.WorkerWrapper = WorkerWrapper; + +/** + * @class GoWorker + * @extends NodeJS.EventEmitter + * @description A mock Worker class for Go child processes. Similarly to + * Node.js workers, it uses a TCP net server to perform IPC. After launching + * the server, it will spawn the Go child process and wait for it to make a + * connection to the worker's server before performing IPC with it. + */ +class GoWorker extends EventEmitter { + /** + * @param {number} id + * @prop {number} id + * @prop {NodeJS.ChildProcess | null} process + * @prop {boolean | undefined} exitedAfterDisconnect + * @prop {NodeJS.net.Server | null} server + * @prop {NodeJS.net.NodeSocket | null} connection + * @prop {string[]} buffer + */ + constructor(id) { + super(); + + this.id = id; + this.process = null; + this.exitedAfterDisconnect = undefined; + + this.server = null; + this.connection = null; + /** @type {string[]} */ + this.buffer = []; + + process.nextTick(() => this.spawnServer()); } - let app = require('http').createServer(); - let appssl; - if (Config.ssl) { - appssl = require('https').createServer(Config.ssl.options); + /** + * @description Worker#kill mock + * @param {string} signal + * @returns {void} + */ + kill(signal = 'SIGTERM') { + if (this.isConnected()) this.connection.end(); + if (!this.isDead() && this.process) this.process.kill(signal); + if (this.server) this.server.close(); + this.exitedAfterDisconnect = false; } - try { - let nodestatic = require('node-static'); - let cssserver = new nodestatic.Server('./config'); - let avatarserver = new nodestatic.Server('./config/avatars'); - let staticserver = new nodestatic.Server('./static'); - let staticRequestHandler = (request, response) => { - // console.log("static rq: " + request.socket.remoteAddress + ":" + request.socket.remotePort + " -> " + request.socket.localAddress + ":" + request.socket.localPort + " - " + request.method + " " + request.url + " " + request.httpVersion + " - " + request.rawHeaders.join('|')); - request.resume(); - request.addListener('end', () => { - if (Config.customhttpresponse && - Config.customhttpresponse(request, response)) { - return; - } - let server; - if (request.url === '/custom.css') { - server = cssserver; - } else if (request.url.substr(0, 9) === '/avatars/') { - request.url = request.url.substr(8); - server = avatarserver; - } else { - if (/^\/([A-Za-z0-9][A-Za-z0-9-]*)\/?$/.test(request.url)) { - request.url = '/'; - } - server = staticserver; - } - server.serve(request, response, (e, res) => { - if (e && (e.status === 404)) { - staticserver.serveFile('404.html', 404, {}, request, response); - } - }); + + /** + * @description Worker#destroy mock + * @param {string=} signal + * @returns {void} + */ + destroy(signal) { + return this.kill(signal); + } + + /** + * @description Worker#send mock + * @param {string} message + * @param {any?} sendHandle + * @returns {void} + */ + send(message, sendHandle) { // eslint-disable-line no-unused-vars + if (!this.isConnected()) { + this.buffer.push(message); + return; + } + + if (this.buffer.length) { + this.buffer.splice(0).forEach(msg => { + this.connection.write(JSON.stringify(msg) + DELIM); }); - }; - app.on('request', staticRequestHandler); - if (appssl) { - appssl.on('request', staticRequestHandler); } - } catch (e) { - console.log('Could not start node-static - try `npm install` if you want to use it'); - } - // SockJS server + return this.connection.write(JSON.stringify(message) + DELIM); + } - // This is the main server that handles users connecting to our server - // and doing things on our server. + /** + * @description Worker#isConnected mock + * @returns {boolean} + */ + isConnected() { + return this.connection && !this.connection.destroyed; + } - const sockjs = require('sockjs'); + /** + * @description Worker#isDead mock + * @returns {boolean} + */ + isDead() { + return !this.process || this.connection.exitCode !== null || this.connection.statusCode !== null; + } - const server = sockjs.createServer({ - sockjs_url: "//play.pokemonshowdown.com/js/lib/sockjs-1.1.1-nwjsfix.min.js", - log: (severity, message) => { - if (severity === 'error') console.log('ERROR: ' + message); - }, - prefix: '/showdown', - }); + /** + * @description Spawns the TCP server through which IPC with the child + * process is handled. + * @returns {boolean} + */ + spawnServer() { + if (!this.isDead()) return false; + + this.server = require('net').createServer(); + this.server.on('error', console.error); + this.server.once('listening', () => { + // Spawn the child process after the TCP server has finished + // launching to allow it to connect to it for IPC. + process.nextTick(() => this.spawnChild()); + }); + // When the child process finally connects to the TCP server we can + // begin communicating with it using a random port. + this.server.listen(() => { + if (!this.server) return; + this.server.once('connection', connection => { + process.nextTick(() => this.bootstrapChild(connection)); + }); + }); + } - const sockets = new Map(); - const channels = new Map(); - const subchannels = new Map(); - - // Deal with phantom connections. - const sweepClosedSockets = () => { - sockets.forEach(socket => { - if (socket.protocol === 'xhr-streaming' && - socket._session && - socket._session.recv) { - socket._session.recv.didClose(); + /** + * @description Spawns the Go child process. Once the process has started, + * it will make a connection to the worker's TCP server. + * @returns {void} + */ + spawnChild() { + if (!this.server) return this.spawnServer(); + this.process = require('child_process').spawn( + `${process.env.GOPATH}/bin/sockets`, [], { + env: { + GOPATH: process.env.GOPATH || '', + GOROOT: process.env.GOROOT || '', + PS_IPC_PORT: `:${this.server.address().port}`, + PS_CONFIG: JSON.stringify({ + workers: Config.workers || 1, + port: `:${Config.port || 8000}`, + bindAddress: Config.bindaddress || '0.0.0.0', + ssl: Config.ssl || null, + }), + }, + stdio: ['inherit', 'inherit', 'pipe'], + shell: true, } + ); + + this.process.once('exit', (code, signal) => { + process.nextTick(() => this.emit('exit', code, signal)); + }); - // A ghost connection's `_session.to_tref._idlePrev` (and `_idleNext`) property is `null` while - // it is an object for normal users. Under normal circumstances, those properties should only be - // `null` when the timeout has already been called, but somehow it's not happening for some connections. - // Simply calling `_session.timeout_cb` (the function bound to the aformentioned timeout) manually - // on those connections kills those connections. For a bit of background, this timeout is the timeout - // that sockjs sets to wait for users to reconnect within that time to continue their session. - if (socket._session && - socket._session.to_tref && - !socket._session.to_tref._idlePrev) { - socket._session.timeout_cb(); + this.process.stderr.setEncoding('utf8'); + this.process.stderr.once('data', data => { + process.nextTick(() => this.emit('error', data)); + }); + } + + /** + * @description 'connection' event handler for the TCP server. Begins + * the parsing of incoming IPC messages. + * @param {Socket} connection + * @returns {void} + */ + bootstrapChild(connection) { + this.connection = connection; + this.connection.setEncoding('utf8'); + this.connection.on('data', + /** @param {string} data */ + data => { + let messages = data.slice(0, -1).split(DELIM); + for (let message of messages) { + this.emit('message', JSON.parse(message)); + } } + ); + + // Leave the error handling to the process, not the connection. + this.connection.on('error', () => {}); + } +} + +exports.GoWorker = GoWorker; + +/** + * @description Map of worker IDs to worker processes. + * @type {Map} + */ +const workers = exports.workers = new Map(); + +/** + * @description Worker ID counter used for Go workers. + * @type {number} + */ +let nextWorkerid = 0; + +/** + * @description Spawns a new worker process. + * @returns {Worker} + */ +function spawnWorker() { + let worker; + if (Config.golang) { + worker = new GoWorker(nextWorkerid); + } else { + worker = cluster.fork({ + PSPORT: Config.port, + PSBINDADDR: Config.bindaddress || '0.0.0.0', + PSNOSSL: Config.ssl ? 0 : 1, }); - }; - const interval = setInterval(sweepClosedSockets, 1000 * 60 * 10); // eslint-disable-line no-unused-vars - - process.on('message', data => { - // console.log('worker received: ' + data); - let socket = null; - let socketid = ''; - let channel = null; - let channelid = ''; - let subchannel = null; - let subchannelid = ''; - let nlLoc = -1; - let message = ''; - - switch (data.charAt(0)) { - case '$': // $code - eval(data.substr(1)); - break; + } - case '!': // !socketid - // destroy - socketid = data.substr(1); - socket = sockets.get(socketid); - if (!socket) return; - socket.end(); - // After sending the FIN packet, we make sure the I/O is totally blocked for this socket - socket.destroy(); - sockets.delete(socketid); - channels.forEach(channel => channel.delete(socketid)); - break; + let wrapper = new WorkerWrapper(worker); + workers.set(wrapper.id, wrapper); + nextWorkerid++; + return wrapper; +} - case '>': - // >socketid, message - // message - nlLoc = data.indexOf('\n'); - socketid = data.substr(1, nlLoc - 1); - socket = sockets.get(socketid); - if (!socket) return; - message = data.substr(nlLoc + 1); - socket.write(message); - break; +exports.spawnWorker = spawnWorker; - case '#': - // #channelid, message - // message to channel - nlLoc = data.indexOf('\n'); - channelid = data.substr(1, nlLoc - 1); - channel = channels.get(channelid); - if (!channel) return; - message = data.substr(nlLoc + 1); - channel.forEach(socket => socket.write(message)); - break; +/** + * @description Initializes the configured number of worker processes. + * @param {any} port + * @param {any} bindAddress + * @param {any} workerCount + * @returns {void} + */ +exports.listen = function (port, bindAddress, workerCount) { + if (port !== undefined && !isNaN(port)) { + Config.port = port; + Config.ssl = null; + } else { + port = Config.port; + // Autoconfigure the app when running in cloud hosting environments: + try { + let cloudenv = require('cloud-env'); + bindAddress = cloudenv.get('IP', bindAddress); + port = cloudenv.get('PORT', port); + } catch (e) {} + } + if (bindAddress !== undefined) { + Config.bindaddress = bindAddress; + } - case '+': - // +channelid, socketid - // add to channel - nlLoc = data.indexOf('\n'); - socketid = data.substr(nlLoc + 1); - socket = sockets.get(socketid); - if (!socket) return; - channelid = data.substr(1, nlLoc - 1); - channel = channels.get(channelid); - if (!channel) { - channel = new Map(); - channels.set(channelid, channel); - } - channel.set(socketid, socket); - break; + // Go only uses one child process since it does not share FD handles for + // serving like Node.js workers do. Workers are instead used to limit the + // number of concurrent requests that can be handled at once in the child + // process. + if (Config.golang) { + spawnWorker(); + return; + } - case '-': - // -channelid, socketid - // remove from channel - nlLoc = data.indexOf('\n'); - channelid = data.slice(1, nlLoc); - channel = channels.get(channelid); - if (!channel) return; - socketid = data.slice(nlLoc + 1); - channel.delete(socketid); - subchannel = subchannels.get(channelid); - if (subchannel) subchannel.delete(socketid); - if (!channel.size) { - channels.delete(channelid); - if (subchannel) subchannels.delete(channelid); - } - break; + if (workerCount === undefined) { + workerCount = (Config.workers !== undefined ? Config.workers : 1); + } + for (let i = 0; i < workerCount; i++) { + spawnWorker(); + } +}; - case '.': - // .channelid, subchannelid, socketid - // move subchannel - nlLoc = data.indexOf('\n'); - channelid = data.slice(1, nlLoc); - let nlLoc2 = data.indexOf('\n', nlLoc + 1); - subchannelid = data.slice(nlLoc + 1, nlLoc2); - socketid = data.slice(nlLoc2 + 1); - - subchannel = subchannels.get(channelid); - if (!subchannel) { - subchannel = new Map(); - subchannels.set(channelid, subchannel); - } - if (subchannelid === '0') { - subchannel.delete(socketid); - } else { - subchannel.set(socketid, subchannelid); - } - break; +/** + * @description Kills a worker process using the given worker object. + * @param {Worker} worker + * @returns {number} + */ +exports.killWorker = function (worker) { + let count = Users.socketDisconnectAll(worker); + try { + worker.kill(); + } catch (e) {} + workers.delete(worker.id); + return count; +}; - case ':': - // :channelid, message - // message to subchannel - nlLoc = data.indexOf('\n'); - channelid = data.slice(1, nlLoc); - channel = channels.get(channelid); - if (!channel) return; - - let messages = [null, null, null]; - message = data.substr(nlLoc + 1); - subchannel = subchannels.get(channelid); - channel.forEach((socket, socketid) => { - switch (subchannel ? subchannel.get(socketid) : '0') { - case '1': - if (!messages[1]) { - messages[1] = message.replace(/\n\|split\n[^\n]*\n([^\n]*)\n[^\n]*\n[^\n]*/g, '\n$1'); - } - socket.write(messages[1]); - break; - case '2': - if (!messages[2]) { - messages[2] = message.replace(/\n\|split\n[^\n]*\n[^\n]*\n([^\n]*)\n[^\n]*/g, '\n$1'); - } - socket.write(messages[2]); - break; - default: - if (!messages[0]) { - messages[0] = message.replace(/\n\|split\n([^\n]*)\n[^\n]*\n[^\n]*\n[^\n]*/g, '\n$1'); - } - socket.write(messages[0]); - break; - } - }); - break; +/** + * @description Kills a worker process using the given worker PID. + * @param {number} pid + * @returns {number | false} + */ +exports.killPid = function (pid) { + for (let [workerid, worker] of workers) { // eslint-disable-line no-unused-vars + if (pid === worker.process.pid) { + return this.killWorker(worker); } - }); + } + return false; +}; - process.on('disconnect', () => { - process.exit(); - }); +/** + * @description Sends a message to a socket in a given worker by ID. + * @param {Worker} worker + * @param {string} socketid + * @param {string} message + * @returns {void} + */ +exports.socketSend = function (worker, socketid, message) { + worker.send(`>${socketid}\n${message}`); +}; - // this is global so it can be hotpatched if necessary - let isTrustedProxyIp = Dnsbl.checker(Config.proxyip); - let socketCounter = 0; - server.on('connection', socket => { - if (!socket) { - // For reasons that are not entirely clear, SockJS sometimes triggers - // this event with a null `socket` argument. - return; - } else if (!socket.remoteAddress) { - // This condition occurs several times per day. It may be a SockJS bug. - try { - socket.end(); - } catch (e) {} - return; - } +/** + * @description Forcefully disconnects a socket in a given worker by ID. + * @param {Worker} worker + * @param {string} socketid + * @returns {void} + */ +exports.socketDisconnect = function (worker, socketid) { + worker.send(`!${socketid}`); +}; - let socketid = socket.id = '' + (++socketCounter); - sockets.set(socketid, socket); - - if (isTrustedProxyIp(socket.remoteAddress)) { - let ips = (socket.headers['x-forwarded-for'] || '').split(','); - let ip; - while ((ip = ips.pop())) { - ip = ip.trim(); - if (!isTrustedProxyIp(ip)) { - socket.remoteAddress = ip; - break; - } - } - } +/** + * @description Broadcasts a message to all sockets in a given channel across + * all workers. + * @param {string} channelid + * @param {string} message + * @returns {void} + */ +exports.channelBroadcast = function (channelid, message) { + workers.forEach(worker => { + worker.send(`#${channelid}\n${message}`); + }); +}; - process.send(`*${socketid}\n${socket.remoteAddress}\n${socket.protocol}`); +/** + * @description Broadcasts a message to all sockets in a given channel and a + * given worker. + * @param {Worker} worker + * @param {string} channelid + * @param {string} message + * @returns {void} + */ +exports.channelSend = function (worker, channelid, message) { + worker.send(`#${channelid}\n${message}`); +}; - socket.on('data', message => { - // drop empty messages (DDoS?) - if (!message) return; - // drop messages over 100KB - if (message.length > 100000) { - console.log(`Dropping client message ${message.length / 1024} KB...`); - console.log(message.slice(0, 160)); - return; - } - // drop legacy JSON messages - if (typeof message !== 'string' || message.startsWith('{')) return; - // drop blank messages (DDoS?) - let pipeIndex = message.indexOf('|'); - if (pipeIndex < 0 || pipeIndex === message.length - 1) return; +/** + * @description Adds a socket to a given channel in a given worker by ID. + * @param {Worker} worker + * @param {string} channelid + * @param {string} socketid + * @returns {void} + */ +exports.channelAdd = function (worker, channelid, socketid) { + worker.send(`+${channelid}\n${socketid}`); +}; - process.send(`<${socketid}\n${message}`); - }); +/** + * @description Removes a socket from a given channel in a given worker by ID. + * @param {Worker} worker + * @param {string} channelid + * @param {string} socketid + * @returns {void} + */ +exports.channelRemove = function (worker, channelid, socketid) { + worker.send(`-${channelid}\n${socketid}`); +}; - socket.on('close', () => { - process.send(`!${socketid}`); - sockets.delete(socketid); - channels.forEach(channel => channel.delete(socketid)); - }); +/** + * @description Broadcasts a message to be demuxed into three separate messages + * across three subchannels in a given channel across all workers. + * @param {string} channelid + * @param {string} message + * @returns {void} + */ +exports.subchannelBroadcast = function (channelid, message) { + workers.forEach(worker => { + worker.send(`:${channelid}\n${message}`); }); - server.installHandlers(app, {}); - app.listen(Config.port, Config.bindaddress); - console.log(`Worker ${cluster.worker.id} now listening on ${Config.bindaddress}:${Config.port}`); - - if (appssl) { - server.installHandlers(appssl, {}); - appssl.listen(Config.ssl.port, Config.bindaddress); - console.log(`Worker ${cluster.worker.id} now listening for SSL on port ${Config.ssl.port}`); - } - - console.log(`Test your server at http://${Config.bindaddress === '0.0.0.0' ? 'localhost' : Config.bindaddress}:${Config.port}`); +}; - require('./repl').start('sockets-', `${cluster.worker.id}-${process.pid}`, cmd => eval(cmd)); -} +/** + * @description Moves a given socket to a different subchannel in a channel by + * ID in the given worker. + * @param {Worker} worker + * @param {string} channelid + * @param {string} subchannelid + * @param {string} socketid + */ +exports.subchannelMove = function (worker, channelid, subchannelid, socketid) { + worker.send(`.${channelid}\n${subchannelid}\n${socketid}`); +}; diff --git a/sockets/lib/commands.go b/sockets/lib/commands.go new file mode 100644 index 0000000000000..613e0296b95d5 --- /dev/null +++ b/sockets/lib/commands.go @@ -0,0 +1,73 @@ +package sockets + +import "strings" + +const SOCKET_CONNECT string = "*" +const SOCKET_DISCONNECT string = "!" +const SOCKET_RECEIVE string = "<" +const SOCKET_SEND string = ">" +const CHANNEL_ADD string = "+" +const CHANNEL_REMOVE string = "-" +const CHANNEL_BROADCAST string = "#" +const SUBCHANNEL_MOVE string = "." +const SUBCHANNEL_BROADCAST string = ":" + +type Command struct { + token string + paramstr string + count int + target CommandIO +} + +type CommandIO interface { + Process(Command) (err error) +} + +func NewCommand(msg string, target CommandIO) Command { + var count int + token := string(msg[:1]) + paramstr := msg[1:] + + switch token { + case SOCKET_DISCONNECT: + count = 1 + case SOCKET_RECEIVE: + count = 2 + case SOCKET_SEND: + count = 2 + case CHANNEL_ADD: + count = 2 + case CHANNEL_REMOVE: + count = 2 + case CHANNEL_BROADCAST: + count = 2 + case SUBCHANNEL_BROADCAST: + count = 2 + case SUBCHANNEL_MOVE: + count = 3 + case SOCKET_CONNECT: + count = 4 + } + + return Command{ + token: token, + paramstr: paramstr, + count: count, + target: target} +} + +func (c Command) Token() string { + return c.token +} + +func (c Command) Params() []string { + return strings.SplitN(c.paramstr, "\n", c.count) +} + +func (c Command) Message() string { + return c.token + c.paramstr +} + +func (c Command) Process() { + c.target.Process(c) +} diff --git a/sockets/lib/commands_test.go b/sockets/lib/commands_test.go new file mode 100644 index 0000000000000..0a34f8b5701c1 --- /dev/null +++ b/sockets/lib/commands_test.go @@ -0,0 +1,31 @@ +package sockets + +import "testing" + +type testTarget struct { + CommandIO +} + +func TestCommands(t *testing.T) { + tokens := []string{ + SOCKET_CONNECT, + SOCKET_DISCONNECT, + SOCKET_RECEIVE, + SOCKET_SEND, + CHANNEL_ADD, + CHANNEL_REMOVE, + CHANNEL_BROADCAST, + SUBCHANNEL_MOVE, + SUBCHANNEL_BROADCAST} + + cmds := make([]Command, len(tokens)) + for i, token := range tokens { + cmds[i] = NewCommand(token+"1\n2\n3\n4", testTarget{}) + } + for _, cmd := range cmds { + params := cmd.Params() + if len(params) != cmd.count { + t.Errorf("Commands: command type %v was expected to return %v tokens but actually returned %v", cmd.token, cmd.count, len(params)) + } + } +} diff --git a/sockets/lib/config.go b/sockets/lib/config.go new file mode 100644 index 0000000000000..637ab7e72a4a0 --- /dev/null +++ b/sockets/lib/config.go @@ -0,0 +1,29 @@ +package sockets + +import ( + "encoding/json" + "os" +) + +type config struct { + Workers int `json:"workers"` + Port string `json:"port"` + BindAddress string `json:"bindAddress"` + SSL sslOpts `json:"ssl"` +} + +type sslOpts struct { + Port string `json:"port"` + Options sslKeys `json:"options"` +} + +type sslKeys struct { + Cert string `json:"cert"` + Key string `json:"key"` +} + +func NewConfig(envVar string) (c config, err error) { + configEnv := os.Getenv(envVar) + err = json.Unmarshal([]byte(configEnv), &c) + return +} diff --git a/sockets/lib/config_test.go b/sockets/lib/config_test.go new file mode 100644 index 0000000000000..b7adba05f7673 --- /dev/null +++ b/sockets/lib/config_test.go @@ -0,0 +1,40 @@ +package sockets + +import ( + "encoding/json" + "fmt" + "testing" +) + +func newTestConfig(w int, p string, ba string, s interface{}) (c config) { + c = config{ + Workers: w, + Port: p, + BindAddress: ba} + if ssl, ok := s.(sslOpts); ok { + c.SSL = ssl + } + return +} + +func TestConfig(t *testing.T) { + t.Parallel() + ws := []int{1, 2, 3, 4} + ps := []string{":1000", ":2000", ":4000", ":8000"} + bas := []string{"127.0.0.1", "0.0.0.0", "192.168.0.1", "localhost"} + ssl := sslOpts{Port: ":443", Options: sslKeys{Cert: "", Key: ""}} + for _, w := range ws { + for _, p := range ps { + for _, ba := range bas { + t.Run(fmt.Sprintf("%v %v%v", w, ba, p, ssl), func(t *testing.T) { + go func(w int, p string, ba string, ssl sslOpts) { + c := newTestConfig(w, p, ba, ssl) + if _, err := json.Marshal(c); err != nil { + t.Errorf("Config: failed to stringify config JSON: %v", err) + } + }(w, p, ba, ssl) + }) + } + } + } +} diff --git a/sockets/lib/ipc.go b/sockets/lib/ipc.go new file mode 100644 index 0000000000000..b6e2a0aaa0be4 --- /dev/null +++ b/sockets/lib/ipc.go @@ -0,0 +1,98 @@ +package sockets + +import ( + "bufio" + "encoding/json" + "fmt" + "net" + "os" +) + +const DELIM byte = '\u0003' + +type Connection struct { + port string + addr *net.TCPAddr + conn *net.TCPConn + mux *Multiplexer + listening bool +} + +func NewConnection(envVar string) (c *Connection, err error) { + port := os.Getenv(envVar) + addr, err := net.ResolveTCPAddr("tcp", "localhost"+port) + if err != nil { + return nil, fmt.Errorf("Sockets: failed to parse TCP address to connect to the parent process with: %v", err) + } + + conn, err := net.DialTCP("tcp", nil, addr) + if err != nil { + return nil, fmt.Errorf("Sockets: failed to connect to TCP server: %v", err) + } + + c = &Connection{ + port: port, + addr: addr, + conn: conn, + listening: false} + + return +} + +func (c *Connection) Listening() bool { + return c.listening +} + +func (c *Connection) Listen(mux *Multiplexer) { + if c.listening { + return + } + + c.mux = mux + c.listening = true + + go func() { + reader := bufio.NewReader(c.conn) + for { + var token []byte + token, err := reader.ReadBytes(DELIM) + if len(token) == 0 || err != nil { + continue + } + + var msg string + err = json.Unmarshal(token[:len(token)-1], &msg) + cmd := NewCommand(msg, c.mux) + CmdQueue <- cmd + } + }() + + return +} + +func (c *Connection) Process(cmd Command) (err error) { + // fmt.Printf("Sockets => IPC: %v\n", cmd.Message()) + if !c.listening { + return fmt.Errorf("Sockets: can't process connection commands when the connection isn't listening yet") + } + + msg := cmd.Message() + _, err = c.Write(msg) + return +} + +func (c *Connection) Close() error { + return c.conn.Close() +} + +func (c *Connection) Write(message string) (int, error) { + if !c.listening { + return 0, fmt.Errorf("Sockets: can't write messages over a connection that isn't listening yet...") + } + + msg, err := json.Marshal(message) + if err != nil { + return 0, fmt.Errorf("Sockets: failed to parse upstream IPC message: %v", err) + } + return c.conn.Write(append(msg, DELIM)) +} diff --git a/sockets/lib/ipc_test.go b/sockets/lib/ipc_test.go new file mode 100644 index 0000000000000..101cdafa97079 --- /dev/null +++ b/sockets/lib/ipc_test.go @@ -0,0 +1,64 @@ +package sockets + +import ( + "net" + "os" + "testing" +) + +type testMux struct { + CommandIO +} + +func (tm *testMux) Listen(conn CommandIO) (err error) { + return nil +} + +func (tm *testMux) Process(cmd Command) (err error) { + return nil +} + +func TestConnection(t *testing.T) { + port := ":3000" + ln, err := net.Listen("tcp", "localhost"+port) + defer ln.Close() + if err != nil { + t.Errorf("Sockets: failed to launch TCP server on port %v: %v", port, err) + } + + envVar := "PS_IPC_PORT" + err = os.Setenv(envVar, port) + if err != nil { + t.Errorf("Sockets: failed to set %v environment variable: %v", envVar, port) + } + + conn, err := NewConnection(envVar) + defer conn.Close() + if err != nil { + t.Errorf("%v", err) + } + if conn.port != port { + t.Errorf("Sockets: new connection expected to have port %v but had %v instead", port, conn.port) + } + + mux := NewMultiplexer() + conn.Listen(mux) + mux.Listen(conn) + + cmd := NewCommand(SOCKET_SEND+"0\n|ayy lmao", mux) + err = conn.Process(cmd) + if err != nil { + t.Errorf("%v", err) + } + + bc, err := conn.Write(string(DELIM)) + if err != nil { + t.Errorf("%v", err) + } + bc += 3 // For the escaped backslashes and additional DELIM character + + cbc := len([]byte(cmd.Message())) + if bc != cbc { + t.Errorf("Sockets: expected the number of bytes received by the connection to be %v, but actually received %v", bc, cbc) + } +} diff --git a/sockets/lib/master.go b/sockets/lib/master.go new file mode 100644 index 0000000000000..f5d2e9ca37ec3 --- /dev/null +++ b/sockets/lib/master.go @@ -0,0 +1,59 @@ +package sockets + +var CmdQueue = make(chan Command) + +type master struct { + wpool chan chan Command + count int +} + +func NewMaster(count int) *master { + wpool := make(chan chan Command, count) + return &master{ + wpool: wpool, + count: count} +} + +func (m *master) Spawn() { + for i := 0; i < m.count; i++ { + w := newWorker(m.wpool) + w.listen() + } +} + +func (m *master) Listen() { + for { + cmd := <-CmdQueue + cmdch := <-m.wpool + cmdch <- cmd + } +} + +type worker struct { + wpool chan chan Command + cmdch chan Command + quit chan bool +} + +func newWorker(wpool chan chan Command) *worker { + cmdch := make(chan Command) + quit := make(chan bool) + return &worker{ + wpool: wpool, + cmdch: cmdch, + quit: quit} +} + +func (w *worker) listen() { + go func() { + for { + w.wpool <- w.cmdch + select { + case cmd := <-w.cmdch: + cmd.target.Process(cmd) + case <-w.quit: + return + } + } + }() +} diff --git a/sockets/lib/master_test.go b/sockets/lib/master_test.go new file mode 100644 index 0000000000000..be1c44788fef4 --- /dev/null +++ b/sockets/lib/master_test.go @@ -0,0 +1,75 @@ +package sockets + +import ( + "net" + "os" + "testing" + + "github.com/igm/sockjs-go/sockjs" +) + +type testSocket struct { + sockjs.Session +} + +func (ts testSocket) Send(msg string) error { + return nil +} + +func (ts testSocket) Close(code uint32, signal string) error { + return nil +} + +func TestMasterListen(t *testing.T) { + t.Parallel() + ln, _ := net.Listen("tcp", ":3000") + defer ln.Close() + + envVar := "PS_IPC_PORT" + os.Setenv(envVar, ":3000") + conn, _ := NewConnection(envVar) + defer conn.Close() + mux := NewMultiplexer() + mux.Listen(conn) + conn.Listen(mux) + + m := NewMaster(4) + m.Spawn() + go m.Listen() + + for i := 0; i < m.count*250; i++ { + id := string(i) + t.Run("Worker/Multiplexer command #"+id, func(t *testing.T) { + go func(id string, mux *Multiplexer, conn *Connection) { + mux.smux.Lock() + sid := string(mux.nsid) + mux.sockets[sid] = testSocket{} + mux.nsid++ + mux.smux.Unlock() + + cmd := NewCommand(SOCKET_DISCONNECT+sid, mux) + cmd.Process() + if len(CmdQueue) != 0 { + t.Error("Sockets: master failed to pass command struct from worker to multiplexer") + } + }(id, mux, conn) + }) + t.Run("Worker/Connection command #"+id, func(t *testing.T) { + go func(id string, mux *Multiplexer, conn *Connection) { + mux.smux.Lock() + sid := string(mux.nsid) + mux.smux.Unlock() + + cmd := NewCommand(SOCKET_CONNECT+sid+"\n0.0.0.0\n\nwebsocket", conn) + cmd.Process() + if len(CmdQueue) != 0 { + t.Error("Sockets: master failed to pass command struct from worker to connection") + } + }(id, mux, conn) + }) + } + + for len(m.wpool) > 0 { + <-m.wpool + } +} diff --git a/sockets/lib/multiplexer.go b/sockets/lib/multiplexer.go new file mode 100644 index 0000000000000..a7e514a3cabc0 --- /dev/null +++ b/sockets/lib/multiplexer.go @@ -0,0 +1,308 @@ +package sockets + +import ( + "fmt" + "net" + "path" + "regexp" + "strconv" + "sync" + + "github.com/igm/sockjs-go/sockjs" +) + +const DEFAULT_SUBCHANNEL_ID string = "0" +const P1_SUBCHANNEL_ID string = "1" +const P2_SUBCHANNEL_ID string = "2" + +type Multiplexer struct { + nsid uint64 + sockets map[string]sockjs.Session + smux sync.Mutex + channels map[string]map[string]string + cmux sync.Mutex + scre *regexp.Regexp + conn *Connection +} + +func NewMultiplexer() *Multiplexer { + sockets := make(map[string]sockjs.Session) + channels := make(map[string]map[string]string) + scre := regexp.MustCompile(`\n/split(\n[^\n]*)(\n[^\n]*)(\n[^\n]*)\n[^\n]*`) + return &Multiplexer{ + sockets: sockets, + channels: channels, + scre: scre} +} + +func (m *Multiplexer) Listen(conn *Connection) { + m.conn = conn +} + +func (m *Multiplexer) Process(cmd Command) (err error) { + // fmt.Printf("IPC => Sockets: %v\n", cmd.Message()) + params := cmd.Params() + + switch token := cmd.Token(); token { + case SOCKET_DISCONNECT: + sid := params[0] + err = m.socketRemove(sid, true) + case SOCKET_SEND: + sid := params[0] + msg := params[1] + err = m.socketSend(sid, msg) + case SOCKET_RECEIVE: + sid := params[0] + msg := params[1] + err = m.socketReceive(sid, msg) + case CHANNEL_ADD: + cid := params[0] + sid := params[1] + err = m.channelAdd(cid, sid) + case CHANNEL_REMOVE: + cid := params[0] + sid := params[1] + err = m.channelRemove(cid, sid) + case CHANNEL_BROADCAST: + cid := params[0] + msg := params[1] + err = m.channelBroadcast(cid, msg) + case SUBCHANNEL_MOVE: + cid := params[0] + scid := params[1] + sid := params[2] + err = m.subchannelMove(cid, scid, sid) + case SUBCHANNEL_BROADCAST: + cid := params[0] + msg := params[1] + err = m.subchannelBroadcast(cid, msg) + } + + if err != nil { + // Something went wrong somewhere, but it's likely a timing issue from + // the parent process. Let's just log the error instead of crashing. + fmt.Printf("%v\n", err) + } + + return +} + +func (m *Multiplexer) socketAdd(s sockjs.Session) (sid string) { + m.smux.Lock() + defer m.smux.Unlock() + + sid = strconv.FormatUint(m.nsid, 10) + m.nsid++ + m.sockets[sid] = s + + if m.conn.Listening() { + req := s.Request() + ip, _, _ := net.SplitHostPort(req.RemoteAddr) + ips := req.Header.Get("X-Forwarded-For") + protocol := path.Base(req.URL.Path) + + cmd := NewCommand(SOCKET_CONNECT+sid+"\n"+ip+"\n"+ips+"\n"+protocol, m.conn) + CmdQueue <- cmd + } + + return +} + +func (m *Multiplexer) socketRemove(sid string, forced bool) error { + m.smux.Lock() + defer m.smux.Unlock() + + m.cmux.Lock() + for cid, c := range m.channels { + if _, ok := c[sid]; ok { + delete(c, sid) + if len(c) == 0 { + delete((*m).channels, cid) + } + } + } + m.cmux.Unlock() + + s, ok := m.sockets[sid] + if ok { + delete((*m).sockets, sid) + } else { + return fmt.Errorf("Sockets: attempted to remove socket of ID %v that doesn't exist", sid) + } + + if forced { + s.Close(2010, "Normal closure") + } else { + // User disconnected on their own. Poke the parent process to clean up. + if m.conn.Listening() { + cmd := NewCommand(SOCKET_DISCONNECT+sid, m.conn) + CmdQueue <- cmd + } + } + + return nil +} + +func (m *Multiplexer) socketReceive(sid string, msg string) error { + m.smux.Lock() + defer m.smux.Unlock() + + if _, ok := m.sockets[sid]; ok { + if m.conn.Listening() { + cmd := NewCommand(SOCKET_RECEIVE+sid+"\n"+msg, m.conn) + CmdQueue <- cmd + } + return nil + } + + return fmt.Errorf("Sockets: received a message for a socket of ID %v that does not exist: %v", sid, msg) +} + +func (m *Multiplexer) socketSend(sid string, msg string) error { + m.smux.Lock() + defer m.smux.Unlock() + + if s, ok := m.sockets[sid]; ok && m.conn.Listening() { + s.Send(msg) + return nil + } + + return fmt.Errorf("Sockets: attempted to send to socket of ID %v, which does not exist", sid) +} + +func (m *Multiplexer) channelAdd(cid string, sid string) error { + m.cmux.Lock() + defer m.cmux.Unlock() + + c, ok := m.channels[cid] + if !ok { + c = make(map[string]string) + m.channels[cid] = c + } + + c[sid] = DEFAULT_SUBCHANNEL_ID + + return nil +} + +func (m *Multiplexer) channelRemove(cid string, sid string) error { + m.cmux.Lock() + defer m.cmux.Unlock() + + c, ok := m.channels[cid] + if ok { + if _, ok = c[sid]; !ok { + return fmt.Errorf("Sockets: failed to remove nonexistent socket of ID %v from channel %v", sid, cid) + } + } else { + // This occasionally happens on user disconnect. + return nil + } + + delete(c, sid) + if len(c) == 0 { + delete((*m).channels, cid) + } + + return nil +} + +func (m *Multiplexer) channelBroadcast(cid string, msg string) error { + m.cmux.Lock() + defer m.cmux.Unlock() + + c, ok := m.channels[cid] + if !ok { + // This happens occasionally when the last user leaves a room. Mitigate + return nil + } + + m.smux.Lock() + defer m.smux.Unlock() + + for sid, _ := range c { + var s sockjs.Session + if s, ok = m.sockets[sid]; ok { + if m.conn.Listening() { + s.Send(msg) + } + } else { + delete(c, sid) + } + } + + return nil +} + +func (m *Multiplexer) subchannelMove(cid string, scid string, sid string) error { + m.cmux.Lock() + defer m.cmux.Unlock() + + c, ok := m.channels[cid] + if !ok { + return fmt.Errorf("Sockets: attempted to move socket of ID %v in channel %v, which does not exist, to subchannel %v", sid, cid, scid) + } + + c[sid] = scid + + return nil +} + +func (m *Multiplexer) subchannelBroadcast(cid string, msg string) error { + m.cmux.Lock() + defer m.cmux.Unlock() + + c, ok := m.channels[cid] + if !ok { + return fmt.Errorf("Sockets: attempted to broadcast to subchannels in channel %v, which doesn't exist: %v", cid, msg) + } + + m.smux.Lock() + defer m.smux.Unlock() + + match := m.scre.FindAllStringSubmatch(msg, len(msg)) + for sid, scid := range c { + s, ok := m.sockets[sid] + if !ok { + return fmt.Errorf("Sockets: attempted to broadcast to subchannels in channel %v, but socket of ID %v doesn't exist: %v", cid, sid, msg) + } + + var msg string + for _, msgs := range match { + switch scid { + case DEFAULT_SUBCHANNEL_ID: + msg = msgs[1] + case P1_SUBCHANNEL_ID: + msg = msgs[2] + case P2_SUBCHANNEL_ID: + msg = msgs[3] + } + } + + if m.conn.Listening() { + s.Send(msg) + } + } + + return nil +} + +func (m *Multiplexer) Handler(s sockjs.Session) { + sid := m.socketAdd(s) + for { + if msg, err := s.Recv(); err == nil { + if err = m.socketReceive(sid, msg); err != nil { + // Likely a SockJS glitch if this happens at all. + fmt.Printf("%v\n", err) + break + } + continue + } + break + } + + if err := m.socketRemove(sid, false); err != nil { + // Socket was already removed by a message from the parent process. + fmt.Printf("%v\n", err) + } +} diff --git a/sockets/lib/multiplexer_test.go b/sockets/lib/multiplexer_test.go new file mode 100644 index 0000000000000..93ad3becb74a5 --- /dev/null +++ b/sockets/lib/multiplexer_test.go @@ -0,0 +1,111 @@ +package sockets + +import ( + "net" + "testing" +) + +func TestMultiplexer(t *testing.T) { + port := ":3000" + ln, _ := net.Listen("tcp", "localhost"+port) + defer ln.Close() + conn, _ := NewConnection("PS_IPC_PORT") + defer conn.Close() + mux := NewMultiplexer() + mux.Listen(conn) + + t.Run("*Multiplexer.socketAdd", func(t *testing.T) { + sid := mux.socketAdd(testSocket{}) + if len(mux.sockets) != 1 { + t.Errorf("Sockets: adding sockets to multiplexer doesn't keep them instead") + } + mux.socketRemove(sid, true) + }) + t.Run("*Multiplexer.socketRemove", func(t *testing.T) { + sid := mux.socketAdd(testSocket{}) + if err := mux.socketRemove(sid, true); err != nil { + t.Errorf("%v", err) + } + if len(mux.sockets) != 0 { + t.Errorf("Sockets: forcibly removing sockets from multiplexer keeps them instead") + } + sid = mux.socketAdd(testSocket{}) + if err := mux.socketRemove(sid, false); err != nil { + t.Errorf("%v", err) + } + if len(mux.sockets) != 0 { + t.Fatalf("Sockets: sockets removing themselves from multiplexer keeps them instead") + } + }) + t.Run("*Multiplexer.channelAdd", func(t *testing.T) { + sid := mux.socketAdd(testSocket{}) + if err := mux.channelAdd("global", sid); err != nil { + t.Errorf("%v", err) + } + if len(mux.channels) != 1 { + t.Errorf("Sockets: adding channels to multiplexer doesn't keep them instead") + } + if err := mux.channelAdd("global", sid); err != nil { + t.Errorf("%v", err) + } + mux.channelRemove("global", sid) + mux.socketRemove(sid, true) + }) + t.Run("*Multiplexer.channelRemove", func(t *testing.T) { + sid := mux.socketAdd(testSocket{}) + mux.channelAdd("global", sid) + if err := mux.channelRemove("global", sid); err != nil { + t.Errorf("%v", err) + } + if len(mux.channels) != 0 { + t.Errorf("Sockets: removing channels from multiplexer keeps them instead") + } + if err := mux.channelRemove("global", sid); err != nil { + t.Errorf("%v", err) + } + mux.socketRemove(sid, true) + }) + t.Run("*Multiplexer.channelBroadcast", func(t *testing.T) { + sid := mux.socketAdd(testSocket{}) + mux.channelAdd("global", sid) + if err := mux.channelBroadcast("global", "|raw|ayy lmao"); err != nil { + t.Errorf("%v", err) + } + mux.channelRemove("global", sid) + if err := mux.channelBroadcast("global", "|raw|ayy lmao"); err != nil { + t.Errorf("%v", err) + } + mux.socketRemove(sid, true) + }) + t.Run("*Multiplexer.subchannelMove", func(t *testing.T) { + sid := mux.socketAdd(testSocket{}) + mux.channelAdd("global", sid) + if err := mux.subchannelMove("global", P1_SUBCHANNEL_ID, sid); err != nil { + t.Errorf("%v", err) + } + mux.channelRemove("global", sid) + mux.socketRemove(sid, true) + }) + t.Run("*Multiplexer.subchannelBroadcast", func(t *testing.T) { + msg := "\n/split\n1\n2\n3\nrest" + matches := mux.scre.FindAllStringSubmatch(msg, len(msg)) + msgs := matches[0] + if msgs[1] != "\n1" { + t.Errorf("Sockets: expected broadcast to subchannel '0' to be %v, but was actually %v", "\n1", msgs[1]) + } + if msgs[2] != "\n2" { + t.Errorf("Sockets: expected broadcast to subchannel '1' to be %v, but was actually %v", "\n2", msgs[2]) + } + if msgs[3] != "\n3" { + t.Errorf("Sockets: expected broadcast to subchannel '2' to be %v, but was actually %v", "\n3", msgs[3]) + } + + sid := mux.socketAdd(testSocket{}) + mux.channelAdd("global", sid) + if err := mux.subchannelBroadcast("global", msg); err != nil { + t.Errorf("%v", err) + } + mux.channelRemove("global", sid) + mux.socketRemove(sid, true) + }) +} diff --git a/sockets/main.go b/sockets/main.go new file mode 100644 index 0000000000000..32a17166ed921 --- /dev/null +++ b/sockets/main.go @@ -0,0 +1,118 @@ +package main + +import ( + "crypto/tls" + "fmt" + "log" + "net" + "net/http" + "path/filepath" + + "github.com/Zarel/Pokemon-Showdown/sockets/lib" + + routing "github.com/gorilla/mux" + "github.com/igm/sockjs-go/sockjs" +) + +func notFoundHandler(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "/404.html", http.StatusSeeOther) +} + +func main() { + // Parse our config settings passed through the $PS_CONFIG environment + // variable by the parent process. + config, err := sockets.NewConfig("PS_CONFIG") + if err != nil { + log.Fatal("Sockets: failed to read parent's config settings from environment") + } + + // Instantiate the socket multiplexer and IPC struct.. + mux := sockets.NewMultiplexer() + conn, err := sockets.NewConnection("PS_IPC_PORT") + if err != nil { + log.Fatal(err) + } + defer conn.Close() + + // Begin listening for incoming messages from sockets and the TCP + // connection to the parent process. For now, they'll just get enqueued + // for workers to manage later.. + mux.Listen(conn) + err = conn.Listen(mux) + if err != nil { + log.Fatal("%v", err) + } + + // Set up server routing. + r := routing.NewRouter() + + avatarDir, _ := filepath.Abs("./config/avatars") + r.PathPrefix("/avatars/"). + Handler(http.FileServer(http.Dir(avatarDir))) + + customCSSDir, _ := filepath.Abs("./config") + r.Handle("/custom.css", http.FileServer(http.Dir(customCSSDir))) + + // Set up the SockJS server. + opts := sockjs.Options{ + SockJSURL: "//play.pokemonshowdown.com/js/lib/sockjs-1.1.1-nwjsfix.min.js", + Websocket: true, + HeartbeatDelay: sockjs.DefaultOptions.HeartbeatDelay, + DisconnectDelay: sockjs.DefaultOptions.DisconnectDelay, + JSessionID: sockjs.DefaultOptions.JSessionID} + + r.PathPrefix("/showdown"). + Handler(sockjs.NewHandler("/showdown", opts, mux.Handler)) + + staticDir, _ := filepath.Abs("/static") + r.Handle("/", http.StripPrefix("/static", http.FileServer(http.Dir(staticDir)))) + + r.NotFoundHandler = http.HandlerFunc(notFoundHandler) + + // Begin serving over HTTPS if configured to do so. + if config.SSL.Options.Cert != "" && config.SSL.Options.Key != "" { + go func(ba string, port string, cert string, key string) { + certs, err := tls.LoadX509KeyPair(cert, key) + if err != nil { + log.Fatalf("Sockets: failed to load certificate and key files for TLS: %v", err) + } + + srv := &http.Server{ + Handler: r, + Addr: ba + port, + TLSConfig: &tls.Config{Certificates: []tls.Certificate{certs}}} + + var ln net.Listener + ln, err = tls.Listen("tcp4", srv.Addr, srv.TLSConfig) + defer ln.Close() + if err != nil { + log.Fatalf("Sockets: failed to listen on %v over HTTPS", srv.Addr) + } + + fmt.Printf("Sockets: now serving on https://%v%v/\n", ba, port) + log.Fatal(http.Serve(ln, r)) + }(config.BindAddress, config.SSL.Port, config.SSL.Options.Cert, config.SSL.Options.Key) + } + + // Begin serving over HTTP. + go func(ba string, port string) { + srv := &http.Server{ + Handler: r, + Addr: ba + port} + + ln, err := net.Listen("tcp4", srv.Addr) + defer ln.Close() + if err != nil { + log.Fatalf("Sockets: failed to listen on %v over HTTP", srv.Addr) + } + + fmt.Printf("Sockets: now serving on http://%v%v/\n", ba, port) + log.Fatal(http.Serve(ln, r)) + }(config.BindAddress, config.Port) + + // Finally, spawn workers.to pipe messages received at the multiplexer or + // IPC connection to each other concurrently. + master := sockets.NewMaster(config.Workers) + master.Spawn() + master.Listen() +} diff --git a/test/application/sockets.js b/test/application/sockets.js index 7fc963ebe9bf2..37d9284c6092a 100644 --- a/test/application/sockets.js +++ b/test/application/sockets.js @@ -1,212 +1,79 @@ 'use strict'; const assert = require('assert'); -const cluster = require('cluster'); -describe.skip('Sockets', function () { - const spawnWorker = () => ( - new Promise(resolve => { - let worker = Sockets.spawnWorker(); - worker.removeAllListeners('message'); - resolve(worker); - }) - ); +const {createSocket} = require('../../dev-tools/sockets'); +describe('Sockets workers', function () { before(function () { - cluster.settings.silent = true; - cluster.removeAllListeners('disconnect'); + this.mux = new (require('../../sockets-workers')).Multiplexer(); + clearInterval(this.mux.cleanupInterval); + this.mux.cleanupInterval = null; + this.mux.sendUpstream = () => {}; }); - afterEach(function () { - Sockets.workers.forEach((worker, workerid) => { - worker.kill(); - Sockets.workers.delete(workerid); - }); + beforeEach(function () { + this.socket = createSocket(); }); - describe('master', function () { - it('should be able to spawn workers', function () { - Sockets.spawnWorker(); - assert.strictEqual(Sockets.workers.size, 1); - }); - - it('should be able to spawn workers on listen', function () { - Sockets.listen(0, '127.0.0.1', 1); - assert.strictEqual(Sockets.workers.size, 1); - }); - - it('should be able to kill workers', function () { - return spawnWorker().then(worker => { - Sockets.killWorker(worker); - assert.strictEqual(Sockets.workers.size, 0); - }); - }); - - it('should be able to kill workers by PID', function () { - return spawnWorker().then(worker => { - Sockets.killPid(worker.process.pid); - assert.strictEqual(Sockets.workers.size, 0); - }); - }); + afterEach(function () { + this.mux.tryDestroySocket(this.socket); + this.mux.channels.clear(); }); - describe('workers', function () { - // This composes a sequence of HOFs that send a message to a worker, - // wait for its response, then return the worker for the next function - // to use. - const chain = (eventHandler, msg) => worker => { - worker.once('message', eventHandler(worker)); - msg = msg || `$ - const {Session} = require('sockjs/lib/transport'); - const socket = new Session('aaaaaaaa', server); - socket.remoteAddress = '127.0.0.1'; - if (!('headers' in socket)) socket.headers = {}; - socket.headers['x-forwarded-for'] = ''; - socket.protocol = 'websocket'; - socket.write = msg => process.send(msg); - server.emit('connection', socket);`; - worker.send(msg); - return worker; - }; - - const spawnSocket = eventHandler => spawnWorker().then(chain(eventHandler)); - - it('should allow sockets to connect', function () { - return spawnSocket(worker => data => { - let cmd = data.charAt(0); - let [sid, ip, protocol] = data.substr(1).split('\n'); - assert.strictEqual(cmd, '*'); - assert.strictEqual(sid, '1'); - assert.strictEqual(ip, '127.0.0.1'); - assert.strictEqual(protocol, 'websocket'); - }); - }); - - it('should allow sockets to disconnect', function () { - let querySocket; - return spawnSocket(worker => data => { - let sid = data.substr(1, data.indexOf('\n')); - querySocket = `$ - let socket = sockets.get(${sid}); - process.send(!socket);`; - Sockets.socketDisconnect(worker, sid); - }).then(chain(worker => data => { - assert.ok(data); - }, querySocket)); - }); - - it('should allow sockets to send messages', function () { - let msg = 'ayy lmao'; - let socketSend; - return spawnSocket(worker => data => { - let sid = data.substr(1, data.indexOf('\n')); - socketSend = `>${sid}\n${msg}`; - }).then(chain(worker => data => { - assert.strictEqual(data, msg); - }, socketSend)); - }); + after(function () { + this.socket = null; + this.mux.sockets.clear(); + this.mux = null; + }); - it('should allow sockets to receive messages', function () { - let sid; - let msg; - let mockReceive; - return spawnSocket(worker => data => { - sid = data.substr(1, data.indexOf('\n')); - msg = '|/cmd rooms'; - mockReceive = `$ - let socket = sockets.get(${sid}); - socket.emit('data', ${msg});`; - }).then(chain(worker => data => { - let cmd = data.charAt(0); - let params = data.substr(1).split('\n'); - assert.strictEqual(cmd, '<'); - assert.strictEqual(sid, params[0]); - assert.strictEqual(msg, params[1]); - }, mockReceive)); - }); + it('should parse more than two params', function () { + let params = '1\n1\n0\n'; + let ret = this.mux.parseParams(params, 4); + assert.deepStrictEqual(ret, ['1', '1', '0', '']); + }); - it('should create a channel for the first socket to get added to it', function () { - let queryChannel; - return spawnSocket(worker => data => { - let sid = data.substr(1, data.indexOf('\n')); - let cid = 'global'; - queryChannel = `$ - let channel = channels.get(${cid}); - process.send(channel && channel.has(${sid}));`; - Sockets.channelAdd(worker, cid, sid); - }).then(chain(worker => data => { - assert.ok(data); - }, queryChannel)); - }); + it('should parse params with multiple newlines', function () { + let params = '0\n|1\n|2'; + let ret = this.mux.parseParams(params, 2); + assert.deepStrictEqual(ret, ['0', '|1\n|2']); + }); - it('should remove a channel if the last socket gets removed from it', function () { - let queryChannel; - return spawnSocket(worker => data => { - let sid = data.substr(1, data.indexOf('\n')); - let cid = 'global'; - queryChannel = `$ - process.send(!sockets.has(${sid}) && !channels.has(${cid}));`; - Sockets.channelAdd(worker, cid, sid); - Sockets.channelRemove(worker, cid, sid); - }).then(chain(worker => data => { - assert.ok(data); - }, queryChannel)); - }); + it('should add sockets on connect', function () { + let res = this.mux.onSocketConnect(this.socket); + assert.ok(res); + }); - it('should send to all sockets in a channel', function () { - let msg = 'ayy lmao'; - let cid = 'global'; - let channelSend = `#${cid}\n${msg}`; - return spawnSocket(worker => data => { - let sid = data.substr(1, data.indexOf('\n')); - Sockets.channelAdd(worker, cid, sid); - }).then(chain(worker => data => { - assert.strictEqual(data, msg); - }, channelSend)); - }); + it('should remove sockets on disconnect', function () { + this.mux.onSocketConnect(this.socket); + let res = this.mux.onSocketDisconnect('0', this.socket); + assert.ok(res); + }); - it('should create a subchannel when moving a socket to it', function () { - let querySubchannel; - return spawnSocket(worker => data => { - let sid = data.substr(1, data.indexOf('\n')); - let cid = 'battle-ou-1'; - let scid = '1'; - querySubchannel = `$ - let subchannel = subchannels[${cid}]; - process.send(!!subchannel && (subchannel.get(${sid}) === ${scid}));`; - Sockets.subchannelMove(worker, cid, scid, sid); - }).then(chain(worker => data => { - assert.ok(data); - }, querySubchannel)); - }); + it('should add sockets to channels', function () { + this.mux.onSocketConnect(this.socket); + let res = this.mux.onChannelAdd('global', '0'); + assert.ok(res); + res = this.mux.onChannelAdd('global', '0'); + assert.ok(!res); + this.mux.channels.set('lobby', new Map()); + res = this.mux.onChannelAdd('lobby', '0'); + assert.ok(res); + }); - it('should remove a subchannel when removing its last socket', function () { - let querySubchannel; - return spawnSocket(worker => data => { - let sid = data.substr(1, data.indexOf('\n')); - let cid = 'battle-ou-1'; - let scid = '1'; - querySubchannel = `$ - let subchannel = subchannels.get(${cid}); - process.send(!!subchannel && (subchannel.get(${sid}) === ${scid}));`; - Sockets.subchannelMove(worker, cid, scid, sid); - Sockets.channelRemove(worker, cid, sid); - }).then(chain(worker => data => { - assert.ok(data); - }, querySubchannel)); - }); + it('should remove sockets from channels', function () { + this.mux.onSocketConnect(this.socket); + this.mux.onChannelAdd('global', '0'); + let res = this.mux.onChannelRemove('global', '0'); + assert.ok(res); + res = this.mux.onChannelRemove('global', '0'); + assert.ok(!res); + }); - it('should send to sockets in a subchannel', function () { - let cid = 'battle-ou-1'; - let msg = 'ayy lmao'; - let subchannelSend = `.${cid}\n\n|split\n\n${msg}\n\n`; - return spawnSocket(worker => data => { - let sid = data.substr(1, data.indexOf('\n')); - let scid = '1'; - Sockets.subchannelMove(worker, cid, scid, sid); - }).then(chain(worker => data => { - assert.strictEqual(data, msg); - }, subchannelSend)); - }); + it('should move sockets to subchannels', function () { + this.mux.onSocketConnect(this.socket); + this.mux.onChannelAdd('global', '0'); + let res = this.mux.onSubchannelMove('global', '1', '0'); + assert.ok(res); }); }); diff --git a/tsconfig.json b/tsconfig.json index 7e55bd7bbb814..9f97de79d15ec 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -7,11 +7,13 @@ "allowJs": true, "checkJs": true }, - "types": ["node"], + "types": ["node", "sockjs-node"], "include": [ "./dev-tools/globals.ts", "./sim/dex-data.js", "./sim/dex.js", - "./sim/prng.js" + "./sim/prng.js", + "./sockets.js", + "./sockets-workers.js" ] } diff --git a/users.js b/users.js index 50252b5c1869e..2b8e054ec7c37 100644 --- a/users.js +++ b/users.js @@ -1610,3 +1610,20 @@ Users.socketReceive = function (worker, workerid, socketid, message) { Monitor.warn(`[slow] ${deltaTime}ms - ${user.name} <${connection.ip}>: ${roomId}|${message}`); } }; + +/** + * @description Clears all connections whose sockets were contained by a + * worker. Called after a worker's process crashes or gets killed. + * @param {object} worker + * @returns {number} + */ +Users.socketDisconnectAll = function (worker) { + let count = 0; + connections.forEach(connection => { + if (connection.worker === worker) { + Users.socketDisconnect(worker, worker.id, connection.socketid); + count++; + } + }); + return count; +};