From 7bbc471d485fa6227fe459c0ff6f91e27a56a752 Mon Sep 17 00:00:00 2001 From: Guangcong Luo Date: Thu, 30 Nov 2023 17:01:08 -0600 Subject: [PATCH] Support Postgres in database.ts (#20) --- src/actions.ts | 17 ++-- src/database.ts | 208 +++++++++++++++++++++++++++++------------------- src/tables.ts | 19 +++-- 3 files changed, 148 insertions(+), 96 deletions(-) diff --git a/src/actions.ts b/src/actions.ts index c52e9a1..80b8911 100644 --- a/src/actions.ts +++ b/src/actions.ts @@ -11,6 +11,7 @@ import {Replays} from './replays'; import {ActionError, QueryHandler, Server} from './server'; import {toID, updateserver, bash, time, escapeHTML} from './utils'; import * as tables from './tables'; +import {SQL} from './database'; import * as pathModule from 'path'; import IPTools from './ip-tools'; import * as crypto from 'crypto'; @@ -662,9 +663,9 @@ export const actions: {[k: string]: QueryHandler} = { } let teams = []; try { - teams = await tables.pgdb.query( - 'SELECT teamid, team, format, title as name FROM teams WHERE ownerid = $1', [this.user.id] - ) ?? []; + teams = await tables.teams.selectAll( + SQL`teamid, team, format, title as name` + )`WHERE ownerid = ${this.user.id}`; } catch (e) { Server.crashlog(e, 'a teams database query', params); throw new ActionError('The server could not load your teams. Please try again later.'); @@ -693,13 +694,13 @@ export const actions: {[k: string]: QueryHandler} = { throw new ActionError("Invalid team ID"); } try { - const data = await tables.pgdb.query( - `SELECT ownerid, team, private as privacy FROM teams WHERE teamid = $1`, [teamid] - ); - if (!data || !data.length || data[0].ownerid !== this.user.id) { + const data = await tables.teams.selectOne( + SQL`ownerid, team, private as privacy` + )`WHERE teamid = ${teamid}`; + if (!data || data.ownerid !== this.user.id) { return {team: null}; } - return data[0]; + return data; } catch (e) { Server.crashlog(e, 'a teams database request', params); throw new ActionError("Failed to fetch team. Please try again later."); diff --git a/src/database.ts b/src/database.ts index fb778b7..7b4fc53 100644 --- a/src/database.ts +++ b/src/database.ts @@ -37,10 +37,11 @@ export class SQLStatement { } else if (value === undefined) { this.sql[this.sql.length - 1] += nextString; } else if (Array.isArray(value)) { - if (this.sql[this.sql.length - 1].endsWith(`\``)) { + if ('"`'.includes(this.sql[this.sql.length - 1].slice(-1))) { // "`a`, `b`" syntax + const quoteChar = this.sql[this.sql.length - 1].slice(-1); for (const col of value) { - this.append(col, `\`, \``); + this.append(col, `${quoteChar}, ${quoteChar}`); } this.sql[this.sql.length - 1] = this.sql[this.sql.length - 1].slice(0, -4) + nextString; } else { @@ -52,21 +53,21 @@ export class SQLStatement { } } else if (this.sql[this.sql.length - 1].endsWith('(')) { // "(`a`, `b`) VALUES (1, 2)" syntax - this.sql[this.sql.length - 1] += `\``; + this.sql[this.sql.length - 1] += `"`; for (const col in value) { - this.append(col, `\`, \``); + this.append(col, `", "`); } - this.sql[this.sql.length - 1] = this.sql[this.sql.length - 1].slice(0, -4) + `\`) VALUES (`; + this.sql[this.sql.length - 1] = this.sql[this.sql.length - 1].slice(0, -4) + `") VALUES (`; for (const col in value) { this.append(value[col], `, `); } this.sql[this.sql.length - 1] = this.sql[this.sql.length - 1].slice(0, -2) + nextString; } else if (this.sql[this.sql.length - 1].toUpperCase().endsWith(' SET ')) { // "`a` = 1, `b` = 2" syntax - this.sql[this.sql.length - 1] += `\``; + this.sql[this.sql.length - 1] += `"`; for (const col in value) { - this.append(col, `\` = `); - this.append(value[col], `, \``); + this.append(col, `" = `); + this.append(value[col], `, "`); } this.sql[this.sql.length - 1] = this.sql[this.sql.length - 1].slice(0, -3) + nextString; } else { @@ -83,27 +84,29 @@ export class SQLStatement { * Tag function for SQL, with some magic. * * * `` SQL`UPDATE table SET a = ${'hello"'}` `` - * * `` 'UPDATE table SET a = "hello"' `` + * * `` `UPDATE table SET a = 'hello'` `` * - * Values surrounded by `` \` `` become names: + * Values surrounded by `"` or `` ` `` become identifiers: * - * * ``` SQL`SELECT * FROM \`${'table'}\`` ``` - * * `` 'SELECT * FROM `table`' `` + * * ``` SQL`SELECT * FROM "${'table'}"` ``` + * * `` `SELECT * FROM "table"` `` + * + * (Make sure to use `"` for Postgres and `` ` `` for MySQL.) * * Objects preceded by SET become setters: * * * `` SQL`UPDATE table SET ${{a: 1, b: 2}}` `` - * * `` 'UPDATE table SET `a` = 1, `b` = 2' `` + * * `` `UPDATE table SET "a" = 1, "b" = 2` `` * * Objects surrounded by `()` become keys and values: * * * `` SQL`INSERT INTO table (${{a: 1, b: 2}})` `` - * * `` 'INSERT INTO table (`a`, `b`) VALUES (1, 2)' `` + * * `` `INSERT INTO table ("a", "b") VALUES (1, 2)` `` * - * Arrays become lists; surrounding by `` \` `` turns them into lists of names: + * Arrays become lists; surrounding by `"` or `` ` `` turns them into lists of names: * - * * `` SQL`INSERT INTO table (\`${['a', 'b']}\`) VALUES (${[1, 2]})` `` - * * `` 'INSERT INTO table (`a`, `b`) VALUES (1, 2)' `` + * * `` SQL`INSERT INTO table ("${['a', 'b']}") VALUES (${[1, 2]})` `` + * * `` `INSERT INTO table ("a", "b") VALUES (1, 2)` `` */ export function SQL(strings: TemplateStringsArray, ...values: SQLValue[]) { return new SQLStatement(strings, values); @@ -113,53 +116,24 @@ export interface ResultRow {[k: string]: BasicSQLValue} export const connectedDatabases: Database[] = []; -export class Database { - connection: mysql.Pool; +export abstract class Database { + connection: Pool; prefix: string; - constructor(config: mysql.PoolOptions & {prefix?: string}) { - this.prefix = config.prefix || ""; - if (config.prefix) { - config = {...config}; - delete config.prefix; - } - this.connection = mysql.createPool(config); + constructor(connection: Pool, prefix = '') { + this.prefix = prefix; + this.connection = connection; connectedDatabases.push(this); } - resolveSQL(query: SQLStatement): [query: string, values: BasicSQLValue[]] { - let sql = query.sql[0]; - const values = []; - for (let i = 0; i < query.values.length; i++) { - const value = query.values[i]; - if (query.sql[i + 1].startsWith('`')) { - sql = sql.slice(0, -1) + this.connection.escapeId('' + value) + query.sql[i + 1].slice(1); - } else { - sql += '?' + query.sql[i + 1]; - values.push(value); - } - } - return [sql, values]; - } + abstract _resolveSQL(query: SQLStatement): [query: string, values: BasicSQLValue[]]; + abstract _query(sql: string, values: BasicSQLValue[]): Promise; + abstract escapeId(param: string): string; query(sql: SQLStatement): Promise; query(): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise; query(sql?: SQLStatement) { if (!sql) return (strings: any, ...rest: any) => this.query(new SQLStatement(strings, rest)); - return new Promise((resolve, reject) => { - const [query, values] = this.resolveSQL(sql); - this.connection.query(query, values, (e, results: any) => { - if (e) { - return reject(new Error(`${e.message} (${query}) (${values}) [${e.code}]`)); - } - if (Array.isArray(results)) { - for (const row of results) { - for (const col in row) { - if (Buffer.isBuffer(row[col])) row[col] = row[col].toString(); - } - } - } - return resolve(results); - }); - }); + const [query, values] = this._resolveSQL(sql); + return this._query(query, values); } queryOne(sql: SQLStatement): Promise; queryOne(): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise; @@ -168,14 +142,14 @@ export class Database { return this.query(sql).then(res => Array.isArray(res) ? res[0] : res); } - queryExec(sql: SQLStatement): Promise; - queryExec(): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise; + queryExec(sql: SQLStatement): Promise; + queryExec(): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise; queryExec(sql?: SQLStatement) { if (!sql) return (strings: any, ...rest: any) => this.queryExec(new SQLStatement(strings, rest)); - return this.queryOne(sql); + return this.queryOne(sql); } close() { - this.connection.end(); + void this.connection.end(); } } @@ -198,7 +172,7 @@ export class DatabaseTable { this.primaryKeyName = primaryKeyName; } escapeId(param: string) { - return this.db.connection.escapeId(param); + return this.db.escapeId(param); } // raw @@ -224,45 +198,52 @@ export class DatabaseTable { selectAll(entries?: (keyof Row & string)[] | SQLStatement): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise { if (!entries) entries = SQL`*`; - if (Array.isArray(entries)) entries = SQL`\`${entries}\``; + if (Array.isArray(entries)) entries = SQL`"${entries}"`; return (strings, ...rest) => - this.query()`SELECT ${entries} FROM \`${this.name}\` ${new SQLStatement(strings, rest)}`; + this.query()`SELECT ${entries} FROM "${this.name}" ${new SQLStatement(strings, rest)}`; } selectOne(entries?: (keyof Row & string)[] | SQLStatement): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise { if (!entries) entries = SQL`*`; - if (Array.isArray(entries)) entries = SQL`\`${entries}\``; + if (Array.isArray(entries)) entries = SQL`"${entries}"`; return (strings, ...rest) => - this.queryOne()`SELECT ${entries} FROM \`${this.name}\` ${new SQLStatement(strings, rest)} LIMIT 1`; + this.queryOne()`SELECT ${entries} FROM "${this.name}" ${new SQLStatement(strings, rest)} LIMIT 1`; } updateAll(partialRow: PartialOrSQL): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise { return (strings, ...rest) => - this.queryExec()`UPDATE \`${this.name}\` SET ${partialRow as any} ${new SQLStatement(strings, rest)}`; + this.queryExec()`UPDATE "${this.name}" SET ${partialRow as any} ${new SQLStatement(strings, rest)}`; } updateOne(partialRow: PartialOrSQL): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise { return (s, ...r) => - this.queryExec()`UPDATE \`${this.name}\` SET ${partialRow as any} ${new SQLStatement(s, r)} LIMIT 1`; + this.queryExec()`UPDATE "${this.name}" SET ${partialRow as any} ${new SQLStatement(s, r)} LIMIT 1`; } deleteAll(): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise { return (strings, ...rest) => - this.queryExec()`DELETE FROM \`${this.name}\` ${new SQLStatement(strings, rest)}`; + this.queryExec()`DELETE FROM "${this.name}" ${new SQLStatement(strings, rest)}`; } deleteOne(): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise { return (strings, ...rest) => - this.queryExec()`DELETE FROM \`${this.name}\` ${new SQLStatement(strings, rest)} LIMIT 1`; + this.queryExec()`DELETE FROM "${this.name}" ${new SQLStatement(strings, rest)} LIMIT 1`; + } + eval(): + (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise { + return (strings, ...rest) => + this.queryOne<{result: T}>( + )`SELECT ${new SQLStatement(strings, rest)} AS result FROM "${this.name}" LIMIT 1` + .then(row => row?.result); } // high-level insert(partialRow: PartialOrSQL, where?: SQLStatement) { - return this.queryExec()`INSERT INTO \`${this.name}\` (${partialRow as SQLValue}) ${where}`; + return this.queryExec()`INSERT INTO "${this.name}" (${partialRow as SQLValue}) ${where}`; } insertIgnore(partialRow: PartialOrSQL, where?: SQLStatement) { - return this.queryExec()`INSERT IGNORE INTO \`${this.name}\` (${partialRow as SQLValue}) ${where}`; + return this.queryExec()`INSERT IGNORE INTO "${this.name}" (${partialRow as SQLValue}) ${where}`; } async tryInsert(partialRow: PartialOrSQL, where?: SQLStatement) { try { @@ -279,28 +260,89 @@ export class DatabaseTable { return this.replace(partialRow, where); } replace(partialRow: PartialOrSQL, where?: SQLStatement) { - return this.queryExec()`REPLACE INTO \`${this.name}\` (${partialRow as SQLValue}) ${where}`; + return this.queryExec()`REPLACE INTO "${this.name}" (${partialRow as SQLValue}) ${where}`; } get(primaryKey: BasicSQLValue, entries?: (keyof Row & string)[] | SQLStatement) { - return this.selectOne(entries)`WHERE \`${this.primaryKeyName}\` = ${primaryKey}`; + return this.selectOne(entries)`WHERE "${this.primaryKeyName}" = ${primaryKey}`; } delete(primaryKey: BasicSQLValue) { - return this.deleteAll()`WHERE \`${this.primaryKeyName}\` = ${primaryKey} LIMIT 1`; + return this.deleteAll()`WHERE "${this.primaryKeyName}" = ${primaryKey} LIMIT 1`; } update(primaryKey: BasicSQLValue, data: PartialOrSQL) { - return this.updateAll(data)`WHERE \`${this.primaryKeyName}\` = ${primaryKey} LIMIT 1`; + return this.updateAll(data)`WHERE "${this.primaryKeyName}" = ${primaryKey} LIMIT 1`; } } -export class PGDatabase { - database: pg.Pool | null; - constructor(config: pg.PoolConfig | null) { - this.database = config ? new pg.Pool(config) : null; +export class MySQLDatabase extends Database { + constructor(config: mysql.PoolOptions & {prefix?: string}) { + const prefix = config.prefix || ""; + if (config.prefix) { + config = {...config}; + delete config.prefix; + } + super(mysql.createPool(config), prefix); } - async query(query: string, values: BasicSQLValue[]) { - if (!this.database) return null; - const result = await this.database.query(query, values); - return result.rows as O[]; + override _resolveSQL(query: SQLStatement): [query: string, values: BasicSQLValue[]] { + let sql = query.sql[0]; + const values = []; + for (let i = 0; i < query.values.length; i++) { + const value = query.values[i]; + if (query.sql[i + 1].startsWith('`') || query.sql[i + 1].startsWith('"')) { + sql = sql.slice(0, -1) + this.escapeId('' + value) + query.sql[i + 1].slice(1); + } else { + sql += '?' + query.sql[i + 1]; + values.push(value); + } + } + return [sql, values]; + } + override _query(query: string, values: BasicSQLValue[]): Promise { + return new Promise((resolve, reject) => { + this.connection.query(query, values, (e, results: any) => { + if (e) { + return reject(new Error(`${e.message} (${query}) (${values}) [${e.code}]`)); + } + if (Array.isArray(results)) { + for (const row of results) { + for (const col in row) { + if (Buffer.isBuffer(row[col])) row[col] = row[col].toString(); + } + } + } + return resolve(results); + }); + }); + } + override escapeId(id: string) { + return this.connection.escapeId(id); } } +export class PGDatabase extends Database { + constructor(config: pg.PoolConfig) { + super(new pg.Pool(config)); + } + override _resolveSQL(query: SQLStatement): [query: string, values: BasicSQLValue[]] { + let sql = query.sql[0]; + const values = []; + let paramCount = 0; + for (let i = 0; i < query.values.length; i++) { + const value = query.values[i]; + if (query.sql[i + 1].startsWith('`') || query.sql[i + 1].startsWith('"')) { + sql = sql.slice(0, -1) + this.escapeId('' + value) + query.sql[i + 1].slice(1); + } else { + paramCount++; + sql += `$${paramCount}` + query.sql[i + 1]; + values.push(value); + } + } + return [sql, values]; + } + override _query(query: string, values: BasicSQLValue[]) { + return this.connection.query(query, values).then(res => res.rows); + } + override escapeId(id: string) { + // @ts-expect-error @types/pg really needs to be updated + return pg.escapeIdentifier(id); + } +} diff --git a/src/tables.ts b/src/tables.ts index 5860ff2..f65caa6 100644 --- a/src/tables.ts +++ b/src/tables.ts @@ -1,17 +1,17 @@ /** * Login server database tables */ -import {Database, DatabaseTable, PGDatabase} from './database'; +import {DatabaseTable, MySQLDatabase, PGDatabase} from './database'; import {Config} from './config-loader'; import type {LadderEntry} from './ladder'; import type {ReplayData} from './replays'; // direct access -export const psdb = new Database(Config.mysql); -export const pgdb = new PGDatabase(Config.postgres); -export const replaysDB = Config.replaysdb ? new Database(Config.replaysdb!) : psdb; -export const ladderDB = Config.ladderdb ? new Database(Config.ladderdb!) : psdb; +export const psdb = new MySQLDatabase(Config.mysql); +export const pgdb = new PGDatabase(Config.postgres!); +export const replaysDB = Config.replaysdb ? new MySQLDatabase(Config.replaysdb!) : psdb; +export const ladderDB = Config.ladderdb ? new MySQLDatabase(Config.ladderdb!) : psdb; export const users = new DatabaseTable<{ userid: string; @@ -117,3 +117,12 @@ export const oauthTokens = new DatabaseTable<{ id: string; time: number; }>(psdb, 'oauth_tokens', 'id'); + +export const teams = new DatabaseTable<{ + teamid: string; + ownerid: string; + team: string; + format: string; + title: string; + private: number; +}>(pgdb, 'teams', 'teamid');