From 2c6ed23e0789b2188dd23941cedf2fd52b08080c Mon Sep 17 00:00:00 2001 From: Jorge Rojas Date: Wed, 10 Jul 2024 01:48:00 -0400 Subject: [PATCH] fix: formats table name everywhere --- drivers/postgres.go | 40 ++++++++++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/drivers/postgres.go b/drivers/postgres.go index 23412a0..a586c6e 100644 --- a/drivers/postgres.go +++ b/drivers/postgres.go @@ -317,6 +317,7 @@ func (db *Postgres) GetIndexes(table string) (indexes [][]string, error error) { } func (db *Postgres) GetRecords(table, where, sort string, offset, limit int) (records [][]string, totalRecords int, err error) { + table = db.formatTableName(table) defaultLimit := 300 isPaginationEnabled := offset >= 0 && limit >= 0 @@ -324,20 +325,14 @@ func (db *Postgres) GetRecords(table, where, sort string, offset, limit int) (re defaultLimit = limit } - splittedTableName := strings.Split(table, ".") - schema := splittedTableName[0] - tableName := splittedTableName[1] - - formattedTableName := fmt.Sprintf("\"%s\".\"%s\"", schema, tableName) - - query := fmt.Sprintf("SELECT * FROM %s s LIMIT %d OFFSET %d", formattedTableName, defaultLimit, offset) + query := fmt.Sprintf("SELECT * FROM %s s LIMIT %d OFFSET %d", table, defaultLimit, offset) if where != "" { - query = fmt.Sprintf("SELECT * FROM %s %s LIMIT %d OFFSET %d", formattedTableName, where, defaultLimit, offset) + query = fmt.Sprintf("SELECT * FROM %s %s LIMIT %d OFFSET %d", table, where, defaultLimit, offset) } if sort != "" { - query = fmt.Sprintf("SELECT * FROM %s %s ORDER BY %s LIMIT %d OFFSET %d", formattedTableName, where, sort, defaultLimit, offset) + query = fmt.Sprintf("SELECT * FROM %s %s ORDER BY %s LIMIT %d OFFSET %d", table, where, sort, defaultLimit, offset) } paginatedRows, err := db.Connection.Query(query) @@ -346,7 +341,7 @@ func (db *Postgres) GetRecords(table, where, sort string, offset, limit int) (re } if isPaginationEnabled { - queryWithoutLimit := fmt.Sprintf("SELECT COUNT(*) FROM %s %s", formattedTableName, where) + queryWithoutLimit := fmt.Sprintf("SELECT COUNT(*) FROM %s %s", table, where) rows := db.Connection.QueryRow(queryWithoutLimit) @@ -387,6 +382,7 @@ func (db *Postgres) GetRecords(table, where, sort string, offset, limit int) (re } func (db *Postgres) UpdateRecord(table, column, value, primaryKeyColumnName, primaryKeyValue string) (err error) { + table = db.formatTableName(table) query := fmt.Sprintf("UPDATE %s SET %s = '%s' WHERE \"%s\" = '%s'", table, column, value, primaryKeyColumnName, primaryKeyValue) _, err = db.Connection.Exec(query) @@ -394,6 +390,7 @@ func (db *Postgres) UpdateRecord(table, column, value, primaryKeyColumnName, pri } func (db *Postgres) DeleteRecord(table, primaryKeyColumnName, primaryKeyValue string) (err error) { + table = db.formatTableName(table) query := fmt.Sprintf("DELETE FROM %s WHERE \"%s\" = '%s'", table, primaryKeyColumnName, primaryKeyValue) _, err = db.Connection.Exec(query) @@ -457,7 +454,7 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange, inserts // Group changes by RowId and Table for _, change := range changes { if change.Type == "UPDATE" { - key := fmt.Sprintf("%s|%s|%s", change.Table, change.PrimaryKeyColumnName, change.PrimaryKeyValue) + key := fmt.Sprintf("%s|%s|%s", db.formatTableName(change.Table), change.PrimaryKeyColumnName, change.PrimaryKeyValue) groupedUpdated[key] = append(groupedUpdated[key], change) } else if change.Type == "DELETE" { groupedDeletes = append(groupedDeletes, change) @@ -470,7 +467,7 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange, inserts // Split key into table and rowId splitted := strings.Split(key, "|") - table := splitted[0] + table := db.formatTableName(splitted[0]) PrimaryKeyColumnName := splitted[1] primaryKeyValue := splitted[2] @@ -491,7 +488,7 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange, inserts query := "" statementType = "DELETE FROM" - query = fmt.Sprintf("%s %s WHERE \"%s\" = '%s'", statementType, del.Table, del.PrimaryKeyColumnName, del.PrimaryKeyValue) + query = fmt.Sprintf("%s %s WHERE \"%s\" = '%s'", statementType, db.formatTableName(del.Table), del.PrimaryKeyColumnName, del.PrimaryKeyValue) if query != "" { queries = append(queries, query) @@ -511,7 +508,7 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange, inserts } } - query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", insert.Table, strings.Join(insert.Columns, ", "), strings.Join(values, ", ")) + query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", db.formatTableName(insert.Table), strings.Join(insert.Columns, ", "), strings.Join(values, ", ")) queries = append(queries, query) } @@ -577,3 +574,18 @@ func (db *Postgres) SwitchDatabase(database string) error { return nil } + +func (db *Postgres) formatTableName(table string) string { + splittedTableName := strings.Split(table, ".") + + if len(splittedTableName) == 1 { + return table + } + + schema := splittedTableName[0] + tableName := splittedTableName[1] + + formattedTableName := fmt.Sprintf("\"%s\".\"%s\"", schema, tableName) + + return formattedTableName +}