From dfc09f9f030f6f160670112e6024644789d02281 Mon Sep 17 00:00:00 2001 From: Marc Scholten Date: Fri, 4 Feb 2022 20:20:26 +0100 Subject: [PATCH] Added transactions to DataSync This adds a new high-level `withTransaction` function and a lower-level `Transaction` object to deal with postgres transactions from DataSync. The following operations can be used from within a transaction: - createRecord, createRecords - updateRecord, updateRecords - deleteRecord, deleteRecords - query --- IHP/DataSync/Controller.hs | 147 +++++++++++++++++++++++---- IHP/DataSync/Types.hs | 32 ++++-- lib/IHP/DataSync/ihp-datasync.js | 37 ++++--- lib/IHP/DataSync/ihp-querybuilder.js | 8 +- lib/IHP/DataSync/index.js | 8 +- lib/IHP/DataSync/transaction.js | 94 +++++++++++++++++ 6 files changed, 285 insertions(+), 41 deletions(-) create mode 100644 lib/IHP/DataSync/transaction.js diff --git a/IHP/DataSync/Controller.hs b/IHP/DataSync/Controller.hs index d40f6fd6e..b53c59117 100644 --- a/IHP/DataSync/Controller.hs +++ b/IHP/DataSync/Controller.hs @@ -25,6 +25,7 @@ import qualified IHP.PGListener as PGListener import IHP.ApplicationContext import Data.Set (Set) import qualified Data.Set as Set +import qualified Data.Pool as Pool instance ( PG.ToField (PrimaryKey (GetTableName CurrentUserRecord)) @@ -36,7 +37,7 @@ instance ( initialState = DataSyncController run = do - setState DataSyncReady { subscriptions = HashMap.empty } + setState DataSyncReady { subscriptions = HashMap.empty, transactions = HashMap.empty } ensureRLSEnabled <- makeCachedEnsureRLSEnabled installTableChangeTriggers <- ChangeNotifications.makeCachedInstallTableChangeTriggers @@ -45,12 +46,12 @@ instance ( let handleMessage :: DataSyncMessage -> IO () - handleMessage DataSyncQuery { query, requestId } = do + handleMessage DataSyncQuery { query, requestId, transactionId } = do ensureRLSEnabled (get #table query) let (theQuery, theParams) = compileQuery query - result :: [[Field]] <- sqlQueryWithRLS theQuery theParams + result :: [[Field]] <- sqlQueryWithRLSAndTransactionId transactionId theQuery theParams sendJSON DataSyncResult { result, requestId } @@ -131,7 +132,7 @@ instance ( sendJSON DidDeleteDataSubscription { subscriptionId, requestId } - handleMessage CreateRecordMessage { table, record, requestId } = do + handleMessage CreateRecordMessage { table, record, requestId, transactionId } = do ensureRLSEnabled table let query = "INSERT INTO ? ? VALUES ? RETURNING *" @@ -145,7 +146,7 @@ instance ( let params = (PG.Identifier table, PG.In (map PG.Identifier columns), PG.In values) - result :: [[Field]] <- sqlQueryWithRLS query params + result :: [[Field]] <- sqlQueryWithRLSAndTransactionId transactionId query params case result of [record] -> sendJSON DidCreateRecord { requestId, record } @@ -153,7 +154,7 @@ instance ( pure () - handleMessage CreateRecordsMessage { table, records, requestId } = do + handleMessage CreateRecordsMessage { table, records, requestId, transactionId } = do ensureRLSEnabled table let query = "INSERT INTO ? ? ? RETURNING *" @@ -175,13 +176,13 @@ instance ( let params = (PG.Identifier table, PG.In (map PG.Identifier columns), PG.Values [] values) - records :: [[Field]] <- sqlQueryWithRLS query params + records :: [[Field]] <- sqlQueryWithRLSAndTransactionId transactionId query params sendJSON DidCreateRecords { requestId, records } pure () - handleMessage UpdateRecordMessage { table, id, patch, requestId } = do + handleMessage UpdateRecordMessage { table, id, patch, requestId, transactionId } = do ensureRLSEnabled table let columns = patch @@ -204,7 +205,7 @@ instance ( <> (join (map (\(key, value) -> [PG.toField key, value]) keyValues)) <> [PG.toField id] - result :: [[Field]] <- sqlQueryWithRLS (PG.Query query) params + result :: [[Field]] <- sqlQueryWithRLSAndTransactionId transactionId (PG.Query query) params case result of [record] -> sendJSON DidUpdateRecord { requestId, record } @@ -212,7 +213,7 @@ instance ( pure () - handleMessage UpdateRecordsMessage { table, ids, patch, requestId } = do + handleMessage UpdateRecordsMessage { table, ids, patch, requestId, transactionId } = do ensureRLSEnabled table let columns = patch @@ -235,26 +236,63 @@ instance ( <> (join (map (\(key, value) -> [PG.toField key, value]) keyValues)) <> [PG.toField (PG.In ids)] - records <- sqlQueryWithRLS (PG.Query query) params + records <- sqlQueryWithRLSAndTransactionId transactionId (PG.Query query) params sendJSON DidUpdateRecords { requestId, records } pure () - handleMessage DeleteRecordMessage { table, id, requestId } = do + handleMessage DeleteRecordMessage { table, id, requestId, transactionId } = do ensureRLSEnabled table - sqlExecWithRLS "DELETE FROM ? WHERE id = ?" (PG.Identifier table, id) + sqlExecWithRLSAndTransactionId transactionId "DELETE FROM ? WHERE id = ?" (PG.Identifier table, id) sendJSON DidDeleteRecord { requestId } - handleMessage DeleteRecordsMessage { table, ids, requestId } = do + handleMessage DeleteRecordsMessage { table, ids, requestId, transactionId } = do ensureRLSEnabled table - sqlExecWithRLS "DELETE FROM ? WHERE id IN ?" (PG.Identifier table, PG.In ids) + sqlExecWithRLSAndTransactionId transactionId "DELETE FROM ? WHERE id IN ?" (PG.Identifier table, PG.In ids) sendJSON DidDeleteRecords { requestId } + handleMessage StartTransaction { requestId } = do + ensureBelowTransactionLimit + + transactionId <- UUID.nextRandom + + (connection, localPool) <- ?modelContext + |> get #connectionPool + |> Pool.takeResource + + let transaction = DataSyncTransaction + { id = transactionId + , connection + , releaseConnection = Pool.putResource localPool connection + } + + let globalModelContext = ?modelContext + let ?modelContext = globalModelContext { transactionConnection = Just connection } in sqlExecWithRLS "BEGIN" () + + modifyIORef' ?state (\state -> state |> modify #transactions (HashMap.insert transactionId transaction)) + + sendJSON DidStartTransaction { requestId, transactionId } + + handleMessage RollbackTransaction { requestId, id } = do + sqlExecWithRLSAndTransactionId (Just id) "ROLLBACK" () + + closeTransaction id + + sendJSON DidRollbackTransaction { requestId, transactionId = id } + + handleMessage CommitTransaction { requestId, id } = do + sqlExecWithRLSAndTransactionId (Just id) "COMMIT" () + + closeTransaction id + + sendJSON DidCommitTransaction { requestId, transactionId = id } + + forever do message <- Aeson.eitherDecodeStrict' <$> receiveData @ByteString @@ -289,13 +327,15 @@ cleanupAllSubscriptions = do let pgListener = ?applicationContext |> get #pgListener case state of - DataSyncReady { subscriptions } -> do + DataSyncReady { subscriptions, transactions } -> do let channelSubscriptions = subscriptions |> HashMap.elems |> map (get #channelSubscription) forEach channelSubscriptions \channelSubscription -> do pgListener |> PGListener.unsubscribe channelSubscription + forEach (HashMap.elems transactions) (get #releaseConnection) + pure () _ -> pure () @@ -310,8 +350,81 @@ queryFieldNamesToColumnNames sqlQuery = sqlQuery where convertOrderByClause OrderByClause { orderByColumn, orderByDirection } = OrderByClause { orderByColumn = cs (fieldNameToColumnName (cs orderByColumn)), orderByDirection } + +runInModelContextWithTransaction :: (?state :: IORef DataSyncController, _) => ((?modelContext :: ModelContext) => IO result) -> Maybe UUID -> IO result +runInModelContextWithTransaction function (Just transactionId) = do + let globalModelContext = ?modelContext + + DataSyncTransaction { connection } <- findTransactionById transactionId + let + ?modelContext = globalModelContext { transactionConnection = Just connection } + in + function +runInModelContextWithTransaction function Nothing = function + +findTransactionById :: (?state :: IORef DataSyncController) => UUID -> IO DataSyncTransaction +findTransactionById transactionId = do + transactions <- get #transactions <$> readIORef ?state + case HashMap.lookup transactionId transactions of + Just transaction -> pure transaction + Nothing -> error "No transaction with that id" + +closeTransaction transactionId = do + DataSyncTransaction { releaseConnection } <- findTransactionById transactionId + modifyIORef' ?state (\state -> state |> modify #transactions (HashMap.delete transactionId)) + releaseConnection + +-- | Allow max 10 concurrent transactions per connection to avoid running out of database connections +-- +-- Each transaction removes a database connection from the connection pool. If we don't limit the transactions, +-- a single user could take down the application by starting more than 'IHP.FrameworkConfig.DBPoolMaxConnections' +-- concurrent transactions. Then all database connections are removed from the connection pool and further database +-- queries for other users will fail. +-- +ensureBelowTransactionLimit :: (?state :: IORef DataSyncController) => IO () +ensureBelowTransactionLimit = do + transactions <- get #transactions <$> readIORef ?state + let transactionCount = HashMap.size transactions + let maxTransactionsPerConnection = 10 + when (transactionCount >= maxTransactionsPerConnection) do + error ("You've reached the transaction limit of " <> tshow maxTransactionsPerConnection <> " transactions") + +sqlQueryWithRLSAndTransactionId :: + ( ?modelContext :: ModelContext + , PG.ToRow parameters + , ?context :: ControllerContext + , userId ~ Id CurrentUserRecord + , Show (PrimaryKey (GetTableName CurrentUserRecord)) + , HasNewSessionUrl CurrentUserRecord + , Typeable CurrentUserRecord + , ?context :: ControllerContext + , HasField "id" CurrentUserRecord (Id' (GetTableName CurrentUserRecord)) + , PG.ToField userId + , FromRow result + , ?state :: IORef DataSyncController + ) => Maybe UUID -> PG.Query -> parameters -> IO [result] +sqlQueryWithRLSAndTransactionId transactionId theQuery theParams = runInModelContextWithTransaction (sqlQueryWithRLS theQuery theParams) transactionId + +sqlExecWithRLSAndTransactionId :: + ( ?modelContext :: ModelContext + , PG.ToRow parameters + , ?context :: ControllerContext + , userId ~ Id CurrentUserRecord + , Show (PrimaryKey (GetTableName CurrentUserRecord)) + , HasNewSessionUrl CurrentUserRecord + , Typeable CurrentUserRecord + , ?context :: ControllerContext + , HasField "id" CurrentUserRecord (Id' (GetTableName CurrentUserRecord)) + , PG.ToField userId + , ?state :: IORef DataSyncController + ) => Maybe UUID -> PG.Query -> parameters -> IO Int64 +sqlExecWithRLSAndTransactionId transactionId theQuery theParams = runInModelContextWithTransaction (sqlExecWithRLS theQuery theParams) transactionId + $(deriveFromJSON defaultOptions 'DataSyncQuery) $(deriveToJSON defaultOptions 'DataSyncResult) instance SetField "subscriptions" DataSyncController (HashMap UUID Subscription) where - setField subscriptions record = record { subscriptions } \ No newline at end of file + setField subscriptions record = record { subscriptions } + +instance SetField "transactions" DataSyncController (HashMap UUID DataSyncTransaction) where + setField transactions record = record { transactions } \ No newline at end of file diff --git a/IHP/DataSync/Types.hs b/IHP/DataSync/Types.hs index 7c010a35d..d2b193e8a 100644 --- a/IHP/DataSync/Types.hs +++ b/IHP/DataSync/Types.hs @@ -6,17 +6,21 @@ import IHP.QueryBuilder import IHP.DataSync.DynamicQuery import Data.HashMap.Strict (HashMap) import qualified IHP.PGListener as PGListener +import qualified Database.PostgreSQL.Simple as PG data DataSyncMessage - = DataSyncQuery { query :: !DynamicSQLQuery, requestId :: !Int } + = DataSyncQuery { query :: !DynamicSQLQuery, requestId :: !Int, transactionId :: !(Maybe UUID) } | CreateDataSubscription { query :: !DynamicSQLQuery, requestId :: !Int } | DeleteDataSubscription { subscriptionId :: !UUID, requestId :: !Int } - | CreateRecordMessage { table :: !Text, record :: !(HashMap Text Value), requestId :: !Int } - | CreateRecordsMessage { table :: !Text, records :: ![HashMap Text Value], requestId :: !Int } - | UpdateRecordMessage { table :: !Text, id :: !UUID, patch :: !(HashMap Text Value), requestId :: !Int } - | UpdateRecordsMessage { table :: !Text, ids :: ![UUID], patch :: !(HashMap Text Value), requestId :: !Int } - | DeleteRecordMessage { table :: !Text, id :: !UUID, requestId :: !Int } - | DeleteRecordsMessage { table :: !Text, ids :: ![UUID], requestId :: !Int } + | CreateRecordMessage { table :: !Text, record :: !(HashMap Text Value), requestId :: !Int, transactionId :: !(Maybe UUID) } + | CreateRecordsMessage { table :: !Text, records :: ![HashMap Text Value], requestId :: !Int, transactionId :: !(Maybe UUID) } + | UpdateRecordMessage { table :: !Text, id :: !UUID, patch :: !(HashMap Text Value), requestId :: !Int, transactionId :: !(Maybe UUID) } + | UpdateRecordsMessage { table :: !Text, ids :: ![UUID], patch :: !(HashMap Text Value), requestId :: !Int, transactionId :: !(Maybe UUID) } + | DeleteRecordMessage { table :: !Text, id :: !UUID, requestId :: !Int, transactionId :: !(Maybe UUID) } + | DeleteRecordsMessage { table :: !Text, ids :: ![UUID], requestId :: !Int, transactionId :: !(Maybe UUID) } + | StartTransaction { requestId :: !Int } + | RollbackTransaction { requestId :: !Int, id :: !UUID } + | CommitTransaction { requestId :: !Int, id :: !UUID } deriving (Eq, Show) data DataSyncResponse @@ -34,9 +38,21 @@ data DataSyncResponse | DidUpdateRecords { requestId :: !Int, records :: ![[Field]] } -- ^ Response to 'UpdateRecordsMessage' | DidDeleteRecord { requestId :: !Int } | DidDeleteRecords { requestId :: !Int } + | DidStartTransaction { requestId :: !Int, transactionId :: !UUID } + | DidRollbackTransaction { requestId :: !Int, transactionId :: !UUID } + | DidCommitTransaction { requestId :: !Int, transactionId :: !UUID } data Subscription = Subscription { id :: !UUID, channelSubscription :: !PGListener.Subscription } +data DataSyncTransaction + = DataSyncTransaction + { id :: !UUID + , connection :: !PG.Connection + , releaseConnection :: IO () + } data DataSyncController = DataSyncController - | DataSyncReady { subscriptions :: !(HashMap UUID Subscription) } + | DataSyncReady + { subscriptions :: !(HashMap UUID Subscription) + , transactions :: !(HashMap UUID DataSyncTransaction) + } diff --git a/lib/IHP/DataSync/ihp-datasync.js b/lib/IHP/DataSync/ihp-datasync.js index 70e98fa0a..adf9e237d 100644 --- a/lib/IHP/DataSync/ihp-datasync.js +++ b/lib/IHP/DataSync/ihp-datasync.js @@ -130,6 +130,13 @@ class DataSyncController { this.eventListeners[event].push(callback); } + removeEventListener(event, callback) { + const index = this.eventListeners[event].indexOf(callback); + if (index > -1) { + this.eventListeners[event].splice(index, 1); + } + } + retryToReconnect() { if (this.connection) { return; @@ -332,7 +339,7 @@ function initIHPBackend({ host }) { DataSyncController.ihpBackendHost = host; } -export async function createRecord(table, record) { +export async function createRecord(table, record, options = {}) { if (typeof table !== "string") { throw new Error(`Table name needs to be a string, you passed ${JSON.stringify(table)} in a call to createRecord(${JSON.stringify(table)}, ${JSON.stringify(record, null, 4)})`); } @@ -340,7 +347,8 @@ export async function createRecord(table, record) { throw new Error(`Record needs to be an object, you passed ${JSON.stringify(record)} in a call to createRecord(${JSON.stringify(table)}, ${JSON.stringify(record, null, 4)})`); } - const request = { tag: 'CreateRecordMessage', table, record }; + const transactionId = 'transactionId' in options ? options.transactionId : null; + const request = { tag: 'CreateRecordMessage', table, record, transactionId }; try { const response = await DataSyncController.getInstance().sendMessage(request); @@ -354,7 +362,7 @@ export async function createRecord(table, record) { } } -export async function updateRecord(table, id, patch) { +export async function updateRecord(table, id, patch, options = {}) { if (typeof table !== "string") { throw new Error(`Table name needs to be a string, you passed ${JSON.stringify(table)} in a call to updateRecord(${JSON.stringify(table)}, ${JSON.stringify(id)}, ${JSON.stringify(patch, null, 4)})`); } @@ -365,7 +373,8 @@ export async function updateRecord(table, id, patch) { throw new Error(`Patch needs to be an object, you passed ${JSON.stringify(patch)} in a call to updateRecord(${JSON.stringify(table)}, ${JSON.stringify(id)}, ${JSON.stringify(patch, null, 4)})`); } - const request = { tag: 'UpdateRecordMessage', table, id, patch }; + const transactionId = 'transactionId' in options ? options.transactionId : null; + const request = { tag: 'UpdateRecordMessage', table, id, patch, transactionId }; try { const response = await DataSyncController.getInstance().sendMessage(request); @@ -376,7 +385,7 @@ export async function updateRecord(table, id, patch) { } } -export async function updateRecords(table, ids, patch) { +export async function updateRecords(table, ids, patch, options = {}) { if (typeof table !== "string") { throw new Error(`Table name needs to be a string, you passed ${JSON.stringify(table)} in a call to updateRecords(${JSON.stringify(table)}, ${JSON.stringify(ids)}, ${JSON.stringify(patch, null, 4)})`); } @@ -387,7 +396,8 @@ export async function updateRecords(table, ids, patch) { throw new Error(`Patch needs to be an object, you passed ${JSON.stringify(patch)} in a call to updateRecords(${JSON.stringify(table)}, ${JSON.stringify(ids)}, ${JSON.stringify(patch, null, 4)})`); } - const request = { tag: 'UpdateRecordsMessage', table, ids, patch }; + const transactionId = 'transactionId' in options ? options.transactionId : null; + const request = { tag: 'UpdateRecordsMessage', table, ids, patch, transactionId }; try { const response = await DataSyncController.getInstance().sendMessage(request); @@ -398,7 +408,7 @@ export async function updateRecords(table, ids, patch) { } } -export async function deleteRecord(table, id) { +export async function deleteRecord(table, id, options = {}) { if (typeof table !== "string") { throw new Error(`Table name needs to be a string, you passed ${JSON.stringify(table)} in a call to deleteRecord(${JSON.stringify(table)}, ${JSON.stringify(id)})`); } @@ -406,7 +416,8 @@ export async function deleteRecord(table, id) { throw new Error(`ID needs to be an UUID, you passed ${JSON.stringify(id)} in a call to deleteRecord(${JSON.stringify(table)}, ${JSON.stringify(id)})`); } - const request = { tag: 'DeleteRecordMessage', table, id }; + const transactionId = 'transactionId' in options ? options.transactionId : null; + const request = { tag: 'DeleteRecordMessage', table, id, transactionId }; try { const response = await DataSyncController.getInstance().sendMessage(request); @@ -417,7 +428,7 @@ export async function deleteRecord(table, id) { } } -export async function deleteRecords(table, ids) { +export async function deleteRecords(table, ids, options = {}) { if (typeof table !== "string") { throw new Error(`Table name needs to be a string, you passed ${JSON.stringify(table)} in a call to deleteRecords(${JSON.stringify(table)}, ${JSON.stringify(ids)})`); } @@ -425,7 +436,8 @@ export async function deleteRecords(table, ids) { throw new Error(`IDs needs to be an array, you passed ${JSON.stringify(ids)} in a call to deleteRecords(${JSON.stringify(table)}, ${JSON.stringify(ids)})`); } - const request = { tag: 'DeleteRecordsMessage', table, ids }; + const transactionId = 'transactionId' in options ? options.transactionId : null; + const request = { tag: 'DeleteRecordsMessage', table, ids, transactionId }; try { const response = await DataSyncController.getInstance().sendMessage(request); @@ -436,7 +448,7 @@ export async function deleteRecords(table, ids) { } } -export async function createRecords(table, records) { +export async function createRecords(table, records, options = {}) { if (typeof table !== "string") { throw new Error(`Table name needs to be a string, you passed ${JSON.stringify(table)} in a call to createRecords(${JSON.stringify(table)}, ${JSON.stringify(records, null, 4)})`); } @@ -444,7 +456,8 @@ export async function createRecords(table, records) { throw new Error(`Records need to be an array, you passed ${JSON.stringify(records)} in a call to createRecords(${JSON.stringify(table)}, ${JSON.stringify(records, null, 4)})`); } - const request = { tag: 'CreateRecordsMessage', table, records }; + const transactionId = 'transactionId' in options ? options.transactionId : null; + const request = { tag: 'CreateRecordsMessage', table, records, transactionId }; try { const response = await DataSyncController.getInstance().sendMessage(request); diff --git a/lib/IHP/DataSync/ihp-querybuilder.js b/lib/IHP/DataSync/ihp-querybuilder.js index 99b5fb657..046dd2c0e 100644 --- a/lib/IHP/DataSync/ihp-querybuilder.js +++ b/lib/IHP/DataSync/ihp-querybuilder.js @@ -27,6 +27,7 @@ class QueryBuilder { limit: null, offset: null }; + this.transactionId = null; } filterWhere(field, value) { @@ -73,8 +74,11 @@ class QueryBuilder { } async fetch() { - // return fetch('/Query').then(response => response.json()); - const { result } = await DataSyncController.getInstance().sendMessage({ tag: 'DataSyncQuery', query: this.query }); + const { result } = await DataSyncController.getInstance().sendMessage({ + tag: 'DataSyncQuery', + query: this.query, + transactionId: this.transactionId + }); return result; } diff --git a/lib/IHP/DataSync/index.js b/lib/IHP/DataSync/index.js index 42a0d570e..58903f598 100644 --- a/lib/IHP/DataSync/index.js +++ b/lib/IHP/DataSync/index.js @@ -1,10 +1,14 @@ import { QueryBuilder, query, ihpBackendUrl, fetchAuthenticated } from './ihp-querybuilder.js'; -import { DataSyncController, DataSubscription, initIHPBackend, createRecord, updateRecord, deleteRecord, createRecords } from './ihp-datasync.js'; +import { DataSyncController, DataSubscription, initIHPBackend, createRecord, createRecords, updateRecord, updateRecords, deleteRecord, deleteRecords } from './ihp-datasync.js'; +import { Transaction, withTransaction } from './transaction.js'; export { /* ihp-querybuilder.js */ QueryBuilder, query, ihpBackendUrl, fetchAuthenticated, /* ihp-datasync.js */ - DataSyncController, DataSubscription, initIHPBackend, createRecord, updateRecord, deleteRecord, createRecords + DataSyncController, DataSubscription, initIHPBackend, createRecord, createRecords, updateRecord, updateRecords, deleteRecord, deleteRecords, + + /* transaction.js */ + Transaction, withTransaction }; diff --git a/lib/IHP/DataSync/transaction.js b/lib/IHP/DataSync/transaction.js new file mode 100644 index 000000000..575297745 --- /dev/null +++ b/lib/IHP/DataSync/transaction.js @@ -0,0 +1,94 @@ +import { DataSyncController, createRecord, createRecords, updateRecord, updateRecords, deleteRecord, deleteRecords } from "./ihp-datasync.js"; +import { query } from "./ihp-querybuilder.js"; + +export class Transaction { + constructor() { + this.transactionId = null; + this.onClose = this.onClose.bind(this); + this.dataSyncController = DataSyncController.getInstance(); + } + + async start() { + const { transactionId } = await this.dataSyncController.sendMessage({ tag: 'StartTransaction' }); + + this.transactionId = transactionId; + + this.dataSyncController.addEventListener('close', this.onClose); + } + + async commit() { + if (this.transactionId === null) { + throw new Error('You need to call `.start()` before you can commit the transaction'); + } + + await this.dataSyncController.sendMessage({ tag: 'CommitTransaction', id: this.transactionId }); + } + + async rollback() { + if (this.transactionId === null) { + throw new Error('You need to call `.start()` before you can rollback the transaction'); + } + + await this.dataSyncController.sendMessage({ tag: 'RollbackTransaction', id: this.transactionId }); + } + + onClose() { + this.transactionId = null; + this.dataSyncController.removeEventListener('close', this.onClose); + } + + getIdOrFail() { + if (this.transactionId === null) { + throw new Error('You need to call `.start()` before you can use this transaction'); + } + + return this.transactionId; + } + + buildOptions() { + return { transactionId: this.getIdOrFail() }; + } + + query(table) { + const tableQuery = query(table); + tableQuery.transactionId = this.getIdOrFail(); + return tableQuery; + } + + createRecord(table, record) { + return createRecord(table, record, this.buildOptions()); + } + + createRecords(table, records) { + return createRecords(table, records, this.buildOptions()); + } + + updateRecord(table, id, patch) { + return updateRecord(table, id, patch, this.buildOptions()); + } + + updateRecords(table, ids, patch) { + return updateRecords(table, ids, patch, this.buildOptions()); + } + + deleteRecord(table, id) { + return deleteRecord(table, id, this.buildOptions()); + } + + deleteRecords(table, ids) { + return deleteRecords(table, ids, this.buildOptions()); + } +} + +export async function withTransaction(callback) { + const transaction = new Transaction(); + await transaction.start(); + try { + const result = await callback(transaction); + await transaction.commit(); + return result; + } catch (exception) { + await transaction.rollback(); + throw exception; + } +}