diff --git a/pkg/kgo/consumer_group.go b/pkg/kgo/consumer_group.go index 998d7aef..6dc6bbd2 100644 --- a/pkg/kgo/consumer_group.go +++ b/pkg/kgo/consumer_group.go @@ -2289,18 +2289,19 @@ func PreCommitFnContext(ctx context.Context, fn func(*kmsg.OffsetCommitRequest) return context.WithValue(ctx, commitContextFn, fn) } -type commitTxnContextFnT struct{} +type txnCommitContextFnT struct{} -var commitTxnContextFn commitTxnContextFnT +var txnCommitContextFn txnCommitContextFnT -// PreCommitTxnFnContext attaches fn to the context through WithValue. Using +// PreTxnCommitFnContext attaches fn to the context through WithValue. Using // the context while committing a transaction allows fn to be called just // before the commit is issued. This can be used to modify the actual commit, // such as by associating metadata with partitions (for transactions, the // default internal metadata is the client's current member ID). If fn returns -// an error, the commit is not attempted. -func PreCommitTxnFnContext(ctx context.Context, fn func(*kmsg.TxnOffsetCommitRequest) error) context.Context { - return context.WithValue(ctx, commitTxnContextFn, fn) +// an error, the commit is not attempted. This context can be used in either +// GroupTransactSession.End or in Client.EndTransaction. +func PreTxnCommitFnContext(ctx context.Context, fn func(*kmsg.TxnOffsetCommitRequest) error) context.Context { + return context.WithValue(ctx, txnCommitContextFn, fn) } // CommitRecords issues a synchronous offset commit for the offsets contained diff --git a/pkg/kgo/txn.go b/pkg/kgo/txn.go index c7620f68..bc1382b0 100644 --- a/pkg/kgo/txn.go +++ b/pkg/kgo/txn.go @@ -287,7 +287,7 @@ func (s *GroupTransactSession) End(ctx context.Context, commit TransactionEndTry var commitErrs []string committed := make(chan struct{}) - g = s.cl.commitTransactionOffsets(context.Background(), postcommit, + g = s.cl.commitTransactionOffsets(ctx, postcommit, func(_ *kmsg.TxnOffsetCommitRequest, resp *kmsg.TxnOffsetCommitResponse, err error) { defer close(committed) if err != nil { @@ -1222,7 +1222,7 @@ func (g *groupConsumer) commitTxn( req.Topics = append(req.Topics, reqTopic) } - if fn, ok := ctx.Value(commitTxnContextFn).(func(*kmsg.TxnOffsetCommitRequest) error); ok { + if fn, ok := ctx.Value(txnCommitContextFn).(func(*kmsg.TxnOffsetCommitRequest) error); ok { if err := fn(req); err != nil { onDone(req, nil, err) return