From 97fcd9d72688002e0be7ecfa7e40b06cf321a805 Mon Sep 17 00:00:00 2001 From: oldme <45782393+oldme-git@users.noreply.github.com> Date: Wed, 6 Mar 2024 19:05:13 +0800 Subject: [PATCH] enhance: improve `FormatUpsert` implements for pgsql (#3349) --- contrib/drivers/pgsql/pgsql.go | 2 +- contrib/drivers/pgsql/pgsql_do_insert.go | 55 --------------- contrib/drivers/pgsql/pgsql_format_upsert.go | 68 +++++++++++++++++++ .../drivers/pgsql/pgsql_z_unit_model_test.go | 55 +++++++++++++-- 4 files changed, 118 insertions(+), 62 deletions(-) create mode 100644 contrib/drivers/pgsql/pgsql_format_upsert.go diff --git a/contrib/drivers/pgsql/pgsql.go b/contrib/drivers/pgsql/pgsql.go index 492d679c224..e863d213ae9 100644 --- a/contrib/drivers/pgsql/pgsql.go +++ b/contrib/drivers/pgsql/pgsql.go @@ -7,7 +7,7 @@ // Package pgsql implements gdb.Driver, which supports operations for database PostgreSQL. // // Note: -// 1. It does not support Save/Replace features. +// 1. It does not support Replace features. // 2. It does not support Insert Ignore features. package pgsql diff --git a/contrib/drivers/pgsql/pgsql_do_insert.go b/contrib/drivers/pgsql/pgsql_do_insert.go index 84995ff3710..4404a34e03b 100644 --- a/contrib/drivers/pgsql/pgsql_do_insert.go +++ b/contrib/drivers/pgsql/pgsql_do_insert.go @@ -9,13 +9,10 @@ package pgsql import ( "context" "database/sql" - "fmt" "github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/errors/gcode" "github.com/gogf/gf/v2/errors/gerror" - "github.com/gogf/gf/v2/text/gstr" - "github.com/gogf/gf/v2/util/gconv" ) // DoInsert inserts or updates data forF given table. @@ -47,55 +44,3 @@ func (d *Driver) DoInsert(ctx context.Context, link gdb.Link, table string, list } return d.Core.DoInsert(ctx, link, table, list, option) } - -// FormatUpsert returns SQL clause of type upsert for PgSQL. -// For example: ON CONFLICT (id) DO UPDATE SET ... -func (d *Driver) FormatUpsert(columns []string, list gdb.List, option gdb.DoInsertOption) (string, error) { - if len(option.OnConflict) == 0 { - return "", gerror.New("Please specify conflict columns") - } - - var onDuplicateStr string - if option.OnDuplicateStr != "" { - onDuplicateStr = option.OnDuplicateStr - } else if len(option.OnDuplicateMap) > 0 { - for k, v := range option.OnDuplicateMap { - if len(onDuplicateStr) > 0 { - onDuplicateStr += "," - } - switch v.(type) { - case gdb.Raw, *gdb.Raw: - onDuplicateStr += fmt.Sprintf( - "%s=%s", - d.Core.QuoteWord(k), - v, - ) - default: - onDuplicateStr += fmt.Sprintf( - "%s=EXCLUDED.%s", - d.Core.QuoteWord(k), - d.Core.QuoteWord(gconv.String(v)), - ) - } - } - } else { - for _, column := range columns { - // If it's SAVE operation, do not automatically update the creating time. - if d.Core.IsSoftCreatedFieldName(column) { - continue - } - if len(onDuplicateStr) > 0 { - onDuplicateStr += "," - } - onDuplicateStr += fmt.Sprintf( - "%s=EXCLUDED.%s", - d.Core.QuoteWord(column), - d.Core.QuoteWord(column), - ) - } - } - - conflictKeys := gstr.Join(option.OnConflict, ",") - - return fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET ", conflictKeys) + onDuplicateStr, nil -} diff --git a/contrib/drivers/pgsql/pgsql_format_upsert.go b/contrib/drivers/pgsql/pgsql_format_upsert.go new file mode 100644 index 00000000000..bc082243c2d --- /dev/null +++ b/contrib/drivers/pgsql/pgsql_format_upsert.go @@ -0,0 +1,68 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package pgsql + +import ( + "fmt" + + "github.com/gogf/gf/v2/database/gdb" + "github.com/gogf/gf/v2/errors/gerror" + "github.com/gogf/gf/v2/text/gstr" + "github.com/gogf/gf/v2/util/gconv" +) + +// FormatUpsert returns SQL clause of type upsert for PgSQL. +// For example: ON CONFLICT (id) DO UPDATE SET ... +func (d *Driver) FormatUpsert(columns []string, list gdb.List, option gdb.DoInsertOption) (string, error) { + if len(option.OnConflict) == 0 { + return "", gerror.New("Please specify conflict columns") + } + + var onDuplicateStr string + if option.OnDuplicateStr != "" { + onDuplicateStr = option.OnDuplicateStr + } else if len(option.OnDuplicateMap) > 0 { + for k, v := range option.OnDuplicateMap { + if len(onDuplicateStr) > 0 { + onDuplicateStr += "," + } + switch v.(type) { + case gdb.Raw, *gdb.Raw: + onDuplicateStr += fmt.Sprintf( + "%s=%s", + d.Core.QuoteWord(k), + v, + ) + default: + onDuplicateStr += fmt.Sprintf( + "%s=EXCLUDED.%s", + d.Core.QuoteWord(k), + d.Core.QuoteWord(gconv.String(v)), + ) + } + } + } else { + for _, column := range columns { + // If it's SAVE operation, do not automatically update the creating time. + if d.Core.IsSoftCreatedFieldName(column) { + continue + } + if len(onDuplicateStr) > 0 { + onDuplicateStr += "," + } + onDuplicateStr += fmt.Sprintf( + "%s=EXCLUDED.%s", + d.Core.QuoteWord(column), + d.Core.QuoteWord(column), + ) + } + } + + conflictKeys := gstr.Join(option.OnConflict, ",") + + return fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET ", conflictKeys) + onDuplicateStr, nil +} diff --git a/contrib/drivers/pgsql/pgsql_z_unit_model_test.go b/contrib/drivers/pgsql/pgsql_z_unit_model_test.go index 81411ec49cb..4c9d6e90642 100644 --- a/contrib/drivers/pgsql/pgsql_z_unit_model_test.go +++ b/contrib/drivers/pgsql/pgsql_z_unit_model_test.go @@ -7,6 +7,7 @@ package pgsql_test import ( + "database/sql" "fmt" "testing" @@ -260,16 +261,58 @@ func Test_Model_Save(t *testing.T) { table := createTable() defer dropTable(table) gtest.C(t, func(t *gtest.T) { - result, err := db.Model(table).Data(g.Map{ + type User struct { + Id int + Passport string + Password string + NickName string + CreateTime *gtime.Time + } + var ( + user User + count int + result sql.Result + err error + ) + + result, err = db.Model(table).Data(g.Map{ "id": 1, - "passport": "t111", - "password": "25d55ad283aa400af464c76d713c07ad", - "nickname": "T111", - "create_time": "2018-10-24 10:00:00", + "passport": "p1", + "password": "pw1", + "nickname": "n1", + "create_time": CreateTime, }).OnConflict("id").Save() - t.AssertNil(err) + t.AssertNil(nil) n, _ := result.RowsAffected() t.Assert(n, 1) + + err = db.Model(table).Scan(&user) + t.Assert(err, nil) + t.Assert(user.Id, 1) + t.Assert(user.Passport, "p1") + t.Assert(user.Password, "pw1") + t.Assert(user.NickName, "n1") + t.Assert(user.CreateTime.String(), CreateTime) + + _, err = db.Model(table).Data(g.Map{ + "id": 1, + "passport": "p1", + "password": "pw2", + "nickname": "n2", + "create_time": CreateTime, + }).OnConflict("id").Save() + t.AssertNil(err) + + err = db.Model(table).Scan(&user) + t.Assert(err, nil) + t.Assert(user.Passport, "p1") + t.Assert(user.Password, "pw2") + t.Assert(user.NickName, "n2") + t.Assert(user.CreateTime.String(), CreateTime) + + count, err = db.Model(table).Count() + t.Assert(err, nil) + t.Assert(count, 1) }) }