diff --git a/conn.go b/conn.go index 96ed452d9..a7a5ef73d 100644 --- a/conn.go +++ b/conn.go @@ -1359,12 +1359,12 @@ func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error } if c.descriptionCache != nil { - c.descriptionCache.HandleInvalidated() + c.descriptionCache.RemoveInvalidated() } var invalidatedStatements []*pgconn.StatementDescription if c.statementCache != nil { - invalidatedStatements = c.statementCache.HandleInvalidated() + invalidatedStatements = c.statementCache.GetInvalidated() } if len(invalidatedStatements) == 0 { @@ -1376,7 +1376,6 @@ func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error for _, sd := range invalidatedStatements { pipeline.SendDeallocate(sd.Name) - delete(c.preparedStatements, sd.Name) } err := pipeline.Sync() @@ -1389,5 +1388,10 @@ func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error return fmt.Errorf("failed to deallocate cached statement(s): %w", err) } + c.statementCache.RemoveInvalidated() + for _, sd := range invalidatedStatements { + delete(c.preparedStatements, sd.Name) + } + return nil } diff --git a/conn_test.go b/conn_test.go index a7f7f2f88..e9415b229 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1338,3 +1338,32 @@ func TestRawValuesUnderlyingMemoryReused(t *testing.T) { t.Fatal("expected buffer from RawValues to be overwritten by subsequent queries but it was not") }) } + +// https://github.com/jackc/pgx/issues/1847 +func TestConnDeallocateInvalidatedCachedStatementsWhenCanceled(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var n int32 + err := conn.QueryRow(ctx, "select 1 / $1::int", 1).Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 1, n) + + // Divide by zero causes an error. baseRows.Close() calls Invalidate on the statement cache whenever an error was + // encountered by the query. Use this to purposely invalidate the query. If we had access to private fields of conn + // we could call conn.statementCache.InvalidateAll() instead. + err = conn.QueryRow(ctx, "select 1 / $1::int", 0).Scan(&n) + require.Error(t, err) + + ctx2, cancel2 := context.WithCancel(ctx) + cancel2() + err = conn.QueryRow(ctx2, "select 1 / $1::int", 1).Scan(&n) + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) + + err = conn.QueryRow(ctx, "select 1 / $1::int", 1).Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 1, n) + }) +} diff --git a/internal/stmtcache/lru_cache.go b/internal/stmtcache/lru_cache.go index 859345fcb..dec83f47b 100644 --- a/internal/stmtcache/lru_cache.go +++ b/internal/stmtcache/lru_cache.go @@ -81,12 +81,16 @@ func (c *LRUCache) InvalidateAll() { c.l = list.New() } -// HandleInvalidated returns a slice of all statement descriptions invalidated since the last call to HandleInvalidated. -// Typically, the caller will then deallocate them. -func (c *LRUCache) HandleInvalidated() []*pgconn.StatementDescription { - invalidStmts := c.invalidStmts +// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated. +func (c *LRUCache) GetInvalidated() []*pgconn.StatementDescription { + return c.invalidStmts +} + +// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a +// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were +// never seen by the call to GetInvalidated. +func (c *LRUCache) RemoveInvalidated() { c.invalidStmts = nil - return invalidStmts } // Len returns the number of cached prepared statement descriptions. diff --git a/internal/stmtcache/stmtcache.go b/internal/stmtcache/stmtcache.go index b2940e230..d57bdd29e 100644 --- a/internal/stmtcache/stmtcache.go +++ b/internal/stmtcache/stmtcache.go @@ -29,8 +29,13 @@ type Cache interface { // InvalidateAll invalidates all statement descriptions. InvalidateAll() - // HandleInvalidated returns a slice of all statement descriptions invalidated since the last call to HandleInvalidated. - HandleInvalidated() []*pgconn.StatementDescription + // GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated. + GetInvalidated() []*pgconn.StatementDescription + + // RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a + // call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were + // never seen by the call to GetInvalidated. + RemoveInvalidated() // Len returns the number of cached prepared statement descriptions. Len() int diff --git a/internal/stmtcache/unlimited_cache.go b/internal/stmtcache/unlimited_cache.go index f5f59396e..696413291 100644 --- a/internal/stmtcache/unlimited_cache.go +++ b/internal/stmtcache/unlimited_cache.go @@ -54,10 +54,16 @@ func (c *UnlimitedCache) InvalidateAll() { c.m = make(map[string]*pgconn.StatementDescription) } -func (c *UnlimitedCache) HandleInvalidated() []*pgconn.StatementDescription { - invalidStmts := c.invalidStmts +// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated. +func (c *UnlimitedCache) GetInvalidated() []*pgconn.StatementDescription { + return c.invalidStmts +} + +// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a +// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were +// never seen by the call to GetInvalidated. +func (c *UnlimitedCache) RemoveInvalidated() { c.invalidStmts = nil - return invalidStmts } // Len returns the number of cached prepared statement descriptions.