Skip to content

Commit

Permalink
Allow detect whether it's in a database transaction for a context.Con…
Browse files Browse the repository at this point in the history
…text (#21756)

Fix #19513

This PR introduce a new db method `InTransaction(context.Context)`,
and also builtin check on `db.TxContext` and `db.WithTx`.
There is also a new method `db.AutoTx` has been introduced but could be used by other PRs.

`WithTx` will always open a new transaction, if a transaction exist in context, return an error.
`AutoTx` will try to open a new transaction if no transaction exist in context.
That means it will always enter a transaction if there is no error.

Co-authored-by: delvh <[email protected]>
Co-authored-by: 6543 <[email protected]>
  • Loading branch information
3 people authored Nov 12, 2022
1 parent a0a425a commit 34283a7
Show file tree
Hide file tree
Showing 91 changed files with 252 additions and 176 deletions.
2 changes: 1 addition & 1 deletion models/activities/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ func NotifyWatchers(actions ...*Action) error {

// NotifyWatchersActions creates batch of actions for every watcher.
func NotifyWatchersActions(acts []*Action) error {
ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions models/activities/notification.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ func CountNotifications(opts *FindNotificationOptions) (int64, error) {

// CreateRepoTransferNotification creates notification for the user a repository was transferred to
func CreateRepoTransferNotification(doer, newOwner *user_model.User, repo *repo_model.Repository) error {
ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return err
}
Expand Down Expand Up @@ -185,7 +185,7 @@ func CreateRepoTransferNotification(doer, newOwner *user_model.User, repo *repo_
// for each watcher, or updates it if already exists
// receiverID > 0 just send to receiver, else send to all watcher
func CreateOrUpdateIssueNotifications(issueID, commentID, notificationAuthorID, receiverID int64) error {
ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion models/asymkey/gpg_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ func DeleteGPGKey(doer *user_model.User, id int64) (err error) {
return ErrGPGKeyAccessDenied{doer.ID, key.ID}
}

ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion models/asymkey/gpg_key_add.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func AddGPGKey(ownerID int64, content, token, signature string) ([]*GPGKey, erro
return nil, err
}

ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion models/asymkey/gpg_key_verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import (

// VerifyGPGKey marks a GPG key as verified
func VerifyGPGKey(ownerID int64, keyID, token, signature string) (string, error) {
ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return "", err
}
Expand Down
4 changes: 2 additions & 2 deletions models/asymkey/ssh_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func AddPublicKey(ownerID int64, name, content string, authSourceID int64) (*Pub
return nil, err
}

ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -321,7 +321,7 @@ func PublicKeyIsExternallyManaged(id int64) (bool, error) {
// deleteKeysMarkedForDeletion returns true if ssh keys needs update
func deleteKeysMarkedForDeletion(keys []string) (bool, error) {
// Start session
ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return false, err
}
Expand Down
2 changes: 1 addition & 1 deletion models/asymkey/ssh_key_deploy.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ func AddDeployKey(repoID int64, name, content string, readOnly bool) (*DeployKey
accessMode = perm.AccessModeWrite
}

ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion models/asymkey/ssh_key_principals.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (

// AddPrincipalKey adds new principal to database and authorized_principals file.
func AddPrincipalKey(ownerID int64, content string, authSourceID int64) (*PublicKey, error) {
ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion models/asymkey/ssh_key_verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (

// VerifySSHKey marks a SSH key as verified
func VerifySSHKey(ownerID int64, fingerprint, token, signature string) (string, error) {
ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return "", err
}
Expand Down
4 changes: 2 additions & 2 deletions models/auth/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ type UpdateOAuth2ApplicationOptions struct {

// UpdateOAuth2Application updates an oauth2 application
func UpdateOAuth2Application(opts UpdateOAuth2ApplicationOptions) (*OAuth2Application, error) {
ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -265,7 +265,7 @@ func deleteOAuth2Application(ctx context.Context, id, userid int64) error {

// DeleteOAuth2Application deletes the application with the given id and the grants and auth codes related to it. It checks if the userid was the creator of the app.
func DeleteOAuth2Application(id, userid int64) error {
ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions models/auth/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func ReadSession(key string) (*Session, error) {
Key: key,
}

ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -73,7 +73,7 @@ func DestroySession(key string) error {

// RegenerateSession regenerates a session from the old id
func RegenerateSession(oldKey, newKey string) (*Session, error) {
ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion models/avatars/avatar.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func saveEmailHash(email string) string {
Hash: emailHash,
}
// OK we're going to open a session just because I think that that might hide away any problems with postgres reporting errors
if err := db.WithTx(func(ctx context.Context) error {
if err := db.WithTx(db.DefaultContext, func(ctx context.Context) error {
has, err := db.GetEngine(ctx).Where("email = ? AND hash = ?", emailHash.Email, emailHash.Hash).Get(new(EmailHash))
if has || err != nil {
// Seriously we don't care about any DB problems just return the lowerEmail - we expect the transaction to fail most of the time
Expand Down
54 changes: 47 additions & 7 deletions models/db/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"context"
"database/sql"

"xorm.io/xorm"
"xorm.io/xorm/schemas"
)

Expand Down Expand Up @@ -86,7 +87,11 @@ type Committer interface {
}

// TxContext represents a transaction Context
func TxContext() (*Context, Committer, error) {
func TxContext(parentCtx context.Context) (*Context, Committer, error) {
if InTransaction(parentCtx) {
return nil, nil, ErrAlreadyInTransaction
}

sess := x.NewSession()
if err := sess.Begin(); err != nil {
sess.Close()
Expand All @@ -97,14 +102,24 @@ func TxContext() (*Context, Committer, error) {
}

// WithTx represents executing database operations on a transaction
// you can optionally change the context to a parent one
func WithTx(f func(ctx context.Context) error, stdCtx ...context.Context) error {
parentCtx := DefaultContext
if len(stdCtx) != 0 && stdCtx[0] != nil {
// TODO: make sure parent context has no open session
parentCtx = stdCtx[0]
// This function will always open a new transaction, if a transaction exist in parentCtx return an error.
func WithTx(parentCtx context.Context, f func(ctx context.Context) error) error {
if InTransaction(parentCtx) {
return ErrAlreadyInTransaction
}
return txWithNoCheck(parentCtx, f)
}

// AutoTx represents executing database operations on a transaction, if the transaction exist,
// this function will reuse it otherwise will create a new one and close it when finished.
func AutoTx(parentCtx context.Context, f func(ctx context.Context) error) error {
if InTransaction(parentCtx) {
return f(newContext(parentCtx, GetEngine(parentCtx), true))
}
return txWithNoCheck(parentCtx, f)
}

func txWithNoCheck(parentCtx context.Context, f func(ctx context.Context) error) error {
sess := x.NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
Expand Down Expand Up @@ -180,3 +195,28 @@ func EstimateCount(ctx context.Context, bean interface{}) (int64, error) {
}
return rows, err
}

// InTransaction returns true if the engine is in a transaction otherwise return false
func InTransaction(ctx context.Context) bool {
var e Engine
if engined, ok := ctx.(Engined); ok {
e = engined.Engine()
} else {
enginedInterface := ctx.Value(enginedContextKey)
if enginedInterface != nil {
e = enginedInterface.(Engined).Engine()
}
}
if e == nil {
return false
}

switch t := e.(type) {
case *xorm.Engine:
return false
case *xorm.Session:
return t.IsInTx()
default:
return false
}
}
33 changes: 33 additions & 0 deletions models/db/context_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright 2022 The Gitea Authors. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.

package db_test

import (
"context"
"testing"

"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/unittest"

"github.com/stretchr/testify/assert"
)

func TestInTransaction(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
assert.False(t, db.InTransaction(db.DefaultContext))
assert.NoError(t, db.WithTx(db.DefaultContext, func(ctx context.Context) error {
assert.True(t, db.InTransaction(ctx))
return nil
}))

ctx, committer, err := db.TxContext(db.DefaultContext)
assert.NoError(t, err)
defer committer.Close()
assert.True(t, db.InTransaction(ctx))
assert.Error(t, db.WithTx(ctx, func(ctx context.Context) error {
assert.True(t, db.InTransaction(ctx))
return nil
}))
}
3 changes: 3 additions & 0 deletions models/db/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
package db

import (
"errors"
"fmt"

"code.gitea.io/gitea/modules/util"
)

var ErrAlreadyInTransaction = errors.New("database connection has already been in a transaction")

// ErrCancelled represents an error due to context cancellation
type ErrCancelled struct {
Message string
Expand Down
8 changes: 4 additions & 4 deletions models/db/index_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func TestSyncMaxResourceIndex(t *testing.T) {
assert.EqualValues(t, 62, maxIndex)

// commit transaction
err = db.WithTx(func(ctx context.Context) error {
err = db.WithTx(db.DefaultContext, func(ctx context.Context) error {
err = db.SyncMaxResourceIndex(ctx, "test_index", 10, 73)
assert.NoError(t, err)
maxIndex, err = getCurrentResourceIndex(ctx, "test_index", 10)
Expand All @@ -73,7 +73,7 @@ func TestSyncMaxResourceIndex(t *testing.T) {
assert.EqualValues(t, 73, maxIndex)

// rollback transaction
err = db.WithTx(func(ctx context.Context) error {
err = db.WithTx(db.DefaultContext, func(ctx context.Context) error {
err = db.SyncMaxResourceIndex(ctx, "test_index", 10, 84)
maxIndex, err = getCurrentResourceIndex(ctx, "test_index", 10)
assert.NoError(t, err)
Expand Down Expand Up @@ -102,7 +102,7 @@ func TestGetNextResourceIndex(t *testing.T) {
assert.EqualValues(t, 2, maxIndex)

// commit transaction
err = db.WithTx(func(ctx context.Context) error {
err = db.WithTx(db.DefaultContext, func(ctx context.Context) error {
maxIndex, err = db.GetNextResourceIndex(ctx, "test_index", 20)
assert.NoError(t, err)
assert.EqualValues(t, 3, maxIndex)
Expand All @@ -114,7 +114,7 @@ func TestGetNextResourceIndex(t *testing.T) {
assert.EqualValues(t, 3, maxIndex)

// rollback transaction
err = db.WithTx(func(ctx context.Context) error {
err = db.WithTx(db.DefaultContext, func(ctx context.Context) error {
maxIndex, err = db.GetNextResourceIndex(ctx, "test_index", 20)
assert.NoError(t, err)
assert.EqualValues(t, 4, maxIndex)
Expand Down
2 changes: 1 addition & 1 deletion models/git/branches.go
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ func FindRenamedBranch(repoID int64, from string) (branch *RenamedBranch, exist

// RenameBranch rename a branch
func RenameBranch(repo *repo_model.Repository, from, to string, gitAction func(isDefault bool) error) (err error) {
ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion models/git/branches_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func TestRenameBranch(t *testing.T) {
repo1 := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 1})
_isDefault := false

ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
defer committer.Close()
assert.NoError(t, err)
assert.NoError(t, git_model.UpdateProtectBranch(ctx, repo1, &git_model.ProtectedBranch{
Expand Down
4 changes: 2 additions & 2 deletions models/git/commit_status.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func GetNextCommitStatusIndex(repoID int64, sha string) (int64, error) {

// getNextCommitStatusIndex return the next index
func getNextCommitStatusIndex(repoID int64, sha string) (int64, error) {
ctx, commiter, err := db.TxContext()
ctx, commiter, err := db.TxContext(db.DefaultContext)
if err != nil {
return 0, err
}
Expand Down Expand Up @@ -297,7 +297,7 @@ func NewCommitStatus(opts NewCommitStatusOptions) error {
return fmt.Errorf("generate commit status index failed: %w", err)
}

ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return fmt.Errorf("NewCommitStatus[repo_id: %d, user_id: %d, sha: %s]: %w", opts.Repo.ID, opts.Creator.ID, opts.SHA, err)
}
Expand Down
6 changes: 3 additions & 3 deletions models/git/lfs.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ var ErrLFSObjectNotExist = db.ErrNotExist{Resource: "LFS Meta object"}
func NewLFSMetaObject(m *LFSMetaObject) (*LFSMetaObject, error) {
var err error

ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -185,7 +185,7 @@ func RemoveLFSMetaObjectByOid(repoID int64, oid string) (int64, error) {
return 0, ErrLFSObjectNotExist
}

ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return 0, err
}
Expand Down Expand Up @@ -242,7 +242,7 @@ func LFSObjectIsAssociated(oid string) (bool, error) {

// LFSAutoAssociate auto associates accessible LFSMetaObjects
func LFSAutoAssociate(metas []*LFSMetaObject, user *user_model.User, repoID int64) error {
ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions models/git/lfs_lock.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func cleanPath(p string) string {

// CreateLFSLock creates a new lock.
func CreateLFSLock(repo *repo_model.Repository, lock *LFSLock) (*LFSLock, error) {
dbCtx, committer, err := db.TxContext()
dbCtx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -137,7 +137,7 @@ func CountLFSLockByRepoID(repoID int64) (int64, error) {

// DeleteLFSLockByID deletes a lock by given ID.
func DeleteLFSLockByID(id int64, repo *repo_model.Repository, u *user_model.User, force bool) (*LFSLock, error) {
dbCtx, committer, err := db.TxContext()
dbCtx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion models/issues/assignees.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func IsUserAssignedToIssue(ctx context.Context, issue *Issue, user *user_model.U

// ToggleIssueAssignee changes a user between assigned and not assigned for this issue, and make issue comment for it.
func ToggleIssueAssignee(issue *Issue, doer *user_model.User, assigneeID int64) (removed bool, comment *Comment, err error) {
ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return false, nil, err
}
Expand Down
Loading

0 comments on commit 34283a7

Please sign in to comment.