From 51295787ed2d48280f040ca868989252ba09c572 Mon Sep 17 00:00:00 2001 From: Andrei Matei Date: Mon, 31 Oct 2022 14:07:26 -0400 Subject: [PATCH] singleflight: spruce up singleflight This patch improves singleflight with a couple of features that were done manually by most callers. The singleflight now optionally makes a flight's context not respond to the cancelation of the caller's ctx, the flights now run in stopper tasks, and stopper quiescence cancels the flights' contexts. The callers were doing some of these things, but inconsistently. Also, the flights now get a tracing span and callers that join an existing span get a copy of the recording of the flight leader. We've wanted this multiple times when debugging, for example in the context of table descriptor lease acquisitions and range descriptor resolving. The interface for getting the results out of DoChan() had to change a bit as a result. Release note: None --- .../testdata/benchmark_expectations | 34 +- pkg/ccl/kvccl/kvtenantccl/connector.go | 57 +-- pkg/kv/bulk/sst_batcher_test.go | 2 +- pkg/kv/kvclient/kvcoord/dist_sender.go | 3 +- pkg/kv/kvclient/kvcoord/dist_sender_test.go | 2 +- pkg/kv/kvclient/rangecache/BUILD.bazel | 1 - pkg/kv/kvclient/rangecache/range_cache.go | 72 ++-- .../kvclient/rangecache/range_cache_test.go | 35 +- pkg/kv/kvserver/BUILD.bazel | 1 + pkg/kv/kvserver/liveness/liveness.go | 30 +- pkg/kv/kvserver/protectedts/ptcache/cache.go | 51 +-- .../protectedts/ptcache/cache_test.go | 2 +- pkg/kv/kvserver/rangefeed/BUILD.bazel | 1 - pkg/kv/kvserver/rangefeed/metrics.go | 4 +- pkg/kv/kvserver/replica_rangefeed.go | 28 +- pkg/kv/kvserver/store.go | 18 +- pkg/kv/kvserver/txnrecovery/manager.go | 93 +++-- pkg/sql/authorization.go | 37 +- pkg/sql/cacheutil/BUILD.bazel | 1 - pkg/sql/cacheutil/cache.go | 40 +- pkg/sql/catalog/hydrateddesc/hydratedcache.go | 5 +- pkg/sql/catalog/lease/lease.go | 88 +++-- pkg/sql/distsql_physical_planner_test.go | 4 +- pkg/sql/opt/exec/execbuilder/testdata/upsert | 5 + pkg/sql/rowexec/tablereader_test.go | 18 +- pkg/sql/sessioninit/BUILD.bazel | 1 - pkg/sql/sessioninit/cache.go | 38 +- pkg/sql/sqlliveness/slstorage/BUILD.bazel | 1 - pkg/sql/sqlliveness/slstorage/slstorage.go | 74 ++-- pkg/testutils/storageutils/mocking.go | 25 +- pkg/util/syncutil/singleflight/BUILD.bazel | 21 +- .../syncutil/singleflight/singleflight.go | 341 ++++++++++++++++-- .../singleflight/singleflight_test.go | 286 +++++++++++++-- pkg/util/tracing/crdbspan.go | 26 ++ pkg/util/tracing/span.go | 67 +++- pkg/util/tracing/span_inner.go | 41 +-- 36 files changed, 1035 insertions(+), 518 deletions(-) diff --git a/pkg/bench/rttanalysis/testdata/benchmark_expectations b/pkg/bench/rttanalysis/testdata/benchmark_expectations index d435dfe6622a..1445cab7176e 100644 --- a/pkg/bench/rttanalysis/testdata/benchmark_expectations +++ b/pkg/bench/rttanalysis/testdata/benchmark_expectations @@ -50,27 +50,27 @@ exp,benchmark 9,Grant/grant_all_on_1_table 9,Grant/grant_all_on_2_tables 9,Grant/grant_all_on_3_tables -11,GrantRole/grant_1_role -15,GrantRole/grant_2_roles +13,GrantRole/grant_1_role +18,GrantRole/grant_2_roles 3,ORMQueries/activerecord_type_introspection_query 5,ORMQueries/django_table_introspection_1_table 5,ORMQueries/django_table_introspection_4_tables 5,ORMQueries/django_table_introspection_8_tables -1,ORMQueries/has_column_privilege_using_attnum -1,ORMQueries/has_column_privilege_using_column_name -0,ORMQueries/has_schema_privilege_1 -0,ORMQueries/has_schema_privilege_3 -0,ORMQueries/has_schema_privilege_5 -1,ORMQueries/has_sequence_privilege_1 -3,ORMQueries/has_sequence_privilege_3 -5,ORMQueries/has_sequence_privilege_5 -1,ORMQueries/has_table_privilege_1 -3,ORMQueries/has_table_privilege_3 -5,ORMQueries/has_table_privilege_5 +5,ORMQueries/has_column_privilege_using_attnum +5,ORMQueries/has_column_privilege_using_column_name +5,ORMQueries/has_schema_privilege_1 +15,ORMQueries/has_schema_privilege_3 +25,ORMQueries/has_schema_privilege_5 +5,ORMQueries/has_sequence_privilege_1 +15,ORMQueries/has_sequence_privilege_3 +25,ORMQueries/has_sequence_privilege_5 +5,ORMQueries/has_table_privilege_1 +15,ORMQueries/has_table_privilege_3 +25,ORMQueries/has_table_privilege_5 85,ORMQueries/hasura_column_descriptions 85,ORMQueries/hasura_column_descriptions_8_tables 5,ORMQueries/hasura_column_descriptions_modified -5,ORMQueries/information_schema._pg_index_position +9,ORMQueries/information_schema._pg_index_position 4,ORMQueries/pg_attribute 4,ORMQueries/pg_class 6,ORMQueries/pg_is_other_temp_schema @@ -82,8 +82,8 @@ exp,benchmark 9,Revoke/revoke_all_on_1_table 9,Revoke/revoke_all_on_2_tables 9,Revoke/revoke_all_on_3_tables -9,RevokeRole/revoke_1_role -11,RevokeRole/revoke_2_roles +10,RevokeRole/revoke_1_role +12,RevokeRole/revoke_2_roles 1,SystemDatabaseQueries/select_system.users_with_empty_database_Name 1,SystemDatabaseQueries/select_system.users_with_schema_Name 1,SystemDatabaseQueries/select_system.users_without_schema_Name @@ -95,5 +95,5 @@ exp,benchmark 12,Truncate/truncate_2_column_2_rows 1,VirtualTableQueries/select_crdb_internal.invalid_objects_with_1_fk 1,VirtualTableQueries/select_crdb_internal.tables_with_1_fk -3,VirtualTableQueries/virtual_table_cache_with_point_lookups +11,VirtualTableQueries/virtual_table_cache_with_point_lookups 7,VirtualTableQueries/virtual_table_cache_with_schema_change diff --git a/pkg/ccl/kvccl/kvtenantccl/connector.go b/pkg/ccl/kvccl/kvtenantccl/connector.go index 10016515fd77..cddccae83248 100644 --- a/pkg/ccl/kvccl/kvtenantccl/connector.go +++ b/pkg/ccl/kvccl/kvtenantccl/connector.go @@ -71,7 +71,7 @@ type Connector struct { rpcContext *rpc.Context rpcRetryOptions retry.Options rpcDialTimeout time.Duration // for testing - rpcDial singleflight.Group + rpcDial *singleflight.Group defaultZoneCfg *zonepb.ZoneConfig addrs []string @@ -142,6 +142,7 @@ func NewConnector(cfg kvtenant.ConnectorConfig, addrs []string) *Connector { tenantID: cfg.TenantID, AmbientContext: cfg.AmbientCtx, rpcContext: cfg.RPCContext, + rpcDial: singleflight.NewGroup("dial tenant connector", singleflight.NoTags), rpcRetryOptions: cfg.RPCRetryOptions, defaultZoneCfg: cfg.DefaultZoneConfig, addrs: addrs, @@ -673,41 +674,41 @@ func (c *Connector) withClient( // blocks until either a connection is successfully established or the provided // context is canceled. func (c *Connector) getClient(ctx context.Context) (*client, error) { + ctx = c.AnnotateCtx(ctx) c.mu.RLock() if client := c.mu.client; client != nil { c.mu.RUnlock() return client, nil } - ch, _ := c.rpcDial.DoChan("dial", func() (interface{}, error) { - dialCtx := c.AnnotateCtx(context.Background()) - dialCtx, cancel := c.rpcContext.Stopper.WithCancelOnQuiesce(dialCtx) - defer cancel() - var client *client - err := c.rpcContext.Stopper.RunTaskWithErr(dialCtx, "kvtenant.Connector: dial", - func(ctx context.Context) error { - var err error - client, err = c.dialAddrs(ctx) - return err - }) - if err != nil { - return nil, err - } - c.mu.Lock() - defer c.mu.Unlock() - c.mu.client = client - return client, nil - }) + future, _ := c.rpcDial.DoChan(ctx, + "dial", + singleflight.DoOpts{ + Stop: c.rpcContext.Stopper, + InheritCancelation: false, + }, + func(ctx context.Context) (interface{}, error) { + var client *client + err := c.rpcContext.Stopper.RunTaskWithErr(ctx, "kvtenant.Connector: dial", + func(ctx context.Context) error { + var err error + client, err = c.dialAddrs(ctx) + return err + }) + if err != nil { + return nil, err + } + c.mu.Lock() + defer c.mu.Unlock() + c.mu.client = client + return client, nil + }) c.mu.RUnlock() - select { - case res := <-ch: - if res.Err != nil { - return nil, res.Err - } - return res.Val.(*client), nil - case <-ctx.Done(): - return nil, ctx.Err() + res := future.WaitForResult(ctx) + if res.Err != nil { + return nil, res.Err } + return res.Val.(*client), nil } // dialAddrs attempts to dial each of the configured addresses in a retry loop. diff --git a/pkg/kv/bulk/sst_batcher_test.go b/pkg/kv/bulk/sst_batcher_test.go index 249df6f299a7..9b44851b1e4f 100644 --- a/pkg/kv/bulk/sst_batcher_test.go +++ b/pkg/kv/bulk/sst_batcher_test.go @@ -290,7 +290,7 @@ func runTestImport(t *testing.T, batchSizeValue int64) { // populate it after the first split but before the second split. ds := s.DistSenderI().(*kvcoord.DistSender) mockCache := rangecache.NewRangeCache(s.ClusterSettings(), ds, - func() int64 { return 2 << 10 }, s.Stopper(), s.TracerI().(*tracing.Tracer)) + func() int64 { return 2 << 10 }, s.Stopper()) for _, k := range []int{0, split1} { ent, err := ds.RangeDescriptorCache().Lookup(ctx, keys.MustAddr(key(k))) require.NoError(t, err) diff --git a/pkg/kv/kvclient/kvcoord/dist_sender.go b/pkg/kv/kvclient/kvcoord/dist_sender.go index b72ac9b30a60..2eb47402abfa 100644 --- a/pkg/kv/kvclient/kvcoord/dist_sender.go +++ b/pkg/kv/kvclient/kvcoord/dist_sender.go @@ -454,8 +454,7 @@ func NewDistSender(cfg DistSenderConfig) *DistSender { getRangeDescCacheSize := func() int64 { return rangeDescriptorCacheSize.Get(&ds.st.SV) } - ds.rangeCache = rangecache.NewRangeCache(ds.st, rdb, getRangeDescCacheSize, - cfg.RPCContext.Stopper, cfg.AmbientCtx.Tracer) + ds.rangeCache = rangecache.NewRangeCache(ds.st, rdb, getRangeDescCacheSize, cfg.RPCContext.Stopper) if tf := cfg.TestingKnobs.TransportFactory; tf != nil { ds.transportFactory = tf } else { diff --git a/pkg/kv/kvclient/kvcoord/dist_sender_test.go b/pkg/kv/kvclient/kvcoord/dist_sender_test.go index 283a0a75901b..cf603c60f2f2 100644 --- a/pkg/kv/kvclient/kvcoord/dist_sender_test.go +++ b/pkg/kv/kvclient/kvcoord/dist_sender_test.go @@ -5047,7 +5047,7 @@ func TestSendToReplicasSkipsStaleReplicas(t *testing.T) { getRangeDescCacheSize := func() int64 { return 1 << 20 } - rc := rangecache.NewRangeCache(st, nil /* db */, getRangeDescCacheSize, stopper, tr) + rc := rangecache.NewRangeCache(st, nil /* db */, getRangeDescCacheSize, stopper) rc.Insert(ctx, roachpb.RangeInfo{ Desc: tc.initialDesc, Lease: roachpb.Lease{ diff --git a/pkg/kv/kvclient/rangecache/BUILD.bazel b/pkg/kv/kvclient/rangecache/BUILD.bazel index 1c1dc6be4cfc..1c039f12e9c4 100644 --- a/pkg/kv/kvclient/rangecache/BUILD.bazel +++ b/pkg/kv/kvclient/rangecache/BUILD.bazel @@ -18,7 +18,6 @@ go_library( "//pkg/util/stop", "//pkg/util/syncutil", "//pkg/util/syncutil/singleflight", - "//pkg/util/tracing", "@com_github_biogo_store//llrb", "@com_github_cockroachdb_errors//:errors", "@com_github_cockroachdb_logtags//:logtags", diff --git a/pkg/kv/kvclient/rangecache/range_cache.go b/pkg/kv/kvclient/rangecache/range_cache.go index 52da49105acf..4d3f14f5465d 100644 --- a/pkg/kv/kvclient/rangecache/range_cache.go +++ b/pkg/kv/kvclient/rangecache/range_cache.go @@ -31,7 +31,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/stop" "github.com/cockroachdb/cockroach/pkg/util/syncutil" "github.com/cockroachdb/cockroach/pkg/util/syncutil/singleflight" - "github.com/cockroachdb/cockroach/pkg/util/tracing" "github.com/cockroachdb/errors" "github.com/cockroachdb/logtags" ) @@ -124,7 +123,6 @@ type RangeDescriptorDB interface { type RangeCache struct { st *cluster.Settings stopper *stop.Stopper - tracer *tracing.Tracer // RangeDescriptorDB is used to retrieve range descriptors from the // database, which will be cached by this structure. db RangeDescriptorDB @@ -140,7 +138,7 @@ type RangeCache struct { // lookup requests for the same inferred range descriptor to be // multiplexed onto the same database lookup. See makeLookupRequestKey // for details on this inference. - lookupRequests singleflight.Group + lookupRequests *singleflight.Group // coalesced, if not nil, is sent on every time a request is coalesced onto // another in-flight one. Used by tests to block until a lookup request is @@ -223,13 +221,12 @@ func makeLookupRequestKey( // NewRangeCache returns a new RangeCache which uses the given RangeDescriptorDB // as the underlying source of range descriptors. func NewRangeCache( - st *cluster.Settings, - db RangeDescriptorDB, - size func() int64, - stopper *stop.Stopper, - tracer *tracing.Tracer, + st *cluster.Settings, db RangeDescriptorDB, size func() int64, stopper *stop.Stopper, ) *RangeCache { - rdc := &RangeCache{st: st, db: db, stopper: stopper, tracer: tracer} + rdc := &RangeCache{ + st: st, db: db, stopper: stopper, + lookupRequests: singleflight.NewGroup("range lookup", "lookup"), + } rdc.rangeCache.cache = cache.NewOrderedCache(cache.Config{ Policy: cache.CacheLRU, ShouldEvict: func(n int, _, _ interface{}) bool { @@ -797,23 +794,14 @@ func (rc *RangeCache) tryLookup( } } - // Fork a context with a new span before reqCtx is captured by the DoChan - // closure below; the parent span might get finished by the time the closure - // starts. In the "leader" case, the closure will take ownership of the new - // span. - reqCtx, reqSpan := tracing.EnsureChildSpan(ctx, rc.tracer, "range lookup") - resC, leader := rc.lookupRequests.DoChan(requestKey, func() (interface{}, error) { - defer reqSpan.Finish() - var lookupRes lookupResult - if err := rc.stopper.RunTaskWithErr(reqCtx, "rangecache: range lookup", func(ctx context.Context) (err error) { - // Clear the context's cancelation. This request services potentially many - // callers waiting for its result, and using the flight's leader's - // cancelation doesn't make sense. - ctx, cancel := rc.stopper.WithCancelOnQuiesce( - logtags.WithTags(context.Background(), logtags.FromContext(ctx))) - defer cancel() - ctx = tracing.ContextWithSpan(ctx, reqSpan) - + future, leader := rc.lookupRequests.DoChan(ctx, + requestKey, + singleflight.DoOpts{ + Stop: rc.stopper, + InheritCancelation: false, + }, + func(ctx context.Context) (interface{}, error) { + var lookupRes lookupResult // Attempt to perform the lookup by reading from a follower. If the // result is too old for the leader of this group, then we'll fall back // to reading from the leaseholder. Note that it's possible that the @@ -823,21 +811,24 @@ func (rc *RangeCache) tryLookup( // in that goroutine re-fetching. lookupRes.consistency = ReadFromFollower { + var err error lookupRes.EvictionToken, err = tryLookupImpl(ctx, rc, key, lookupRes.consistency, useReverseScan) - shouldReturnError := err != nil && !errors.Is(err, errFailedToFindNewerDescriptor) + if err != nil && !errors.Is(err, errFailedToFindNewerDescriptor) { + return nil, err + } gotFreshResult := err == nil && !lookupResultIsStale(lookupRes) - if shouldReturnError || gotFreshResult { - return err + if gotFreshResult { + return lookupRes, nil } } + var err error lookupRes.consistency = ReadFromLeaseholder lookupRes.EvictionToken, err = tryLookupImpl(ctx, rc, key, lookupRes.consistency, useReverseScan) - return err - }); err != nil { - return nil, err - } - return lookupRes, nil - }) + if err != nil { + return nil, err + } + return lookupRes, nil + }) // We must use DoChan above so that we can always unlock this mutex. This must // be done *after* the request has been added to the lookupRequests group, or @@ -849,19 +840,10 @@ func (rc *RangeCache) tryLookup( if rc.coalesced != nil { rc.coalesced <- struct{}{} } - // In the leader case, the callback takes ownership of reqSpan. If we're not - // the leader, we've created the span for no reason and have to finish it. - reqSpan.Finish() } // Wait for the inflight request. - var res singleflight.Result - select { - case res = <-resC: - case <-ctx.Done(): - return EvictionToken{}, errors.Wrap(ctx.Err(), "aborted during range descriptor lookup") - } - + res := future.WaitForResult(ctx) var s string if res.Err != nil { s = res.Err.Error() diff --git a/pkg/kv/kvclient/rangecache/range_cache_test.go b/pkg/kv/kvclient/rangecache/range_cache_test.go index f95a93de1aab..950434d37dc7 100644 --- a/pkg/kv/kvclient/rangecache/range_cache_test.go +++ b/pkg/kv/kvclient/rangecache/range_cache_test.go @@ -245,7 +245,6 @@ func staticSize(size int64) func() int64 { func initTestDescriptorDB(t *testing.T) *testDescriptorDB { st := cluster.MakeTestingClusterSettings() - tr := tracing.NewTracer() db := newTestDescriptorDB() for i, char := range "abcdefghijklmnopqrstuvwx" { // Create splits on each character: @@ -259,7 +258,7 @@ func initTestDescriptorDB(t *testing.T) *testDescriptorDB { } // TODO(andrei): don't leak this Stopper. Someone needs to Stop() it. db.stopper = stop.NewStopper() - db.cache = NewRangeCache(st, db, staticSize(2<<10), db.stopper, tr) + db.cache = NewRangeCache(st, db, staticSize(2<<10), db.stopper) return db } @@ -497,10 +496,9 @@ func TestLookupByKeyMin(t *testing.T) { ctx := context.Background() st := cluster.MakeTestingClusterSettings() - tr := tracing.NewTracer() stopper := stop.NewStopper() defer stopper.Stop(context.Background()) - cache := NewRangeCache(st, nil, staticSize(2<<10), stopper, tr) + cache := NewRangeCache(st, nil, staticSize(2<<10), stopper) startToMeta2Desc := roachpb.RangeDescriptor{ StartKey: roachpb.RKeyMin, EndKey: keys.RangeMetaKey(roachpb.RKey("a")), @@ -630,6 +628,11 @@ func TestRangeCacheContextCancellation(t *testing.T) { // Cancel the leader and check that it gets an error. cancel() expectContextCancellation(t, errC1) + select { + case err := <-errC2: + t.Fatalf("unexpected err: %v", err) + case <-time.After(time.Millisecond): + } // While lookups are still blocked, launch another one. This new request // should join the flight just like c2. @@ -905,7 +908,8 @@ func TestRangeCacheHandleDoubleSplit(t *testing.T) { var desc *roachpb.RangeDescriptor // Each request goes to a different key. var err error - ctx, getRecAndFinish := tracing.ContextWithRecordingSpan(ctx, db.cache.tracer, "test") + tracer := tracing.NewTracer() + ctx, getRecAndFinish := tracing.ContextWithRecordingSpan(ctx, tracer, "test") defer getRecAndFinish() tok, err := db.cache.lookupInternal( ctx, key, oldToken, reverseScan) @@ -1022,10 +1026,9 @@ func TestRangeCacheClearOverlapping(t *testing.T) { } st := cluster.MakeTestingClusterSettings() - tr := tracing.NewTracer() stopper := stop.NewStopper() defer stopper.Stop(ctx) - cache := NewRangeCache(st, nil, staticSize(2<<10), stopper, tr) + cache := NewRangeCache(st, nil, staticSize(2<<10), stopper) cache.addEntryLocked(&CacheEntry{desc: *defDesc}) // Now, add a new, overlapping set of descriptors. @@ -1199,10 +1202,9 @@ func TestRangeCacheClearOlderOverlapping(t *testing.T) { for _, tc := range testCases { t.Run("", func(t *testing.T) { st := cluster.MakeTestingClusterSettings() - tr := tracing.NewTracer() stopper := stop.NewStopper() defer stopper.Stop(ctx) - cache := NewRangeCache(st, nil /* db */, staticSize(2<<10), stopper, tr) + cache := NewRangeCache(st, nil /* db */, staticSize(2<<10), stopper) for _, d := range tc.cachedDescs { cache.Insert(ctx, roachpb.RangeInfo{Desc: d}) } @@ -1252,10 +1254,9 @@ func TestRangeCacheClearOverlappingMeta(t *testing.T) { } st := cluster.MakeTestingClusterSettings() - tr := tracing.NewTracer() stopper := stop.NewStopper() defer stopper.Stop(ctx) - cache := NewRangeCache(st, nil, staticSize(2<<10), stopper, tr) + cache := NewRangeCache(st, nil, staticSize(2<<10), stopper) cache.Insert(ctx, roachpb.RangeInfo{Desc: firstDesc}, roachpb.RangeInfo{Desc: restDesc}) @@ -1291,10 +1292,9 @@ func TestGetCachedRangeDescriptorInverted(t *testing.T) { } st := cluster.MakeTestingClusterSettings() - tr := tracing.NewTracer() stopper := stop.NewStopper() defer stopper.Stop(ctx) - cache := NewRangeCache(st, nil, staticSize(2<<10), stopper, tr) + cache := NewRangeCache(st, nil, staticSize(2<<10), stopper) for _, rd := range testData { cache.Insert(ctx, roachpb.RangeInfo{ Desc: rd, @@ -1430,10 +1430,9 @@ func TestRangeCacheGeneration(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { st := cluster.MakeTestingClusterSettings() - tr := tracing.NewTracer() stopper := stop.NewStopper() defer stopper.Stop(ctx) - cache := NewRangeCache(st, nil, staticSize(2<<10), stopper, tr) + cache := NewRangeCache(st, nil, staticSize(2<<10), stopper) cache.Insert(ctx, roachpb.RangeInfo{Desc: *descAM2}, roachpb.RangeInfo{Desc: *descMZ4}) cache.Insert(ctx, roachpb.RangeInfo{Desc: *tc.insertDesc}) @@ -1497,10 +1496,9 @@ func TestRangeCacheEvictAndReplace(t *testing.T) { startKey := desc1.StartKey st := cluster.MakeTestingClusterSettings() - tr := tracing.NewTracer() stopper := stop.NewStopper() defer stopper.Stop(ctx) - cache := NewRangeCache(st, nil, staticSize(2<<10), stopper, tr) + cache := NewRangeCache(st, nil, staticSize(2<<10), stopper) ri := roachpb.RangeInfo{ Desc: desc1, @@ -1634,10 +1632,9 @@ func TestRangeCacheSyncTokenAndMaybeUpdateCache(t *testing.T) { startKey := desc1.StartKey st := cluster.MakeTestingClusterSettings() - tr := tracing.NewTracer() stopper := stop.NewStopper() defer stopper.Stop(ctx) - cache := NewRangeCache(st, nil, staticSize(2<<10), stopper, tr) + cache := NewRangeCache(st, nil, staticSize(2<<10), stopper) const lag, lead = roachpb.LAG_BY_CLUSTER_SETTING, roachpb.LEAD_FOR_GLOBAL_READS testCases := []struct { name string diff --git a/pkg/kv/kvserver/BUILD.bazel b/pkg/kv/kvserver/BUILD.bazel index a8bc4dbbc856..a3de6e70326c 100644 --- a/pkg/kv/kvserver/BUILD.bazel +++ b/pkg/kv/kvserver/BUILD.bazel @@ -195,6 +195,7 @@ go_library( "//pkg/util/slidingwindow", "//pkg/util/stop", "//pkg/util/syncutil", + "//pkg/util/syncutil/singleflight", "//pkg/util/timeutil", "//pkg/util/tracing", "//pkg/util/tracing/tracingpb", diff --git a/pkg/kv/kvserver/liveness/liveness.go b/pkg/kv/kvserver/liveness/liveness.go index b32728e453e1..23aec0aa1ac3 100644 --- a/pkg/kv/kvserver/liveness/liveness.go +++ b/pkg/kv/kvserver/liveness/liveness.go @@ -201,7 +201,7 @@ type NodeLiveness struct { metrics Metrics onNodeDecommissioned func(livenesspb.Liveness) // noop if nil onNodeDecommissioning OnNodeDecommissionCallback // noop if nil - engineSyncs singleflight.Group + engineSyncs *singleflight.Group mu struct { syncutil.RWMutex @@ -302,6 +302,7 @@ func NewNodeLiveness(opts NodeLivenessOptions) *NodeLiveness { heartbeatToken: make(chan struct{}, 1), onNodeDecommissioned: opts.OnNodeDecommissioned, onNodeDecommissioning: opts.OnNodeDecommissioning, + engineSyncs: singleflight.NewGroup("engine sync", "engine"), } nl.metrics = Metrics{ LiveNodes: metric.NewFunctionalGauge(metaLiveNodes, nl.numLiveNodes), @@ -1270,24 +1271,23 @@ func (nl *NodeLiveness) updateLiveness( nl.mu.RLock() engines := nl.mu.engines nl.mu.RUnlock() - resultCs := make([]<-chan singleflight.Result, len(engines)) + resultCs := make([]singleflight.Future, len(engines)) for i, eng := range engines { eng := eng // pin the loop variable - resultCs[i], _ = nl.engineSyncs.DoChan(strconv.Itoa(i), func() (interface{}, error) { - return nil, nl.stopper.RunTaskWithErr(ctx, "liveness-hb-diskwrite", - func(ctx context.Context) error { - return storage.WriteSyncNoop(eng) - }) - }) + resultCs[i], _ = nl.engineSyncs.DoChan(ctx, + strconv.Itoa(i), + singleflight.DoOpts{ + Stop: nl.stopper, + InheritCancelation: false, + }, + func(ctx context.Context) (interface{}, error) { + return nil, storage.WriteSyncNoop(eng) + }) } for _, resultC := range resultCs { - select { - case r := <-resultC: - if r.Err != nil { - return Record{}, errors.Wrapf(r.Err, "disk write failed while updating node liveness") - } - case <-ctx.Done(): - return Record{}, ctx.Err() + r := resultC.WaitForResult(ctx) + if r.Err != nil { + return Record{}, errors.Wrapf(r.Err, "disk write failed while updating node liveness") } } diff --git a/pkg/kv/kvserver/protectedts/ptcache/cache.go b/pkg/kv/kvserver/protectedts/ptcache/cache.go index 6d277ffb6561..af1e2b6375d6 100644 --- a/pkg/kv/kvserver/protectedts/ptcache/cache.go +++ b/pkg/kv/kvserver/protectedts/ptcache/cache.go @@ -35,7 +35,7 @@ type Cache struct { storage protectedts.Storage stopper *stop.Stopper settings *cluster.Settings - sf singleflight.Group + sf *singleflight.Group mu struct { syncutil.RWMutex @@ -67,6 +67,7 @@ func New(config Config) *Cache { db: config.DB, storage: config.Storage, settings: config.Settings, + sf: singleflight.NewGroup("refresh-protectedts-cache", singleflight.NoTags), } c.mu.recordsByID = make(map[uuid.UUID]*ptpb.Record) return c @@ -113,14 +114,17 @@ const refreshKey = "" // Refresh is part of the protectedts.Cache interface. func (c *Cache) Refresh(ctx context.Context, asOf hlc.Timestamp) error { for !c.upToDate(asOf) { - ch, _ := c.sf.DoChan(refreshKey, c.doSingleFlightUpdate) - select { - case <-ctx.Done(): - return ctx.Err() - case res := <-ch: - if res.Err != nil { - return res.Err - } + future, _ := c.sf.DoChan(ctx, + refreshKey, + singleflight.DoOpts{ + Stop: c.stopper, + InheritCancelation: false, + }, + c.doSingleFlightUpdate, + ) + res := future.WaitForResult(ctx) + if res.Err != nil { + return res.Err } } return nil @@ -168,7 +172,7 @@ func (c *Cache) periodicallyRefreshProtectedtsCache(ctx context.Context) { defer timer.Stop() timer.Reset(0) // Read immediately upon startup var lastReset time.Time - var doneCh <-chan singleflight.Result + var future singleflight.Future // TODO(ajwerner): consider resetting the timer when the state is updated // due to a call to Refresh. for { @@ -176,7 +180,14 @@ func (c *Cache) periodicallyRefreshProtectedtsCache(ctx context.Context) { case <-timer.C: // Let's not reset the timer until we get our response. timer.Read = true - doneCh, _ = c.sf.DoChan(refreshKey, c.doSingleFlightUpdate) + future, _ = c.sf.DoChan(ctx, + refreshKey, + singleflight.DoOpts{ + Stop: c.stopper, + InheritCancelation: false, + }, + c.doSingleFlightUpdate, + ) case <-settingChanged: if timer.Read { // we're currently fetching continue @@ -187,12 +198,14 @@ func (c *Cache) periodicallyRefreshProtectedtsCache(ctx context.Context) { nextUpdate := interval - timeutil.Since(lastReset) timer.Reset(nextUpdate) lastReset = timeutil.Now() - case res := <-doneCh: + case <-future.C(): + res := future.WaitForResult(ctx) if res.Err != nil { if ctx.Err() == nil { log.Errorf(ctx, "failed to refresh protected timestamps: %v", res.Err) } } + future.Reset() timer.Reset(protectedts.PollInterval.Get(&c.settings.SV)) lastReset = timeutil.Now() case <-c.stopper.ShouldQuiesce(): @@ -201,21 +214,13 @@ func (c *Cache) periodicallyRefreshProtectedtsCache(ctx context.Context) { } } -func (c *Cache) doSingleFlightUpdate() (interface{}, error) { - // TODO(ajwerner): add log tags to the context. - ctx, cancel := c.stopper.WithCancelOnQuiesce(context.Background()) - defer cancel() - return nil, c.stopper.RunTaskWithErr(ctx, - "refresh-protectedts-cache", c.doUpdate) -} - func (c *Cache) getMetadata() ptpb.Metadata { c.mu.RLock() defer c.mu.RUnlock() return c.mu.state.Metadata } -func (c *Cache) doUpdate(ctx context.Context) error { +func (c *Cache) doSingleFlightUpdate(ctx context.Context) (interface{}, error) { // NB: doUpdate is only ever called underneath c.singleFlight and thus is // never called concurrently. Due to the lack of concurrency there are no // concerns about races as this is the only method which writes to the Cache's @@ -248,7 +253,7 @@ func (c *Cache) doUpdate(ctx context.Context) error { return nil }) if err != nil { - return err + return nil, err } c.mu.Lock() defer c.mu.Unlock() @@ -263,7 +268,7 @@ func (c *Cache) doUpdate(ctx context.Context) error { c.mu.recordsByID[r.ID.GetUUID()] = r } } - return nil + return nil, nil } // upToDate returns true if the lastUpdate for the cache is at least asOf. diff --git a/pkg/kv/kvserver/protectedts/ptcache/cache_test.go b/pkg/kv/kvserver/protectedts/ptcache/cache_test.go index adc2ec298e14..6ba7c63662ad 100644 --- a/pkg/kv/kvserver/protectedts/ptcache/cache_test.go +++ b/pkg/kv/kvserver/protectedts/ptcache/cache_test.go @@ -176,7 +176,7 @@ func TestRefresh(t *testing.T) { return nil }) go func() { time.Sleep(time.Millisecond); cancel() }() - require.EqualError(t, c.Refresh(withCancel, s.Clock().Now()), + require.ErrorContains(t, c.Refresh(withCancel, s.Clock().Now()), context.Canceled.Error()) close(done) }) diff --git a/pkg/kv/kvserver/rangefeed/BUILD.bazel b/pkg/kv/kvserver/rangefeed/BUILD.bazel index bac4f6ebbc90..084b10d96504 100644 --- a/pkg/kv/kvserver/rangefeed/BUILD.bazel +++ b/pkg/kv/kvserver/rangefeed/BUILD.bazel @@ -33,7 +33,6 @@ go_library( "//pkg/util/retry", "//pkg/util/stop", "//pkg/util/syncutil", - "//pkg/util/syncutil/singleflight", "//pkg/util/timeutil", "//pkg/util/uuid", "@com_github_cockroachdb_errors//:errors", diff --git a/pkg/kv/kvserver/rangefeed/metrics.go b/pkg/kv/kvserver/rangefeed/metrics.go index 280fcf6a41f8..258f950c9124 100644 --- a/pkg/kv/kvserver/rangefeed/metrics.go +++ b/pkg/kv/kvserver/rangefeed/metrics.go @@ -15,7 +15,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/cockroach/pkg/util/metric" - "github.com/cockroachdb/cockroach/pkg/util/syncutil/singleflight" ) var ( @@ -45,8 +44,7 @@ type Metrics struct { RangeFeedBudgetExhausted *metric.Counter RangeFeedBudgetBlocked *metric.Counter - RangeFeedSlowClosedTimestampLogN log.EveryN - RangeFeedSlowClosedTimestampNudge singleflight.Group + RangeFeedSlowClosedTimestampLogN log.EveryN // RangeFeedSlowClosedTimestampNudgeSem bounds the amount of work that can be // spun up on behalf of the RangeFeed nudger. We don't expect to hit this // limit, but it's here to limit the effect on stability in case something diff --git a/pkg/kv/kvserver/replica_rangefeed.go b/pkg/kv/kvserver/replica_rangefeed.go index d1e1abd37a23..5ad89fe8e095 100644 --- a/pkg/kv/kvserver/replica_rangefeed.go +++ b/pkg/kv/kvserver/replica_rangefeed.go @@ -32,8 +32,8 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/hlc" "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/cockroach/pkg/util/syncutil" + "github.com/cockroachdb/cockroach/pkg/util/syncutil/singleflight" "github.com/cockroachdb/cockroach/pkg/util/timeutil" - "github.com/cockroachdb/cockroach/pkg/util/tracing" "github.com/cockroachdb/cockroach/pkg/util/uuid" "github.com/cockroachdb/errors" ) @@ -721,37 +721,31 @@ func (r *Replica) handleClosedTimestampUpdateRaftMuLocked( } // Asynchronously attempt to nudge the closed timestamp in case it's stuck. - key := fmt.Sprintf(`rangefeed-slow-closed-timestamp-nudge-r%d`, r.RangeID) + key := fmt.Sprintf(`r%d`, r.RangeID) // Ignore the result of DoChan since, to keep this all async, it always // returns nil and any errors are logged by the closure passed to the // `DoChan` call. - taskCtx, sp := tracing.EnsureForkSpan(ctx, r.AmbientContext.Tracer, key) - _, leader := m.RangeFeedSlowClosedTimestampNudge.DoChan(key, func() (interface{}, error) { - defer sp.Finish() - // Also ignore the result of RunTask, since it only returns errors when - // the task didn't start because we're shutting down. - _ = r.store.stopper.RunTask(taskCtx, key, func(ctx context.Context) { + _, _ = r.store.rangeFeedSlowClosedTimestampNudge.DoChan(ctx, + key, + singleflight.DoOpts{ + Stop: r.store.stopper, + InheritCancelation: false, + }, + func(ctx context.Context) (interface{}, error) { // Limit the amount of work this can suddenly spin up. In particular, // this is to protect against the case of a system-wide slowdown on // closed timestamps, which would otherwise potentially launch a huge // number of lease acquisitions all at once. select { case m.RangeFeedSlowClosedTimestampNudgeSem <- struct{}{}: - case <-r.store.stopper.ShouldQuiesce(): - return + case <-ctx.Done(): } defer func() { <-m.RangeFeedSlowClosedTimestampNudgeSem }() if err := r.ensureClosedTimestampStarted(ctx); err != nil { log.Infof(ctx, `RangeFeed failed to nudge: %s`, err) } + return nil, nil }) - return nil, nil - }) - if !leader { - // In the leader case, we've passed ownership of sp to the task. If the - // task was not triggered, though, it's up to us to Finish() it. - sp.Finish() - } } // If the closed timestamp is not empty, inform the Processor. diff --git a/pkg/kv/kvserver/store.go b/pkg/kv/kvserver/store.go index c7d4e25ccb60..1c08bbe76e6c 100644 --- a/pkg/kv/kvserver/store.go +++ b/pkg/kv/kvserver/store.go @@ -80,6 +80,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/shuffle" "github.com/cockroachdb/cockroach/pkg/util/stop" "github.com/cockroachdb/cockroach/pkg/util/syncutil" + "github.com/cockroachdb/cockroach/pkg/util/syncutil/singleflight" "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/cockroach/pkg/util/tracing" "github.com/cockroachdb/cockroach/pkg/util/tracing/tracingpb" @@ -994,6 +995,8 @@ type Store struct { computeInitialMetrics sync.Once systemConfigUpdateQueueRateLimiter *quotapool.RateLimiter spanConfigUpdateQueueRateLimiter *quotapool.RateLimiter + + rangeFeedSlowClosedTimestampNudge *singleflight.Group } var _ kv.Sender = &Store{} @@ -1214,13 +1217,14 @@ func NewStore( iot := ioThresholds{} iot.Replace(nil, 1.0) // init as empty s := &Store{ - cfg: cfg, - db: cfg.DB, // TODO(tschottdorf): remove redundancy. - engine: eng, - nodeDesc: nodeDesc, - metrics: newStoreMetrics(cfg.HistogramWindowInterval), - ctSender: cfg.ClosedTimestampSender, - ioThresholds: &iot, + cfg: cfg, + db: cfg.DB, // TODO(tschottdorf): remove redundancy. + engine: eng, + nodeDesc: nodeDesc, + metrics: newStoreMetrics(cfg.HistogramWindowInterval), + ctSender: cfg.ClosedTimestampSender, + ioThresholds: &iot, + rangeFeedSlowClosedTimestampNudge: singleflight.NewGroup("rangfeed-ct-nudge", "range"), } s.ioThreshold.t = &admissionpb.IOThreshold{} var allocatorStorePool storepool.AllocatorStorePool diff --git a/pkg/kv/kvserver/txnrecovery/manager.go b/pkg/kv/kvserver/txnrecovery/manager.go index 4406ba532006..ba22fe197380 100644 --- a/pkg/kv/kvserver/txnrecovery/manager.go +++ b/pkg/kv/kvserver/txnrecovery/manager.go @@ -63,7 +63,7 @@ type manager struct { db *kv.DB stopper *stop.Stopper metrics Metrics - txns singleflight.Group + txns *singleflight.Group sem chan struct{} } @@ -76,6 +76,7 @@ func NewManager(ac log.AmbientContext, clock *hlc.Clock, db *kv.DB, stopper *sto db: db, stopper: stopper, metrics: makeMetrics(), + txns: singleflight.NewGroup("resolve indeterminate commit", "txn"), sem: make(chan struct{}, defaultTaskLimit), } } @@ -92,23 +93,23 @@ func (m *manager) ResolveIndeterminateCommit( // Launch a single-flight task to recover the transaction. This may be // coalesced with other recovery attempts for the same transaction. log.VEventf(ctx, 2, "recovering txn %s from indeterminate commit", txn.ID.Short()) - resC, _ := m.txns.DoChan(txn.ID.String(), func() (interface{}, error) { - return m.resolveIndeterminateCommitForTxn(txn) - }) - - // Wait for the inflight request. - select { - case res := <-resC: - if res.Err != nil { - log.VEventf(ctx, 2, "recovery error: %v", res.Err) - return nil, res.Err - } - txn := res.Val.(*roachpb.Transaction) - log.VEventf(ctx, 2, "recovered txn %s with status: %s", txn.ID.Short(), txn.Status) - return txn, nil - case <-ctx.Done(): - return nil, errors.Wrap(ctx.Err(), "abandoned indeterminate commit recovery") + future, _ := m.txns.DoChan(ctx, + txn.ID.String(), + singleflight.DoOpts{ + InheritCancelation: false, + Stop: m.stopper, + }, + func(ctx context.Context) (interface{}, error) { + return m.resolveIndeterminateCommitForTxn(ctx, txn) + }) + res := future.WaitForResult(ctx) + if res.Err != nil { + log.VEventf(ctx, 2, "recovery error: %v", res.Err) + return nil, errors.Wrap(res.Err, "failed indeterminate commit recovery") } + txn = res.Val.(*roachpb.Transaction) + log.VEventf(ctx, 2, "recovered txn %s with status: %s", txn.ID.Short(), txn.Status) + return txn, nil } // resolveIndeterminateCommitForTxn attempts to to resolve the status of @@ -120,48 +121,36 @@ func (m *manager) ResolveIndeterminateCommit( // is determined, the method issues a RecoverTxn request with a summary of their // outcome. func (m *manager) resolveIndeterminateCommitForTxn( - txn *roachpb.Transaction, + ctx context.Context, txn *roachpb.Transaction, ) (resTxn *roachpb.Transaction, resErr error) { // Record the recovery attempt in the Manager's metrics. onComplete := m.updateMetrics() defer func() { onComplete(resTxn, resErr) }() - // TODO(nvanbenschoten): Set up tracing. - ctx := m.AnnotateCtx(context.Background()) - - // Launch the recovery task. - resErr = m.stopper.RunTaskWithErr(ctx, - "recovery.manager: resolving indeterminate commit", - func(ctx context.Context) error { - // Grab semaphore with defaultTaskLimit. - select { - case m.sem <- struct{}{}: - defer func() { <-m.sem }() - case <-m.stopper.ShouldQuiesce(): - return stop.ErrUnavailable - } + // Grab semaphore with defaultTaskLimit. + select { + case m.sem <- struct{}{}: + defer func() { <-m.sem }() + case <-ctx.Done(): + return nil, ctx.Err() + } - // We probe to determine whether the transaction is implicitly - // committed or not. If not, we prevent it from ever becoming - // implicitly committed at this (epoch, timestamp) pair. - preventedIntent, changedTxn, err := m.resolveIndeterminateCommitForTxnProbe(ctx, txn) - if err != nil { - return err - } - if changedTxn != nil { - resTxn = changedTxn - return nil - } + // We probe to determine whether the transaction is implicitly + // committed or not. If not, we prevent it from ever becoming + // implicitly committed at this (epoch, timestamp) pair. + preventedIntent, changedTxn, err := m.resolveIndeterminateCommitForTxnProbe(ctx, txn) + if err != nil { + return nil, err + } + if changedTxn != nil { + return changedTxn, nil + } - // Now that we know whether the transaction was implicitly committed - // or not (implicitly committed = !preventedIntent), we attempt to - // recover it. If this succeeds, it will either move the transaction - // record to a COMMITTED or ABORTED status. - resTxn, err = m.resolveIndeterminateCommitForTxnRecover(ctx, txn, preventedIntent) - return err - }, - ) - return resTxn, resErr + // Now that we know whether the transaction was implicitly committed + // or not (implicitly committed = !preventedIntent), we attempt to + // recover it. If this succeeds, it will either move the transaction + // record to a COMMITTED or ABORTED status. + return m.resolveIndeterminateCommitForTxnRecover(ctx, txn, preventedIntent) } // resolveIndeterminateCommitForTxnProbe performs the "probing phase" of the diff --git a/pkg/sql/authorization.go b/pkg/sql/authorization.go index b6dc3f9c8926..18f6d5154b6a 100644 --- a/pkg/sql/authorization.go +++ b/pkg/sql/authorization.go @@ -45,7 +45,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/syncutil" "github.com/cockroachdb/cockroach/pkg/util/syncutil/singleflight" "github.com/cockroachdb/errors" - "github.com/cockroachdb/logtags" ) // MembershipCache is a shared cache for role membership information. @@ -57,15 +56,16 @@ type MembershipCache struct { userCache map[username.SQLUsername]userRoleMembership // populateCacheGroup ensures that there is at most one request in-flight // for each key. - populateCacheGroup singleflight.Group + populateCacheGroup *singleflight.Group stopper *stop.Stopper } // NewMembershipCache initializes a new MembershipCache. func NewMembershipCache(account mon.BoundAccount, stopper *stop.Stopper) *MembershipCache { return &MembershipCache{ - boundAccount: account, - stopper: stopper, + boundAccount: account, + populateCacheGroup: singleflight.NewGroup("lookup role membership", "key"), + stopper: stopper, } } @@ -510,31 +510,24 @@ func MemberOfWithAdminOption( // in-flight for each user. The role_memberships table version is also part // of the request key so that we don't read data from an old version of the // table. - ch, _ := roleMembersCache.populateCacheGroup.DoChan( + future, _ := roleMembersCache.populateCacheGroup.DoChan(ctx, fmt.Sprintf("%s-%d", member.Normalized(), tableVersion), - func() (interface{}, error) { - // Use a different context to fetch, so that it isn't possible for - // one query to timeout and cause all the goroutines that are waiting - // to get a timeout error. - ctx, cancel := roleMembersCache.stopper.WithCancelOnQuiesce( - logtags.WithTags(context.Background(), logtags.FromContext(ctx))) - defer cancel() + singleflight.DoOpts{ + Stop: roleMembersCache.stopper, + InheritCancelation: false, + }, + func(ctx context.Context) (interface{}, error) { return resolveMemberOfWithAdminOption( ctx, member, ie, txn, useSingleQueryForRoleMembershipCache.Get(execCfg.SV()), ) - }, - ) + }) var memberships map[username.SQLUsername]bool - select { - case res := <-ch: - if res.Err != nil { - return nil, res.Err - } - memberships = res.Val.(map[username.SQLUsername]bool) - case <-ctx.Done(): - return nil, ctx.Err() + res := future.WaitForResult(ctx) + if res.Err != nil { + return nil, res.Err } + memberships = res.Val.(map[username.SQLUsername]bool) func() { // Update membership if the table version hasn't changed. diff --git a/pkg/sql/cacheutil/BUILD.bazel b/pkg/sql/cacheutil/BUILD.bazel index 4a3216ce92d2..b1409dcd6605 100644 --- a/pkg/sql/cacheutil/BUILD.bazel +++ b/pkg/sql/cacheutil/BUILD.bazel @@ -14,7 +14,6 @@ go_library( "//pkg/util/syncutil", "//pkg/util/syncutil/singleflight", "@com_github_cockroachdb_errors//:errors", - "@com_github_cockroachdb_logtags//:logtags", ], ) diff --git a/pkg/sql/cacheutil/cache.go b/pkg/sql/cacheutil/cache.go index 3f1cf682c123..e3ead6ebb9c4 100644 --- a/pkg/sql/cacheutil/cache.go +++ b/pkg/sql/cacheutil/cache.go @@ -21,7 +21,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/syncutil" "github.com/cockroachdb/cockroach/pkg/util/syncutil/singleflight" "github.com/cockroachdb/errors" - "github.com/cockroachdb/logtags" ) // Cache is a shared cache for hashed passwords and other information used @@ -32,7 +31,7 @@ type Cache struct { tableVersions []descpb.DescriptorVersion // TODO(richardjcai): In go1.18 we can use generics. cache map[interface{}]interface{} - populateCacheGroup singleflight.Group + populateCacheGroup *singleflight.Group stopper *stop.Stopper } @@ -42,9 +41,10 @@ type Cache struct { func NewCache(account mon.BoundAccount, stopper *stop.Stopper, numSystemTables int) *Cache { tableVersions := make([]descpb.DescriptorVersion, numSystemTables) return &Cache{ - tableVersions: tableVersions, - boundAccount: account, - stopper: stopper, + tableVersions: tableVersions, + boundAccount: account, + populateCacheGroup: singleflight.NewGroup("load-value", "key"), + stopper: stopper, } } @@ -61,25 +61,19 @@ func (c *Cache) GetValueLocked(key interface{}) (interface{}, bool) { func (c *Cache) LoadValueOutsideOfCacheSingleFlight( ctx context.Context, requestKey string, fn func(loadCtx context.Context) (interface{}, error), ) (interface{}, error) { - ch, _ := c.populateCacheGroup.DoChan(requestKey, func() (interface{}, error) { - // Use a different context to fetch, so that it isn't possible for - // one query to timeout and cause all the goroutines that are waiting - // to get a timeout error. - loadCtx, cancel := c.stopper.WithCancelOnQuiesce( - logtags.WithTags(context.Background(), logtags.FromContext(ctx)), - ) - defer cancel() - return fn(loadCtx) - }) - select { - case res := <-ch: - if res.Err != nil { - return nil, res.Err - } - return res.Val, nil - case <-ctx.Done(): - return nil, ctx.Err() + future, _ := c.populateCacheGroup.DoChan(ctx, + requestKey, + singleflight.DoOpts{ + Stop: c.stopper, + InheritCancelation: false, + }, + fn, + ) + res := future.WaitForResult(ctx) + if res.Err != nil { + return nil, res.Err } + return res.Val, nil } // MaybeWriteBackToCache tries to put the key, value into the diff --git a/pkg/sql/catalog/hydrateddesc/hydratedcache.go b/pkg/sql/catalog/hydrateddesc/hydratedcache.go index 3769172264ca..a6d4b734235f 100644 --- a/pkg/sql/catalog/hydrateddesc/hydratedcache.go +++ b/pkg/sql/catalog/hydrateddesc/hydratedcache.go @@ -50,7 +50,7 @@ import ( // and re-construct the immutable descriptors in most cases. type Cache struct { settings *cluster.Settings - g singleflight.Group + g *singleflight.Group metrics Metrics mu struct { syncutil.Mutex @@ -160,6 +160,7 @@ var CacheSize = settings.RegisterIntSetting( func NewCache(settings *cluster.Settings) *Cache { c := &Cache{ settings: settings, + g: singleflight.NewGroup("get-hydrated-descriptor", "key"), metrics: makeMetrics(), } c.mu.cache = cache.NewOrderedCache(cache.Config{ @@ -349,7 +350,7 @@ func (c *Cache) GetHydratedDescriptor( // should return it. Other goroutines will have to go back around and // attempt to read from the cache. var called bool - res, _, err := c.g.Do(groupKey, func() (interface{}, error) { + res, _, err := c.g.Do(ctx, groupKey, func(ctx context.Context) (interface{}, error) { called = true cachedRes := cachedTypeDescriptorResolver{ underlying: res, diff --git a/pkg/sql/catalog/lease/lease.go b/pkg/sql/catalog/lease/lease.go index d9aa5af1e87a..cc918d7bf3ef 100644 --- a/pkg/sql/catalog/lease/lease.go +++ b/pkg/sql/catalog/lease/lease.go @@ -15,6 +15,7 @@ import ( "context" "fmt" "sort" + "strconv" "sync" "sync/atomic" "time" @@ -493,55 +494,50 @@ func acquireNodeLease( ctx context.Context, m *Manager, id descpb.ID, typ AcquireType, ) (bool, error) { start := timeutil.Now() - log.VEventf(ctx, 2, "acquiring lease for descriptor %d", id) + log.VEventf(ctx, 2, "acquiring lease for descriptor %d...", id) var toRelease *storedLease - resultChan, didAcquire := m.storage.group.DoChan(fmt.Sprintf("acquire%d", id), func() (interface{}, error) { - // Note that we use a new `context` here to avoid a situation where a cancellation - // of the first context cancels other callers to the `acquireNodeLease()` method, - // because of its use of `singleflight.Group`. See issue #41780 for how this has - // happened. - lt := logtags.FromContext(ctx) - ctx, cancel := m.stopper.WithCancelOnQuiesce(logtags.AddTags(m.ambientCtx.AnnotateCtx(context.Background()), lt)) - defer cancel() - if m.IsDraining() { - return nil, errors.New("cannot acquire lease when draining") - } - newest := m.findNewest(id) - var minExpiration hlc.Timestamp - if newest != nil { - minExpiration = newest.getExpiration() - } - desc, expiration, regionPrefix, err := m.storage.acquire(ctx, minExpiration, id) - if err != nil { - return nil, err - } - t := m.findDescriptorState(id, false /* create */) - t.mu.Lock() - t.mu.takenOffline = false - defer t.mu.Unlock() - var newDescVersionState *descriptorVersionState - newDescVersionState, toRelease, err = t.upsertLeaseLocked(ctx, desc, expiration, regionPrefix) - if err != nil { - return nil, err - } - if newDescVersionState != nil { - m.names.insert(newDescVersionState) - } - if toRelease != nil { - releaseLease(ctx, toRelease, m) - } - return true, nil - }) + future, didAcquire := m.storage.group.DoChan(ctx, + strconv.Itoa(int(id)), + singleflight.DoOpts{ + Stop: m.stopper, + InheritCancelation: false, + }, + func(ctx context.Context) (interface{}, error) { + if m.IsDraining() { + return nil, errors.New("cannot acquire lease when draining") + } + newest := m.findNewest(id) + var minExpiration hlc.Timestamp + if newest != nil { + minExpiration = newest.getExpiration() + } + desc, expiration, regionPrefix, err := m.storage.acquire(ctx, minExpiration, id) + if err != nil { + return nil, err + } + t := m.findDescriptorState(id, false /* create */) + t.mu.Lock() + t.mu.takenOffline = false + defer t.mu.Unlock() + var newDescVersionState *descriptorVersionState + newDescVersionState, toRelease, err = t.upsertLeaseLocked(ctx, desc, expiration, regionPrefix) + if err != nil { + return nil, err + } + if newDescVersionState != nil { + m.names.insert(newDescVersionState) + } + if toRelease != nil { + releaseLease(ctx, toRelease, m) + } + return true, nil + }) if m.testingKnobs.LeaseStoreTestingKnobs.LeaseAcquireResultBlockEvent != nil { m.testingKnobs.LeaseStoreTestingKnobs.LeaseAcquireResultBlockEvent(typ, id) } - select { - case <-ctx.Done(): - return false, ctx.Err() - case result := <-resultChan: - if result.Err != nil { - return false, result.Err - } + result := future.WaitForResult(ctx) + if result.Err != nil { + return false, result.Err } log.VEventf(ctx, 2, "acquired lease for descriptor %d, took %v", id, timeutil.Since(start)) return didAcquire, nil @@ -720,7 +716,7 @@ func NewLeaseManager( internalExecutor: internalExecutor, settings: settings, codec: codec, - group: &singleflight.Group{}, + group: singleflight.NewGroup("acquire-lease", "descriptor ID"), testingKnobs: testingKnobs.LeaseStoreTestingKnobs, outstandingLeases: metric.NewGauge(metric.Metadata{ Name: "sql.leases.active", diff --git a/pkg/sql/distsql_physical_planner_test.go b/pkg/sql/distsql_physical_planner_test.go index 7bab41625173..8a30c596187a 100644 --- a/pkg/sql/distsql_physical_planner_test.go +++ b/pkg/sql/distsql_physical_planner_test.go @@ -54,7 +54,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/metric" "github.com/cockroachdb/cockroach/pkg/util/randutil" "github.com/cockroachdb/cockroach/pkg/util/stop" - "github.com/cockroachdb/cockroach/pkg/util/tracing" "github.com/cockroachdb/errors" "github.com/stretchr/testify/require" ) @@ -194,10 +193,9 @@ func TestDistSQLReceiverUpdatesCaches(t *testing.T) { size := func() int64 { return 2 << 10 } st := cluster.MakeTestingClusterSettings() - tr := tracing.NewTracer() stopper := stop.NewStopper() defer stopper.Stop(ctx) - rangeCache := rangecache.NewRangeCache(st, nil /* db */, size, stopper, tr) + rangeCache := rangecache.NewRangeCache(st, nil /* db */, size, stopper) r := MakeDistSQLReceiver( ctx, &errOnlyResultWriter{}, /* resultWriter */ diff --git a/pkg/sql/opt/exec/execbuilder/testdata/upsert b/pkg/sql/opt/exec/execbuilder/testdata/upsert index 855645857de4..8ddd03c4ed09 100644 --- a/pkg/sql/opt/exec/execbuilder/testdata/upsert +++ b/pkg/sql/opt/exec/execbuilder/testdata/upsert @@ -672,6 +672,11 @@ statement ok CREATE TABLE tu (a INT PRIMARY KEY, b INT, c INT, d INT, FAMILY (a), FAMILY (b), FAMILY (c,d)); INSERT INTO tu VALUES (1, 2, 3, 4) +# Force the leasing of the `tu` descriptor, otherwise the trace below is +# polluted. +statement ok +SELECT 1 FROM tu; + statement ok SET tracing = on,kv,results; UPSERT INTO tu VALUES (1, NULL, NULL, NULL); SET tracing = off diff --git a/pkg/sql/rowexec/tablereader_test.go b/pkg/sql/rowexec/tablereader_test.go index 05e9fa650b05..3de63778ab0b 100644 --- a/pkg/sql/rowexec/tablereader_test.go +++ b/pkg/sql/rowexec/tablereader_test.go @@ -132,8 +132,10 @@ func TestTableReader(t *testing.T) { Mon: evalCtx.TestingMon, Cfg: &execinfra.ServerConfig{ Settings: st, - RangeCache: rangecache.NewRangeCache(s.ClusterSettings(), nil, - func() int64 { return 2 << 10 }, s.Stopper(), s.TracerI().(*tracing.Tracer)), + RangeCache: rangecache.NewRangeCache( + s.ClusterSettings(), nil, + func() int64 { return 2 << 10 }, s.Stopper(), + ), }, Txn: kv.NewTxn(ctx, s.DB(), s.NodeID()), NodeID: evalCtx.NodeID, @@ -383,8 +385,10 @@ func TestLimitScans(t *testing.T) { Mon: evalCtx.TestingMon, Cfg: &execinfra.ServerConfig{ Settings: st, - RangeCache: rangecache.NewRangeCache(s.ClusterSettings(), nil, - func() int64 { return 2 << 10 }, s.Stopper(), s.TracerI().(*tracing.Tracer)), + RangeCache: rangecache.NewRangeCache( + s.ClusterSettings(), nil, + func() int64 { return 2 << 10 }, s.Stopper(), + ), }, Txn: kv.NewTxn(ctx, kvDB, s.NodeID()), NodeID: evalCtx.NodeID, @@ -497,8 +501,10 @@ func BenchmarkTableReader(b *testing.B) { Mon: evalCtx.TestingMon, Cfg: &execinfra.ServerConfig{ Settings: st, - RangeCache: rangecache.NewRangeCache(s.ClusterSettings(), nil, - func() int64 { return 2 << 10 }, s.Stopper(), s.TracerI().(*tracing.Tracer)), + RangeCache: rangecache.NewRangeCache( + s.ClusterSettings(), nil, + func() int64 { return 2 << 10 }, s.Stopper(), + ), }, Txn: kv.NewTxn(ctx, s.DB(), s.NodeID()), NodeID: evalCtx.NodeID, diff --git a/pkg/sql/sessioninit/BUILD.bazel b/pkg/sql/sessioninit/BUILD.bazel index d2a379ba1a74..dcaff275f8a7 100644 --- a/pkg/sql/sessioninit/BUILD.bazel +++ b/pkg/sql/sessioninit/BUILD.bazel @@ -24,7 +24,6 @@ go_library( "//pkg/util/stop", "//pkg/util/syncutil", "//pkg/util/syncutil/singleflight", - "@com_github_cockroachdb_logtags//:logtags", ], ) diff --git a/pkg/sql/sessioninit/cache.go b/pkg/sql/sessioninit/cache.go index 7c0d48f4fc1d..9a7f2107393c 100644 --- a/pkg/sql/sessioninit/cache.go +++ b/pkg/sql/sessioninit/cache.go @@ -29,7 +29,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/stop" "github.com/cockroachdb/cockroach/pkg/util/syncutil" "github.com/cockroachdb/cockroach/pkg/util/syncutil/singleflight" - "github.com/cockroachdb/logtags" ) // CacheEnabledSettingName is the name of the CacheEnabled cluster setting. @@ -59,7 +58,7 @@ type Cache struct { settingsCache map[SettingsCacheKey][]string // populateCacheGroup is used to ensure that there is at most one in-flight // request for populating each cache entry. - populateCacheGroup singleflight.Group + populateCacheGroup *singleflight.Group stopper *stop.Stopper } @@ -93,8 +92,9 @@ type SettingsCacheEntry struct { // NewCache initializes a new sessioninit.Cache. func NewCache(account mon.BoundAccount, stopper *stop.Stopper) *Cache { return &Cache{ - boundAccount: account, - stopper: stopper, + boundAccount: account, + populateCacheGroup: singleflight.NewGroup("load-value", "key"), + stopper: stopper, } } @@ -208,25 +208,19 @@ func (a *Cache) readAuthInfoFromCache( func (a *Cache) loadValueOutsideOfCache( ctx context.Context, requestKey string, fn func(loadCtx context.Context) (interface{}, error), ) (interface{}, error) { - ch, _ := a.populateCacheGroup.DoChan(requestKey, func() (interface{}, error) { - // Use a different context to fetch, so that it isn't possible for - // one query to timeout and cause all the goroutines that are waiting - // to get a timeout error. - loadCtx, cancel := a.stopper.WithCancelOnQuiesce( - logtags.WithTags(context.Background(), logtags.FromContext(ctx)), - ) - defer cancel() - return fn(loadCtx) - }) - select { - case res := <-ch: - if res.Err != nil { - return AuthInfo{}, res.Err - } - return res.Val, nil - case <-ctx.Done(): - return AuthInfo{}, ctx.Err() + future, _ := a.populateCacheGroup.DoChan(ctx, + requestKey, + singleflight.DoOpts{ + Stop: a.stopper, + InheritCancelation: false, + }, + fn, + ) + res := future.WaitForResult(ctx) + if res.Err != nil { + return AuthInfo{}, res.Err } + return res.Val, nil } // maybeWriteAuthInfoBackToCache tries to put the fetched AuthInfo into the diff --git a/pkg/sql/sqlliveness/slstorage/BUILD.bazel b/pkg/sql/sqlliveness/slstorage/BUILD.bazel index 1f8a7d824e38..1c3921e11230 100644 --- a/pkg/sql/sqlliveness/slstorage/BUILD.bazel +++ b/pkg/sql/sqlliveness/slstorage/BUILD.bazel @@ -34,7 +34,6 @@ go_library( "//pkg/util/timeutil", "//pkg/util/uuid", "@com_github_cockroachdb_errors//:errors", - "@com_github_cockroachdb_logtags//:logtags", "@com_github_cockroachdb_redact//:redact", "@com_github_prometheus_client_model//go", ], diff --git a/pkg/sql/sqlliveness/slstorage/slstorage.go b/pkg/sql/sqlliveness/slstorage/slstorage.go index 52ba649d64cb..3f55f6b46bd5 100644 --- a/pkg/sql/sqlliveness/slstorage/slstorage.go +++ b/pkg/sql/sqlliveness/slstorage/slstorage.go @@ -33,7 +33,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/syncutil/singleflight" "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/errors" - "github.com/cockroachdb/logtags" "github.com/cockroachdb/redact" ) @@ -79,8 +78,6 @@ var CacheSize = settings.RegisterIntSetting( // Storage deals with reading and writing session records. It implements the // sqlliveness.Reader interface, and the slinstace.Writer interface. type Storage struct { - log.AmbientContext - settings *cluster.Settings stopper *stop.Stopper clock *hlc.Clock @@ -94,7 +91,7 @@ type Storage struct { mu struct { syncutil.Mutex - g singleflight.Group + g *singleflight.Group started bool // liveSessions caches the current view of expirations of live sessions. @@ -124,8 +121,6 @@ func NewTestingStorage( newTimer func() timeutil.TimerI, ) *Storage { s := &Storage{ - AmbientContext: ambientCtx, - settings: settings, stopper: stopper, clock: clock, @@ -149,6 +144,7 @@ func NewTestingStorage( } s.mu.liveSessions = cache.NewUnorderedCache(cacheConfig) s.mu.deadSessions = cache.NewUnorderedCache(cacheConfig) + s.mu.g = singleflight.NewGroup("is-alive", "session ID") return s } @@ -221,7 +217,7 @@ func (s *Storage) isAlive( } // We think that the session is expired; check, and maybe delete it. - resChan := s.deleteOrFetchSessionSingleFlightLocked(ctx, sid) + future := s.deleteOrFetchSessionSingleFlightLocked(ctx, sid) // At this point, we know that the singleflight goroutine has been launched. // Releasing the lock here ensures that callers will either join the single- @@ -234,15 +230,11 @@ func (s *Storage) isAlive( if syncOrAsync == async { return true, nil } - select { - case res := <-resChan: - if res.Err != nil { - return false, res.Err - } - return res.Val.(bool), nil - case <-ctx.Done(): - return false, ctx.Err() + res := future.WaitForResult(ctx) + if res.Err != nil { + return false, res.Err } + return res.Val.(bool), nil } // This function will launch a singleflight goroutine for the session which @@ -253,35 +245,25 @@ func (s *Storage) isAlive( // This method assumes that s.mu is held. func (s *Storage) deleteOrFetchSessionSingleFlightLocked( ctx context.Context, sid sqlliveness.SessionID, -) <-chan singleflight.Result { +) singleflight.Future { s.mu.AssertHeld() // If it is found, we can add it and its expiration to the liveSessions // cache. If it isn't found, we know it's dead, and we can add that to the // deadSessions cache. - resChan, _ := s.mu.g.DoChan(string(sid), func() (interface{}, error) { - - // Note that we use a new `context` here to avoid a situation where a cancellation - // of the first context cancels other callers to the `acquireNodeLease()` method, - // because of its use of `singleflight.Group`. See issue #41780 for how this has - // happened. - bgCtx := s.AnnotateCtx(context.Background()) - bgCtx = logtags.AddTags(bgCtx, logtags.FromContext(ctx)) - newCtx, cancel := s.stopper.WithCancelOnQuiesce(bgCtx) - defer cancel() - - // store the result underneath the singleflight to avoid the need - // for additional synchronization. Also, use a stopper task to ensure - // the goroutine is tracked during shutdown. - var live bool - const taskName = "sqlliveness-fetch-or-delete-session" - if err := s.stopper.RunTaskWithErr(newCtx, taskName, func( - ctx context.Context, - ) (err error) { + resChan, _ := s.mu.g.DoChan(ctx, string(sid), singleflight.DoOpts{ + Stop: s.stopper, + InheritCancelation: false, + }, + func(ctx context.Context) (interface{}, error) { + // store the result underneath the singleflight to avoid the need + // for additional synchronization. Also, use a stopper task to ensure + // the goroutine is tracked during shutdown. + var live bool var expiration hlc.Timestamp - live, expiration, err = s.deleteOrFetchSession(newCtx, sid) + live, expiration, err := s.deleteOrFetchSession(ctx, sid) if err != nil { - return err + return nil, err } s.mu.Lock() defer s.mu.Unlock() @@ -290,12 +272,8 @@ func (s *Storage) deleteOrFetchSessionSingleFlightLocked( } else { s.mu.deadSessions.Add(sid, nil) } - return nil - }); err != nil { - return false, err - } - return live, nil - }) + return live, nil + }) return resChan } @@ -385,7 +363,7 @@ func (s *Storage) deleteExpiredSessions(ctx context.Context) { } return } - launchSessionCheck := func(id sqlliveness.SessionID) <-chan singleflight.Result { + launchSessionCheck := func(id sqlliveness.SessionID) singleflight.Future { s.mu.Lock() defer s.mu.Unlock() // We have evidence that the session is expired, so remove any cached @@ -395,12 +373,8 @@ func (s *Storage) deleteExpiredSessions(ctx context.Context) { return s.deleteOrFetchSessionSingleFlightLocked(ctx, id) } checkSession := func(id sqlliveness.SessionID) error { - select { - case r := <-launchSessionCheck(id): - return r.Err - case <-ctx.Done(): - return ctx.Err() - } + future := launchSessionCheck(id) + return future.WaitForResult(ctx).Err } for _, id := range toCheck { if err := checkSession(id); err != nil { diff --git a/pkg/testutils/storageutils/mocking.go b/pkg/testutils/storageutils/mocking.go index 9a4840f715dc..0b39886bc3de 100644 --- a/pkg/testutils/storageutils/mocking.go +++ b/pkg/testutils/storageutils/mocking.go @@ -11,6 +11,7 @@ package storageutils import ( + "context" "fmt" "github.com/cockroachdb/cockroach/pkg/kv/kvserver/kvserverbase" @@ -33,7 +34,7 @@ func (r raftCmdIDAndIndex) String() string { // from Raft replays. type ReplayProtectionFilterWrapper struct { syncutil.Mutex - inFlight singleflight.Group + inFlight *singleflight.Group processedCommands map[raftCmdIDAndIndex]*roachpb.Error filter kvserverbase.ReplicaCommandFilter } @@ -44,6 +45,7 @@ func WrapFilterForReplayProtection( filter kvserverbase.ReplicaCommandFilter, ) kvserverbase.ReplicaCommandFilter { wrapper := ReplayProtectionFilterWrapper{ + inFlight: singleflight.NewGroup("replay-protection", "key"), processedCommands: make(map[raftCmdIDAndIndex]*roachpb.Error), filter: filter, } @@ -78,16 +80,21 @@ func (c *ReplayProtectionFilterWrapper) run(args kvserverbase.FilterArgs) *roach // We use the singleflight.Group to coalesce replayed raft commands onto the // same filter call. This allows concurrent access to the filter for // different raft commands. - resC, _ := c.inFlight.DoChan(mapKey.String(), func() (interface{}, error) { - pErr := c.filter(args) + future, _ := c.inFlight.DoChan(args.Ctx, mapKey.String(), + singleflight.DoOpts{ + Stop: nil, + InheritCancelation: true, + }, + func(ctx context.Context) (interface{}, error) { + pErr := c.filter(args) - c.Lock() - defer c.Unlock() - c.processedCommands[mapKey] = pErr - return pErr, nil - }) + c.Lock() + defer c.Unlock() + c.processedCommands[mapKey] = pErr + return pErr, nil + }) c.Unlock() - res := <-resC + res := future.WaitForResult(args.Ctx) return shallowCloneErrorWithTxn(res.Val.(*roachpb.Error)) } diff --git a/pkg/util/syncutil/singleflight/BUILD.bazel b/pkg/util/syncutil/singleflight/BUILD.bazel index 1a04a7100287..47659aacc55e 100644 --- a/pkg/util/syncutil/singleflight/BUILD.bazel +++ b/pkg/util/syncutil/singleflight/BUILD.bazel @@ -6,7 +6,16 @@ go_library( srcs = ["singleflight.go"], importpath = "github.com/cockroachdb/cockroach/pkg/util/syncutil/singleflight", visibility = ["//visibility:public"], - deps = ["//pkg/util/syncutil"], + deps = [ + "//pkg/util/log", + "//pkg/util/stop", + "//pkg/util/syncutil", + "//pkg/util/tracing", + "//pkg/util/tracing/tracingpb", + "@com_github_cockroachdb_errors//:errors", + "@com_github_cockroachdb_logtags//:logtags", + "@io_opentelemetry_go_otel//attribute", + ], ) go_test( @@ -15,7 +24,15 @@ go_test( srcs = ["singleflight_test.go"], args = ["-test.timeout=55s"], embed = [":singleflight"], - deps = ["@com_github_cockroachdb_errors//:errors"], + deps = [ + "//pkg/testutils", + "//pkg/util/leaktest", + "//pkg/util/log", + "//pkg/util/tracing", + "//pkg/util/tracing/tracingpb", + "@com_github_cockroachdb_errors//:errors", + "@com_github_stretchr_testify//require", + ], ) get_x_data(name = "get_x_data") diff --git a/pkg/util/syncutil/singleflight/singleflight.go b/pkg/util/syncutil/singleflight/singleflight.go index 49a2c407a58d..c3b5cf5d92b0 100644 --- a/pkg/util/syncutil/singleflight/singleflight.go +++ b/pkg/util/syncutil/singleflight/singleflight.go @@ -19,40 +19,113 @@ package singleflight import ( - "sync" + "context" + "fmt" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/stop" "github.com/cockroachdb/cockroach/pkg/util/syncutil" + "github.com/cockroachdb/cockroach/pkg/util/tracing" + "github.com/cockroachdb/cockroach/pkg/util/tracing/tracingpb" + "github.com/cockroachdb/errors" + "github.com/cockroachdb/logtags" + "go.opentelemetry.io/otel/attribute" ) -// call is an in-flight or completed singleflight.Do call +// call is an in-flight or completed singleflight.Do/DoChan call. type call struct { - wg sync.WaitGroup + opName, key string + // c is closed when the call completes, signaling all waiters. + c chan struct{} - // These fields are written once before the WaitGroup is done - // and are only read after the WaitGroup is done. + mu struct { + syncutil.Mutex + // sp is the tracing span of the flight leader. Nil if the leader does not + // have a span. Once this span is finished, its recording is captured as + // `rec` below. This span's recording might get copied into the traces of + // other flight members. A non-leader caller with a recording trace will + // enable recording on this span dynamically. + sp *tracing.Span + } + + ///////////////////////////////////////////////////////////////////////////// + // These fields are written once before the channel is closed and are only + // read after the channel is closed. + ///////////////////////////////////////////////////////////////////////////// val interface{} err error + // rec is the call's recording, if any of the callers that joined the call + // requested the trace to be recorded. It is set once mu.sp is set to nil, + // which happens before c is closed. rec is only read after c is closed. + rec tracing.Trace - // These fields are read and written with the singleflight - // mutex held before the WaitGroup is done, and are read but - // not written after the WaitGroup is done. - dups int - chans []chan<- Result + ///////////////////////////////////////////////////////////////////////////// + // These fields are read and written with the singleflight mutex held before + // the channel is closed, and are read but not written after the channel is + // closed. + ///////////////////////////////////////////////////////////////////////////// + dups int +} + +func newCall(opName, key string, sp *tracing.Span) *call { + c := &call{ + opName: opName, + key: key, + c: make(chan struct{}), + } + c.mu.sp = sp + return c +} + +func (c *call) maybeStartRecording(mode tracingpb.RecordingType) { + c.mu.Lock() + defer c.mu.Unlock() + if c.mu.sp.RecordingType() < mode { + c.mu.sp.SetRecordingType(mode) + } } // Group represents a class of work and forms a namespace in // which units of work can be executed with duplicate suppression. type Group struct { - mu syncutil.Mutex // protects m - m map[string]*call // lazily initialized + // opName is used as the operation name of the spans produced by this Group + // for every flight. + opName string + // tagName represents the name of the tag containing the key for each flight. + // If not set, the spans do not get such a tag. + tagName string + mu syncutil.Mutex // protects m + m map[string]*call // lazily initialized +} + +// NoTags can be passed to NewGroup as the tagName to indicate that the tracing +// spans created for operations should not have the operation key as a tag. In +// particular, in cases where a single dummy key is used with a Group, having it +// as a tag is not necessary. +const NoTags = "" + +// NewGroup creates a Group. +// +// opName will be used as the operation name of the spans produced by this Group +// for every flight. +// tagName will be used as the name of the span tag containing the key for each +// flight span. If NoTags is passed, the spans do not get such tags. +func NewGroup(opName, tagName string) *Group { + return &Group{opName: opName, tagName: tagName} } // Result holds the results of Do, so they can be passed // on a channel. type Result struct { - Val interface{} - Err error + // Val represents the call's return value. + Val interface{} + // Err represents the call's error, if any. + Err error + // Shared is set if the result has been shared with multiple callers. Shared bool + // Leader is set if the caller was the flight leader (the caller that + // triggered the flight). + Leader bool } // Do executes and returns the results of the given function, making @@ -60,8 +133,13 @@ type Result struct { // time. If a duplicate comes in, the duplicate caller waits for the // original to complete and receives the same results. // The return value shared indicates whether v was given to multiple callers. +// +// NOTE: If fn responds to ctx cancelation by interrupting the work and +// returning an error, canceling ctx might propagate such error to other callers +// if a flight's leader's ctx is canceled. See DoChan for more control over this +// behavior. func (g *Group) Do( - key string, fn func() (interface{}, error), + ctx context.Context, key string, fn func(context.Context) (interface{}, error), ) (v interface{}, shared bool, err error) { g.mu.Lock() if g.m == nil { @@ -69,24 +147,163 @@ func (g *Group) Do( } if c, ok := g.m[key]; ok { c.dups++ + c.maybeStartRecording(tracing.SpanFromContext(ctx).RecordingType()) g.mu.Unlock() - c.wg.Wait() - return c.val, true, c.err + log.Eventf(ctx, "waiting on singleflight %s:%s owned by another leader. Starting to record the leader's flight.", g.opName, key) + + // Block on the call. + <-c.c + log.Eventf(ctx, "waiting on singleflight %s:%s owned by another leader... done", g.opName, key) + // Get the call's result through result() so that the call's trace gets + // imported into ctx. + res := c.result(ctx, false /* leader */) + return res.Val, true, res.Err + } + + // Open a child span for the flight. Note that this child span might outlive + // its parent if the caller doesn't wait for the result of this flight. It's + // common for the caller to not always wait, particularly if + // opts.InheritCancelation == false (i.e. if the caller can be canceled + // independently of the flight). + ctx, sp := tracing.ChildSpan(ctx, g.opName) + if g.tagName != "" { + sp.SetTag(g.tagName, attribute.StringValue(key)) } - c := new(call) - c.wg.Add(1) + c := newCall(g.opName, key, sp) // c takes ownership of sp g.m[key] = c g.mu.Unlock() - g.doCall(c, key, fn) + g.doCall(ctx, + c, key, + DoOpts{ + Stop: nil, + InheritCancelation: true, + }, + fn) return c.val, c.dups > 0, c.err } -// DoChan is like Do but returns a channel that will receive the results when +// DoOpts groups options for the DoChan() method. +type DoOpts struct { + // Stop, if not nil, is used both to create the flight in a stopper task and + // to cancel the flight's ctx on stopper quiescence. + Stop *stop.Stopper + // InheritCancelation controls whether the cancelation of the caller's ctx + // affects the flight. If set, the flight closure gets the caller's ctx. If + // not set, the closure runs in a different ctx which does not inherit the + // caller's cancelation. It is common to not want the flight to inherit the + // caller's cancelation, so that a canceled leader does not propagate an error + // to everybody else that joined the sane flight. + InheritCancelation bool +} + +// Future is the return type of the DoChan() call. +type Future struct { + call *call + leader bool +} + +func makeFuture(c *call, leader bool) Future { + return Future{ + call: c, + leader: leader, + } +} + +// C returns the channel on which the result of the DoChan call will be +// delivered. +func (f Future) C() <-chan struct{} { + if f.call == nil { + return nil + } + return f.call.c +} + +// WaitForResult delivers the flight's result. If called before the call is done +// (i.e. before channel returned by C() is closed), then it will block; +// canceling ctx unblocks it. +// +// If the ctx has a recording tracing span, and if the context passed to DoChan +// also had a recording span in it (commonly because the same ctx was used), +// then the recording of the flight will be ingested into ctx even if this +// caller was not the flight's leader. +func (f Future) WaitForResult(ctx context.Context) Result { + if f.call == nil { + panic("WaitForResult called on empty Future") + } + return f.call.result(ctx, f.leader) +} + +// Reset resets the future +func (f *Future) Reset() { + *f = Future{} +} + +// result returns the call's results. It will block until the call completes. +func (c *call) result(ctx context.Context, leader bool) Result { + // Wait for the call to finish. + select { + // Give priority to c.c to ensure that a context error is not returned if the + // call is done. + case <-c.c: + default: + select { + case <-c.c: + case <-ctx.Done(): + op := fmt.Sprintf("%s:%s", c.opName, c.key) + if !leader { + log.Eventf(ctx, "waiting for singleflight interrupted: %v", ctx.Err()) + // If we're recording, copy over the call's trace. + sp := tracing.SpanFromContext(ctx) + if sp.RecordingType() != tracingpb.RecordingOff { + // We expect the call to still be ongoing, so c.rec is not yet + // populated. So, we'll look in c.mu.sp, and get the span's recording + // so far. This is racy with the leader finishing and resetting + // c.mu.sp, in which case we'll get nothing; we ignore that + // possibiltity as unlikely, making this best-effort. + c.mu.Lock() + rec := c.mu.sp.GetTraceRecording(sp.RecordingType()) + c.mu.Unlock() + sp.ImportTrace(rec) + } + } + return Result{ + Val: nil, + Err: errors.Wrapf(ctx.Err(), "interrupted during singleflight %s", op), + Shared: false, + Leader: leader} + } + } + + // If we got here, c.c has been closed, so we can access c.rec. + if !leader { + // If we're recording, copy over the call's trace. + sp := tracing.SpanFromContext(ctx) + if sp.RecordingType() != tracingpb.RecordingOff { + if rec := c.rec; !rec.Empty() { + tracing.SpanFromContext(ctx).ImportTrace(rec.PartialClone()) + } + } + } + + return Result{ + Val: c.val, + Err: c.err, + Shared: c.dups > 0, + Leader: leader, + } +} + +// DoChan is like Do but returns a Future that will receive the results when // they are ready. The method also returns a boolean specifying whether the // caller's fn function will be called or not. This return value lets callers // identify a unique "leader" for a flight. // +// Close() must be called on the returned Future if the caller does not +// wait for the Future's result. +// +// opts controls details about how the flight is to run. +// // NOTE: DoChan makes it possible to initiate or join a flight while holding a // lock without holding it for the duration of the flight. A common usage // pattern is: @@ -101,38 +318,92 @@ func (g *Group) Do( // before any modifications to the datastructure occurred (relative to the state // observed in step one). Were the lock to be released before calling DoChan(), // a previous flight might modify the datastructure before our flight began. -func (g *Group) DoChan(key string, fn func() (interface{}, error)) (<-chan Result, bool) { - ch := make(chan Result, 1) +// +// In addition to the above, another reason for using DoChan over Do is so that +// the caller can be canceled while not propagating the cancelation to other +// innocent callers that joined the same flight; the caller can listen to its +// cancelation in parallel to the result channel. See DoOpts.InheritCancelation. +func (g *Group) DoChan( + ctx context.Context, key string, opts DoOpts, fn func(context.Context) (interface{}, error), +) (Future, bool) { g.mu.Lock() if g.m == nil { g.m = make(map[string]*call) } + if c, ok := g.m[key]; ok { c.dups++ - c.chans = append(c.chans, ch) + c.maybeStartRecording(tracing.SpanFromContext(ctx).RecordingType()) + g.mu.Unlock() - return ch, false + log.Eventf(ctx, "joining singleflight %s:%s owned by another leader", g.opName, key) + return makeFuture(c, false /* leader */), false } - c := &call{chans: []chan<- Result{ch}} - c.wg.Add(1) + + // Open a child span for the flight. Note that this child span might outlive + // its parent if the caller doesn't wait for the result of this flight. It's + // common for the caller to not always wait, particularly if + // opts.InheritCancelation == false (i.e. if the caller can be canceled + // independently of the flight). + ctx, sp := tracing.ChildSpan(ctx, g.opName) + if g.tagName != "" { + sp.SetTag(g.tagName, attribute.StringValue(key)) + } + c := newCall(g.opName, key, sp) // c takes ownership of sp g.m[key] = c g.mu.Unlock() - go g.doCall(c, key, fn) + go g.doCall(ctx, c, key, opts, fn) - return ch, true + return makeFuture(c, true /* leader */), true } -// doCall handles the single call for a key. -func (g *Group) doCall(c *call, key string, fn func() (interface{}, error)) { - c.val, c.err = fn() - c.wg.Done() +// doCall handles the single call for a key. At the end of the call, c.c is +// closed, signaling the waiters (if any). +// +// doCall takes ownership of c.sp, which will be finished before returning. +func (g *Group) doCall( + ctx context.Context, + c *call, + key string, + opts DoOpts, + fn func(ctx context.Context) (interface{}, error), +) { + // Prepare the ctx for the call. + if !opts.InheritCancelation { + // Copy the log tags and the span. + ctx = logtags.AddTags(context.Background(), logtags.FromContext(ctx)) + c.mu.Lock() + sp := c.mu.sp + c.mu.Unlock() + ctx = tracing.ContextWithSpan(ctx, sp) + } + if opts.Stop != nil { + var cancel func() + ctx, cancel = opts.Stop.WithCancelOnQuiesce(ctx) + defer cancel() + + if err := opts.Stop.RunTask(ctx, g.opName+":"+key, func(ctx context.Context) { + c.val, c.err = fn(ctx) + }); err != nil { + c.err = err + } + } else { + c.val, c.err = fn(ctx) + } g.mu.Lock() delete(g.m, key) - for _, ch := range c.chans { - ch <- Result{c.val, c.err, c.dups > 0} + { + c.mu.Lock() + sp := c.mu.sp + // Prevent other flyers from observing a finished span. + c.mu.sp = nil + c.rec = sp.FinishAndGetTraceRecording(sp.RecordingType()) + c.mu.Unlock() } + // Publish the results to all waiters. + close(c.c) g.mu.Unlock() } diff --git a/pkg/util/syncutil/singleflight/singleflight_test.go b/pkg/util/syncutil/singleflight/singleflight_test.go index d677b76deefa..73d80bbed380 100644 --- a/pkg/util/syncutil/singleflight/singleflight_test.go +++ b/pkg/util/syncutil/singleflight/singleflight_test.go @@ -17,18 +17,26 @@ package singleflight import ( + "context" "fmt" "sync" "sync/atomic" "testing" "time" + "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/tracing" + "github.com/cockroachdb/cockroach/pkg/util/tracing/tracingpb" "github.com/cockroachdb/errors" + "github.com/stretchr/testify/require" ) func TestDo(t *testing.T) { - var g Group - v, _, err := g.Do("key", func() (interface{}, error) { + ctx := context.Background() + g := NewGroup("test", "test") + v, _, err := g.Do(ctx, "key", func(context.Context) (interface{}, error) { return "bar", nil }) res := Result{Val: v, Err: err} @@ -36,21 +44,23 @@ func TestDo(t *testing.T) { } func TestDoChan(t *testing.T) { - var g Group - resC, leader := g.DoChan("key", func() (interface{}, error) { + ctx := context.Background() + g := NewGroup("test", "test") + res, leader := g.DoChan(ctx, "key", DoOpts{}, func(context.Context) (interface{}, error) { return "bar", nil }) if !leader { t.Errorf("DoChan returned not leader, expected leader") } - res := <-resC - assertRes(t, res, false) + result := res.WaitForResult(ctx) + assertRes(t, result, false) } func TestDoErr(t *testing.T) { - var g Group + ctx := context.Background() + g := NewGroup("test", "test") someErr := errors.New("Some error") - v, _, err := g.Do("key", func() (interface{}, error) { + v, _, err := g.Do(ctx, "key", func(context.Context) (interface{}, error) { return nil, someErr }) if !errors.Is(err, someErr) { @@ -62,11 +72,12 @@ func TestDoErr(t *testing.T) { } func TestDoDupSuppress(t *testing.T) { - var g Group + ctx := context.Background() + g := NewGroup("test", "key") var wg1, wg2 sync.WaitGroup c := make(chan string, 1) var calls int32 - fn := func() (interface{}, error) { + fn := func(context.Context) (interface{}, error) { if atomic.AddInt32(&calls, 1) == 1 { // First invocation. wg1.Done() @@ -87,7 +98,7 @@ func TestDoDupSuppress(t *testing.T) { go func() { defer wg2.Done() wg1.Done() - v, _, err := g.Do("key", fn) + v, _, err := g.Do(ctx, "key", fn) if err != nil { t.Errorf("Do error: %v", err) return @@ -107,45 +118,276 @@ func TestDoDupSuppress(t *testing.T) { } } +// Test that a Do caller that's recording the trace gets the call's recording +// even if it's not the leader. +func TestDoTracing(t *testing.T) { + defer leaktest.AfterTest(t)() + + testutils.RunTrueAndFalse(t, "leader recording", func(t *testing.T, leaderRecording bool) { + ctx := context.Background() + tr := tracing.NewTracerWithOpt(ctx, tracing.WithTracingMode(tracing.TracingModeActiveSpansRegistry)) + opName := "test" + g := NewGroup(opName, NoTags) + leaderSem := make(chan struct{}) + leaderTraceC := make(chan tracingpb.Recording) + const key = "" + go func() { + // The leader needs to have a span, but not necessarily a recording one. + var opt tracing.SpanOption + if leaderRecording { + opt = tracing.WithRecording(tracingpb.RecordingVerbose) + } + sp := tr.StartSpan("leader", opt) + defer sp.Finish() + ctx := tracing.ContextWithSpan(ctx, sp) + _, _, err := g.Do(ctx, key, func(ctx context.Context) (interface{}, error) { + log.Eventf(ctx, "inside the flight 1") + leaderSem <- struct{}{} + <-leaderSem + log.Eventf(ctx, "inside the flight 2") + return nil, nil + }) + if err != nil { + panic(err) + } + leaderTraceC <- sp.GetConfiguredRecording() + }() + // Wait for the call above to become the flight leader. + <-leaderSem + followerTraceC := make(chan tracingpb.Recording) + go func() { + ctx, getRec := tracing.ContextWithRecordingSpan(context.Background(), tr, "follower") + _, _, err := g.Do(ctx, key, func(ctx context.Context) (interface{}, error) { + panic("this should never be called; the leader uses another closure") + }) + if err != nil { + panic(err) + } + followerTraceC <- getRec() + }() + // Wait until the second caller blocks. + testutils.SucceedsSoon(t, func() error { + if g.NumCalls("") != 2 { + return errors.New("second call not blocked yet") + } + return nil + }) + // Unblock the leader. + leaderSem <- struct{}{} + // Wait for the both the leader and the follower. + leaderRec := <-leaderTraceC + followerRec := <-followerTraceC + + // Check that the trace's contains the call. In the leader's case, it should + // only have if the leader was recording. The follower is always recording + // in these tests. + _, ok := leaderRec.FindSpan(opName) + require.Equal(t, leaderRecording, ok) + _, ok = followerRec.FindSpan(opName) + require.True(t, ok) + + // If the leader is not recording, the first message is not in the trace + // because the follower started recording on the leader's span only + // afterwards. + expectedToSeeFirstMsg := leaderRecording + _, ok = followerRec.FindLogMessage("inside the flight 1") + require.Equal(t, expectedToSeeFirstMsg, ok) + // The second message is there. + _, ok = followerRec.FindLogMessage("inside the flight 2") + require.True(t, ok) + }) +} + func TestDoChanDupSuppress(t *testing.T) { c := make(chan struct{}) - fn := func() (interface{}, error) { + fn := func(context.Context) (interface{}, error) { <-c return "bar", nil } - var g Group - resC1, leader1 := g.DoChan("key", fn) + ctx := context.Background() + g := NewGroup("test", "key") + res1, leader1 := g.DoChan(ctx, "key", DoOpts{}, fn) if !leader1 { t.Errorf("DoChan returned not leader, expected leader") } - resC2, leader2 := g.DoChan("key", fn) + res2, leader2 := g.DoChan(ctx, "key", DoOpts{}, fn) if leader2 { t.Errorf("DoChan returned leader, expected not leader") } close(c) - for _, res := range []Result{<-resC1, <-resC2} { + + for _, res := range []Result{res1.WaitForResult(ctx), res2.WaitForResult(ctx)} { assertRes(t, res, true) } } +// Test that a DoChan caller that's recording the trace gets the call's +// recording even if it's not the leader. +func TestDoChanTracing(t *testing.T) { + defer leaktest.AfterTest(t)() + + // We test both the case when a follower waits for the flight completion, and + // the case where it doesn't. In both cases, the follower should have the + // leader's recording (at least part of it in the canceled case). + testutils.RunTrueAndFalse(t, "follower-canceled", func(t *testing.T, cancelFollower bool) { + ctx := context.Background() + tr := tracing.NewTracerWithOpt(ctx, tracing.WithTracingMode(tracing.TracingModeActiveSpansRegistry)) + const opName = "test" + g := NewGroup(opName, NoTags) + leaderSem := make(chan struct{}) + leaderResultC := make(chan Result) + followerResultC := make(chan Result) + go func() { + // The leader needs to have a span, but not necessarily a recording one. + ctx := tracing.ContextWithSpan(ctx, tr.StartSpan("leader")) + future, leader := g.DoChan(ctx, "", + DoOpts{ + Stop: nil, + InheritCancelation: false, + }, + func(ctx context.Context) (interface{}, error) { + log.Eventf(ctx, "inside the flight 1") + // Signal that the leader started. + leaderSem <- struct{}{} + // Wait until the follower joins the flight. + <-leaderSem + log.Eventf(ctx, "inside the flight 2") + // Signal that the leader has logged a message. + leaderSem <- struct{}{} + // Wait for the test to signal one more time. + <-leaderSem + return nil, nil + }) + if !leader { + panic("expected to be leader") + } + + res := future.WaitForResult(ctx) + leaderResultC <- res + }() + // Wait for the call above to become the flight leader. + <-leaderSem + // Call another DoChan, which will get de-duped onto the blocked leader above. + followerTraceC := make(chan tracingpb.Recording) + followerCtx, cancel := context.WithCancel(ctx) + defer cancel() + + go func() { + ctx := followerCtx + ctx, getRec := tracing.ContextWithRecordingSpan(ctx, tr, "follower") + future, leader := g.DoChan(ctx, "", DoOpts{}, func(ctx context.Context) (interface{}, error) { + panic("this should never be called; the leader uses another closure") + }) + if leader { + panic("expected to be follower") + } + res := future.WaitForResult(ctx) + followerResultC <- res + followerTraceC <- getRec() + }() + // Wait until the second caller blocks. + testutils.SucceedsSoon(t, func() error { + if g.NumCalls("") != 2 { + return errors.New("second call not blocked yet") + } + return nil + }) + // Now that the follower is blocked, unblock the leader. + leaderSem <- struct{}{} + // Wait for the leader to make a little progress, logging a message. + <-leaderSem + + if cancelFollower { + cancel() + } else { + // Signal the leader that it can terminate. + leaderSem <- struct{}{} + } + + // Wait for the follower. In the `cancelFollower` case, this happens while + // the leader is still running. + followerRes := <-followerResultC + require.False(t, followerRes.Leader) + if !cancelFollower { + require.NoError(t, followerRes.Err) + } else { + require.ErrorContains(t, followerRes.Err, "interrupted during singleflight") + } + + if cancelFollower { + // Signal the leader that it can terminate. In the `!cancelFollower` case, + // we've done it above. + leaderSem <- struct{}{} + } + + // Wait for the leader. + res := <-leaderResultC + require.True(t, res.Leader) + require.NoError(t, res.Err) + + // Check that the follower's trace contains the call. + followerRec := <-followerTraceC + _, ok := followerRec.FindSpan(opName) + require.True(t, ok) + // The first message is not in the trace because the follower started + // recording on the leader's span only afterwards. + _, ok = followerRec.FindLogMessage("inside the flight 1") + require.False(t, ok) + // The second message is there. + _, ok = followerRec.FindLogMessage("inside the flight 2") + require.True(t, ok) + }) +} + +// Test that a DoChan flight is not interrupted by the canceling of the caller's +// ctx. +func TestDoChanCtxCancel(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx, cancel := context.WithCancel(context.Background()) + g := NewGroup("test", NoTags) + future, _ := g.DoChan(ctx, "", + DoOpts{ + Stop: nil, + InheritCancelation: false, + }, + func(ctx context.Context) (interface{}, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(time.Millisecond): + } + return nil, nil + }) + cancel() + <-future.C() + // Note that when it doesn't have to block, Result() returns the flight's + // result even when called with a canceled ctx. + res := future.WaitForResult(ctx) + // Expect that the call ended after the timer expired, not because of ctx + // cancelation. + require.NoError(t, res.Err) +} + func TestNumCalls(t *testing.T) { c := make(chan struct{}) - fn := func() (interface{}, error) { + fn := func(ctx context.Context) (interface{}, error) { <-c return "bar", nil } - var g Group + ctx := context.Background() + g := NewGroup("test", "key") assertNumCalls(t, g.NumCalls("key"), 0) - resC1, _ := g.DoChan("key", fn) + resC1, _ := g.DoChan(ctx, "key", DoOpts{}, fn) assertNumCalls(t, g.NumCalls("key"), 1) - resC2, _ := g.DoChan("key", fn) + resC2, _ := g.DoChan(ctx, "key", DoOpts{}, fn) assertNumCalls(t, g.NumCalls("key"), 2) close(c) - <-resC1 - <-resC2 + <-resC1.C() + <-resC2.C() assertNumCalls(t, g.NumCalls("key"), 0) } diff --git a/pkg/util/tracing/crdbspan.go b/pkg/util/tracing/crdbspan.go index 85881b98a450..84bcc3368554 100644 --- a/pkg/util/tracing/crdbspan.go +++ b/pkg/util/tracing/crdbspan.go @@ -484,10 +484,36 @@ func (t *Trace) appendSpansRecursively(buffer []tracingpb.RecordedSpan) []tracin // Flatten flattens the trace into a slice of spans. The root is the first span, // and parents come before children. Otherwise, the spans are not sorted. +// +// See SortSpans() for sorting the result in order to turn it into a +// tracingpb.Recording. func (t *Trace) Flatten() []tracingpb.RecordedSpan { + if t.Empty() { + return nil + } return t.appendSpansRecursively(nil /* buffer */) } +// ToRecording converts the Trace to a tracingpb.Recording by flattening it and +// sorting the spans. +func (t *Trace) ToRecording() tracingpb.Recording { + spans := t.Flatten() + // sortSpans sorts the spans by StartTime, except the first Span (the root of + // this recording) which stays in place. + toSort := sortPoolRecordings.Get().(*tracingpb.Recording) // avoids allocations in sort.Sort + *toSort = spans[1:] + sort.Sort(toSort) + *toSort = nil + sortPoolRecordings.Put(toSort) + return spans +} + +var sortPoolRecordings = sync.Pool{ + New: func() interface{} { + return &tracingpb.Recording{} + }, +} + // PartialClone performs a deep copy of the trace. The immutable slices are not // copied: logs, tags and stats. func (t *Trace) PartialClone() Trace { diff --git a/pkg/util/tracing/span.go b/pkg/util/tracing/span.go index 90aed655ffc9..cba11f3c46c3 100644 --- a/pkg/util/tracing/span.go +++ b/pkg/util/tracing/span.go @@ -290,9 +290,9 @@ func (sp *Span) FinishAndGetConfiguredRecording() tracingpb.Recording { rec := tracingpb.Recording(nil) recType := sp.RecordingType() if recType != tracingpb.RecordingOff { + // Reach directly into sp.i to pass the finishing argument. rec = sp.i.GetRecording(recType, true /* finishing */) } - // Reach directly into sp.i to pass the finishing argument. sp.finishInternal() return rec } @@ -352,28 +352,67 @@ func (sp *Span) GetConfiguredRecording() tracingpb.Recording { return sp.i.GetRecording(recType, false /* finishing */) } +// GetTraceRecording returns the span's recording as a Trace. +// +// See also GetRecording(), which returns a tracingpb.Recording. +func (sp *Span) GetTraceRecording(recType tracingpb.RecordingType) Trace { + if sp.detectUseAfterFinish() { + return Trace{} + } + return sp.i.GetTraceRecording(recType, false /* finishing */) +} + +// FinishAndGetTraceRecording finishes the span and returns its recording as a +// Trace. +// +// See also FinishAndGetRecording(), which returns a tracingpb.Recording. +func (sp *Span) FinishAndGetTraceRecording(recType tracingpb.RecordingType) Trace { + if sp.detectUseAfterFinish() { + return Trace{} + } + rec := sp.i.GetTraceRecording(recType, true /* finishing */) + sp.Finish() + return rec +} + // ImportRemoteRecording adds the spans in remoteRecording as children of the // receiver. As a result of this, the imported recording will be a part of the -// GetRecording() output for the receiver. +// GetRecording() output for the receiver. All the structured events from the +// trace are passed to the receiver's event listeners. // // This function is used to import a recording from another node. func (sp *Span) ImportRemoteRecording(remoteRecording tracingpb.Recording) { - if !sp.detectUseAfterFinish() { - // Handle recordings coming from 22.2 nodes that don't have the - // StructuredRecordsSizeBytes field set. - // TODO(andrei): remove this in 23.2. - for i := range remoteRecording { - if len(remoteRecording[i].StructuredRecords) != 0 && remoteRecording[i].StructuredRecordsSizeBytes == 0 { - size := int64(0) - for _, rec := range remoteRecording[i].StructuredRecords { - size += int64(rec.MemorySize()) - } - remoteRecording[i].StructuredRecordsSizeBytes = size + if sp.detectUseAfterFinish() { + return + } + // Handle recordings coming from 22.2 nodes that don't have the + // StructuredRecordsSizeBytes field set. + // TODO(andrei): remove this in 23.2. + for i := range remoteRecording { + if len(remoteRecording[i].StructuredRecords) != 0 && remoteRecording[i].StructuredRecordsSizeBytes == 0 { + size := int64(0) + for _, rec := range remoteRecording[i].StructuredRecords { + size += int64(rec.MemorySize()) } + remoteRecording[i].StructuredRecordsSizeBytes = size } + } - sp.i.ImportRemoteRecording(remoteRecording) + sp.ImportTrace(treeifyRecording(remoteRecording)) +} + +// ImportTrace takes a trace recording and, depending on the receiver's +// recording mode, adds it as a child to sp. All the structured events from the +// trace are passed to the receiver's event listeners. +// +// ImportTrace takes ownership of trace; the caller should not use it anymore. +// The caller can call Trace.PartialClone() to make a sufficient copy for +// passing into ImportTrace. +func (sp *Span) ImportTrace(trace Trace) { + if sp.detectUseAfterFinish() { + return } + sp.i.ImportTrace(trace) } // Meta returns the information which needs to be propagated across process diff --git a/pkg/util/tracing/span_inner.go b/pkg/util/tracing/span_inner.go index d24d4aa8e58f..d5714d27dc29 100644 --- a/pkg/util/tracing/span_inner.go +++ b/pkg/util/tracing/span_inner.go @@ -12,9 +12,7 @@ package tracing import ( "fmt" - "sort" "strings" - "sync" "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/cockroach/pkg/util/tracing/tracingpb" @@ -73,11 +71,21 @@ func (s *spanInner) RecordingType() tracingpb.RecordingType { func (s *spanInner) SetRecordingType(to tracingpb.RecordingType) { if s.isNoop() { - panic(errors.AssertionFailedf("SetVerbose called on NoopSpan; use the WithForceRealSpan option for StartSpan")) + panic(errors.AssertionFailedf("SetRecordingType called on NoopSpan; use the WithForceRealSpan option for StartSpan")) } s.crdb.SetRecordingType(to) } +// GetTraceRecording returns the span's recording as a Trace. +// +// See also GetRecording(), which returns it as a tracingpb.Recording. +func (s *spanInner) GetTraceRecording(recType tracingpb.RecordingType, finishing bool) Trace { + if s.isNoop() { + return Trace{} + } + return s.crdb.GetRecording(recType, finishing) +} + // GetRecording returns the span's recording. // // finishing indicates whether s is in the process of finishing. If it isn't, @@ -85,31 +93,12 @@ func (s *spanInner) SetRecordingType(to tracingpb.RecordingType) { func (s *spanInner) GetRecording( recType tracingpb.RecordingType, finishing bool, ) tracingpb.Recording { - if s.isNoop() { - return nil - } - trace := s.crdb.GetRecording(recType, finishing) - spans := trace.Flatten() - - // Sort the spans by StartTime, except the first Span (the root of this - // recording) which stays in place. - toSort := sortPoolRecordings.Get().(*tracingpb.Recording) // avoids allocations in sort.Sort - *toSort = spans[1:] - sort.Sort(toSort) - *toSort = nil - sortPoolRecordings.Put(toSort) - - return spans -} - -var sortPoolRecordings = sync.Pool{ - New: func() interface{} { - return &tracingpb.Recording{} - }, + trace := s.GetTraceRecording(recType, finishing) + return trace.ToRecording() } -func (s *spanInner) ImportRemoteRecording(remoteRecording tracingpb.Recording) { - s.crdb.recordFinishedChildren(treeifyRecording(remoteRecording)) +func (s *spanInner) ImportTrace(trace Trace) { + s.crdb.recordFinishedChildren(trace) } func treeifyRecording(rec tracingpb.Recording) Trace {