Skip to content

Commit

Permalink
Fix prepared statement already exists on batch prepare failure
Browse files Browse the repository at this point in the history
When a batch successfully prepared some statements, but then failed to
prepare others, the prepared statements that were successfully prepared
were not properly cleaned up. This could lead to a "prepared statement
already exists" error on subsequent attempts to prepare the same
statement.

jackc#1847 (comment)
  • Loading branch information
jackc authored and ninedraft committed Dec 9, 2024
1 parent e400c5e commit 808da06
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 29 deletions.
30 changes: 30 additions & 0 deletions batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,36 @@ func TestSendBatchSimpleProtocol(t *testing.T) {
assert.False(t, rows.Next())
}

// https://github.com/jackc/pgx/issues/1847#issuecomment-2347858887
func TestConnSendBatchErrorDoesNotLeaveOrphanedPreparedStatement(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()

pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.SkipCockroachDB(t, conn, "Server serial type is incompatible with test")

mustExec(t, conn, `create temporary table foo(col1 text primary key);`)

batch := &pgx.Batch{}
batch.Queue("select col1 from foo")
batch.Queue("select col1 from baz")
err := conn.SendBatch(ctx, batch).Close()
require.EqualError(t, err, `ERROR: relation "baz" does not exist (SQLSTATE 42P01)`)

mustExec(t, conn, `create temporary table baz(col1 text primary key);`)

// Since table baz now exists, the batch should succeed.

batch = &pgx.Batch{}
batch.Queue("select col1 from foo")
batch.Queue("select col1 from baz")
err = conn.SendBatch(ctx, batch).Close()
require.NoError(t, err)
})
}

func ExampleConn_SendBatch() {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
Expand Down
75 changes: 46 additions & 29 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1126,47 +1126,64 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d

// Prepare any needed queries
if len(distinctNewQueries) > 0 {
for _, sd := range distinctNewQueries {
pipeline.SendPrepare(sd.Name, sd.SQL, nil)
}
err := func() (err error) {
for _, sd := range distinctNewQueries {
pipeline.SendPrepare(sd.Name, sd.SQL, nil)
}

err := pipeline.Sync()
if err != nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
}
// Store all statements we are preparing into the cache. It's fine if it overflows because HandleInvalidated will
// clean them up later.
if sdCache != nil {
for _, sd := range distinctNewQueries {
sdCache.Put(sd)
}
}

// If something goes wrong preparing the statements, we need to invalidate the cache entries we just added.
defer func() {
if err != nil && sdCache != nil {
for _, sd := range distinctNewQueries {
sdCache.Invalidate(sd.SQL)
}
}
}()

err = pipeline.Sync()
if err != nil {
return err
}

for _, sd := range distinctNewQueries {
results, err := pipeline.GetResults()
if err != nil {
return err
}

resultSD, ok := results.(*pgconn.StatementDescription)
if !ok {
return fmt.Errorf("expected statement description, got %T", results)
}

// Fill in the previously empty / pending statement descriptions.
sd.ParamOIDs = resultSD.ParamOIDs
sd.Fields = resultSD.Fields
}

for _, sd := range distinctNewQueries {
results, err := pipeline.GetResults()
if err != nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
return err
}

resultSD, ok := results.(*pgconn.StatementDescription)
_, ok := results.(*pgconn.PipelineSync)
if !ok {
return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results), closed: true}
return fmt.Errorf("expected sync, got %T", results)
}

// Fill in the previously empty / pending statement descriptions.
sd.ParamOIDs = resultSD.ParamOIDs
sd.Fields = resultSD.Fields
}

results, err := pipeline.GetResults()
return nil
}()
if err != nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
}

_, ok := results.(*pgconn.PipelineSync)
if !ok {
return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results), closed: true}
}
}

// Put all statements into the cache. It's fine if it overflows because HandleInvalidated will clean them up later.
if sdCache != nil {
for _, sd := range distinctNewQueries {
sdCache.Put(sd)
}
}

// Queue the queries.
Expand Down

0 comments on commit 808da06

Please sign in to comment.