From 832b4f97718c2d9d2eb16bbd2fef1d05ede7aab5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 3 Feb 2024 12:25:57 -0600 Subject: [PATCH] Fix: prepared statement already exists When a conn is going to execute a query, the first thing it does is to deallocate any invalidated prepared statements from the statement cache. However, the statements were removed from the cache regardless of whether the deallocation succeeded. This would cause subsequent calls of the same SQL to fail with "prepared statement already exists" error. This problem is easy to trigger by running a query with a context that is already canceled. This commit changes the deallocate invalidated cached statements logic so that the statements are only removed from the cache if the deallocation was successful on the server. https://github.com/jackc/pgx/issues/1847 --- conn.go | 10 ++++++--- conn_test.go | 29 +++++++++++++++++++++++++++ internal/stmtcache/lru_cache.go | 14 ++++++++----- internal/stmtcache/stmtcache.go | 9 +++++++-- internal/stmtcache/unlimited_cache.go | 12 ++++++++--- 5 files changed, 61 insertions(+), 13 deletions(-) diff --git a/conn.go b/conn.go index 96ed452d9..a7a5ef73d 100644 --- a/conn.go +++ b/conn.go @@ -1359,12 +1359,12 @@ func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error } if c.descriptionCache != nil { - c.descriptionCache.HandleInvalidated() + c.descriptionCache.RemoveInvalidated() } var invalidatedStatements []*pgconn.StatementDescription if c.statementCache != nil { - invalidatedStatements = c.statementCache.HandleInvalidated() + invalidatedStatements = c.statementCache.GetInvalidated() } if len(invalidatedStatements) == 0 { @@ -1376,7 +1376,6 @@ func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error for _, sd := range invalidatedStatements { pipeline.SendDeallocate(sd.Name) - delete(c.preparedStatements, sd.Name) } err := pipeline.Sync() @@ -1389,5 +1388,10 @@ func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error return fmt.Errorf("failed to deallocate cached statement(s): %w", err) } + c.statementCache.RemoveInvalidated() + for _, sd := range invalidatedStatements { + delete(c.preparedStatements, sd.Name) + } + return nil } diff --git a/conn_test.go b/conn_test.go index a7f7f2f88..e9415b229 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1338,3 +1338,32 @@ func TestRawValuesUnderlyingMemoryReused(t *testing.T) { t.Fatal("expected buffer from RawValues to be overwritten by subsequent queries but it was not") }) } + +// https://github.com/jackc/pgx/issues/1847 +func TestConnDeallocateInvalidatedCachedStatementsWhenCanceled(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var n int32 + err := conn.QueryRow(ctx, "select 1 / $1::int", 1).Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 1, n) + + // Divide by zero causes an error. baseRows.Close() calls Invalidate on the statement cache whenever an error was + // encountered by the query. Use this to purposely invalidate the query. If we had access to private fields of conn + // we could call conn.statementCache.InvalidateAll() instead. + err = conn.QueryRow(ctx, "select 1 / $1::int", 0).Scan(&n) + require.Error(t, err) + + ctx2, cancel2 := context.WithCancel(ctx) + cancel2() + err = conn.QueryRow(ctx2, "select 1 / $1::int", 1).Scan(&n) + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) + + err = conn.QueryRow(ctx, "select 1 / $1::int", 1).Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 1, n) + }) +} diff --git a/internal/stmtcache/lru_cache.go b/internal/stmtcache/lru_cache.go index 859345fcb..dec83f47b 100644 --- a/internal/stmtcache/lru_cache.go +++ b/internal/stmtcache/lru_cache.go @@ -81,12 +81,16 @@ func (c *LRUCache) InvalidateAll() { c.l = list.New() } -// HandleInvalidated returns a slice of all statement descriptions invalidated since the last call to HandleInvalidated. -// Typically, the caller will then deallocate them. -func (c *LRUCache) HandleInvalidated() []*pgconn.StatementDescription { - invalidStmts := c.invalidStmts +// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated. +func (c *LRUCache) GetInvalidated() []*pgconn.StatementDescription { + return c.invalidStmts +} + +// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a +// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were +// never seen by the call to GetInvalidated. +func (c *LRUCache) RemoveInvalidated() { c.invalidStmts = nil - return invalidStmts } // Len returns the number of cached prepared statement descriptions. diff --git a/internal/stmtcache/stmtcache.go b/internal/stmtcache/stmtcache.go index b2940e230..d57bdd29e 100644 --- a/internal/stmtcache/stmtcache.go +++ b/internal/stmtcache/stmtcache.go @@ -29,8 +29,13 @@ type Cache interface { // InvalidateAll invalidates all statement descriptions. InvalidateAll() - // HandleInvalidated returns a slice of all statement descriptions invalidated since the last call to HandleInvalidated. - HandleInvalidated() []*pgconn.StatementDescription + // GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated. + GetInvalidated() []*pgconn.StatementDescription + + // RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a + // call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were + // never seen by the call to GetInvalidated. + RemoveInvalidated() // Len returns the number of cached prepared statement descriptions. Len() int diff --git a/internal/stmtcache/unlimited_cache.go b/internal/stmtcache/unlimited_cache.go index f5f59396e..696413291 100644 --- a/internal/stmtcache/unlimited_cache.go +++ b/internal/stmtcache/unlimited_cache.go @@ -54,10 +54,16 @@ func (c *UnlimitedCache) InvalidateAll() { c.m = make(map[string]*pgconn.StatementDescription) } -func (c *UnlimitedCache) HandleInvalidated() []*pgconn.StatementDescription { - invalidStmts := c.invalidStmts +// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated. +func (c *UnlimitedCache) GetInvalidated() []*pgconn.StatementDescription { + return c.invalidStmts +} + +// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a +// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were +// never seen by the call to GetInvalidated. +func (c *UnlimitedCache) RemoveInvalidated() { c.invalidStmts = nil - return invalidStmts } // Len returns the number of cached prepared statement descriptions.