diff --git a/pkg/sql/flowinfra/flow_registry.go b/pkg/sql/flowinfra/flow_registry.go index fa8944cf1c20..920ad3fe649f 100644 --- a/pkg/sql/flowinfra/flow_registry.go +++ b/pkg/sql/flowinfra/flow_registry.go @@ -366,13 +366,13 @@ func (fr *FlowRegistry) cancelPendingStreams( // ConnectInboundStream calls for the flow will fail to find it and time out. func (fr *FlowRegistry) UnregisterFlow(id execinfrapb.FlowID) { fr.Lock() + defer fr.Unlock() entry := fr.flows[id] if entry.streamTimer != nil { entry.streamTimer.Stop() entry.streamTimer = nil } fr.releaseEntryLocked(id) - fr.Unlock() } // waitForFlow waits until the flow with the given id gets registered - up to diff --git a/pkg/sql/pgwire/server.go b/pkg/sql/pgwire/server.go index 119f011d4798..33e97e333526 100644 --- a/pkg/sql/pgwire/server.go +++ b/pkg/sql/pgwire/server.go @@ -450,6 +450,14 @@ func (s *Server) setDrainingLocked(drain bool) bool { return true } +// setDraining sets the server's draining state and returns whether the +// state changed (i.e. drain != s.mu.draining). +func (s *Server) setDraining(drain bool) bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.setDrainingLocked(drain) +} + // setRejectNewConnectionsLocked sets the server's rejectNewConnections state. // s.mu must be locked when setRejectNewConnectionsLocked is called. func (s *Server) setRejectNewConnectionsLocked(rej bool) { @@ -567,13 +575,10 @@ func (s *Server) drainImpl( stopper *stop.Stopper, ) error { - s.mu.Lock() - if !s.setDrainingLocked(true) { + if !s.setDraining(true) { // We are already draining. - s.mu.Unlock() return nil } - s.mu.Unlock() // If there is no open SQL connections to drain, just return. if s.GetConnCancelMapLen() == 0 { diff --git a/pkg/sql/sqlliveness/slstorage/slstorage.go b/pkg/sql/sqlliveness/slstorage/slstorage.go index d2d370afb0d1..a444c6eae6d5 100644 --- a/pkg/sql/sqlliveness/slstorage/slstorage.go +++ b/pkg/sql/sqlliveness/slstorage/slstorage.go @@ -202,33 +202,41 @@ const ( func (s *Storage) isAlive( ctx context.Context, sid sqlliveness.SessionID, syncOrAsync readType, ) (alive bool, _ error) { - s.mu.Lock() - if !s.mu.started { - s.mu.Unlock() - return false, sqlliveness.NotStartedError - } - if _, ok := s.mu.deadSessions.Get(sid); ok { - s.mu.Unlock() - s.metrics.IsAliveCacheHits.Inc(1) - return false, nil - } - if expiration, ok := s.mu.liveSessions.Get(sid); ok { - expiration := expiration.(hlc.Timestamp) - // The record exists and is valid. - if s.clock.Now().Less(expiration) { - s.mu.Unlock() + + // If wait is false, alive is set and future is unset. + // If wait is true, alive is unset and future is set. + alive, wait, future, err := func() (bool, bool, singleflight.Future, error) { + s.mu.Lock() + defer s.mu.Unlock() + + if !s.mu.started { + return false, false, singleflight.Future{}, sqlliveness.NotStartedError + } + if _, ok := s.mu.deadSessions.Get(sid); ok { s.metrics.IsAliveCacheHits.Inc(1) - return true, nil + return false, false, singleflight.Future{}, nil } - } + if expiration, ok := s.mu.liveSessions.Get(sid); ok { + expiration := expiration.(hlc.Timestamp) + // The record exists and is valid. + if s.clock.Now().Less(expiration) { + s.metrics.IsAliveCacheHits.Inc(1) + return true, false, singleflight.Future{}, nil + } + } + + // We think that the session is expired; check, and maybe delete it. + future := s.deleteOrFetchSessionSingleFlightLocked(ctx, sid) - // We think that the session is expired; check, and maybe delete it. - future := s.deleteOrFetchSessionSingleFlightLocked(ctx, sid) + // At this point, we know that the singleflight goroutine has been launched. + // Releasing the lock when we return ensures that callers will either join + // the singleflight or see the result. + return false, true, future, nil + }() + if err != nil || !wait { + return alive, err + } - // At this point, we know that the singleflight goroutine has been launched. - // Releasing the lock here ensures that callers will either join the single- - // flight or see the result. - s.mu.Unlock() s.metrics.IsAliveCacheMisses.Inc(1) // If we do not want to wait for the result, assume that the session is diff --git a/pkg/sql/sqlstats/ssmemstorage/ss_mem_storage.go b/pkg/sql/sqlstats/ssmemstorage/ss_mem_storage.go index 7c29ba7d2076..9d6105cbca61 100644 --- a/pkg/sql/sqlstats/ssmemstorage/ss_mem_storage.go +++ b/pkg/sql/sqlstats/ssmemstorage/ss_mem_storage.go @@ -664,9 +664,11 @@ func (s *Container) SaveToLog(ctx context.Context, appName string) { } var buf bytes.Buffer for key, stats := range s.mu.stmts { - stats.mu.Lock() - json, err := json.Marshal(stats.mu.data) - stats.mu.Unlock() + json, err := func() ([]byte, error) { + stats.mu.Lock() + defer stats.mu.Unlock() + return json.Marshal(stats.mu.data) + }() if err != nil { log.Errorf(ctx, "error while marshaling stats for %q // %q: %v", appName, key.String(), err) continue diff --git a/pkg/sql/sqlstats/ssmemstorage/ss_mem_writer.go b/pkg/sql/sqlstats/ssmemstorage/ss_mem_writer.go index 361c5a74bce2..6e2d00bc53b8 100644 --- a/pkg/sql/sqlstats/ssmemstorage/ss_mem_writer.go +++ b/pkg/sql/sqlstats/ssmemstorage/ss_mem_writer.go @@ -335,17 +335,21 @@ func (s *Container) RecordTransaction( if created { estimatedMemAllocBytes := stats.sizeUnsafe() + key.Size() + 8 /* hash of transaction key */ - s.mu.Lock() - - // If the monitor is nil, we do not track memory usage. - if s.mu.acc.Monitor() != nil { - if err := s.mu.acc.Grow(ctx, estimatedMemAllocBytes); err != nil { - delete(s.mu.txns, key) - s.mu.Unlock() - return ErrMemoryPressure + if err := func() error { + s.mu.Lock() + defer s.mu.Unlock() + + // If the monitor is nil, we do not track memory usage. + if s.mu.acc.Monitor() != nil { + if err := s.mu.acc.Grow(ctx, estimatedMemAllocBytes); err != nil { + delete(s.mu.txns, key) + return ErrMemoryPressure + } } + return nil + }(); err != nil { + return err } - s.mu.Unlock() } stats.mu.data.Count++ diff --git a/pkg/util/tracing/tracer.go b/pkg/util/tracing/tracer.go index b48742805456..208d457163f0 100644 --- a/pkg/util/tracing/tracer.go +++ b/pkg/util/tracing/tracer.go @@ -527,17 +527,19 @@ func (r *SpanRegistry) testingAll() []*crdbSpan { // concurrently with this call. swap takes ownership of the spanRefs, and will // release() them. func (r *SpanRegistry) swap(parentID tracingpb.SpanID, children []spanRef) { - r.mu.Lock() - r.removeSpanLocked(parentID) - for _, c := range children { - sp := c.Span.i.crdb - sp.withLock(func() { - if !sp.mu.finished { - r.addSpanLocked(sp) - } - }) - } - r.mu.Unlock() + func() { + r.mu.Lock() + defer r.mu.Unlock() + r.removeSpanLocked(parentID) + for _, c := range children { + sp := c.Span.i.crdb + sp.withLock(func() { + if !sp.mu.finished { + r.addSpanLocked(sp) + } + }) + } + }() for _, c := range children { c.release() } diff --git a/pkg/workload/histogram/histogram.go b/pkg/workload/histogram/histogram.go index ce32ff414021..a29e59db2861 100644 --- a/pkg/workload/histogram/histogram.go +++ b/pkg/workload/histogram/histogram.go @@ -74,6 +74,8 @@ func (w *Registry) newNamedHistogramLocked(name string) *NamedHistogram { func (w *NamedHistogram) Record(elapsed time.Duration) { w.prometheusHistogram.Observe(float64(elapsed.Nanoseconds()) / float64(time.Second)) w.mu.Lock() + defer w.mu.Unlock() + maxLatency := time.Duration(w.mu.current.HighestTrackableValue()) if elapsed < minLatency { elapsed = minLatency @@ -81,10 +83,7 @@ func (w *NamedHistogram) Record(elapsed time.Duration) { elapsed = maxLatency } - err := w.mu.current.RecordValue(elapsed.Nanoseconds()) - w.mu.Unlock() - - if err != nil { + if err := w.mu.current.RecordValue(elapsed.Nanoseconds()); err != nil { // Note that a histogram only drops recorded values that are out of range, // but we clamp the latency value to the configured range to prevent such // drops. This code path should never happen. diff --git a/pkg/workload/ycsb/zipfgenerator.go b/pkg/workload/ycsb/zipfgenerator.go index cd84cfc48f32..c16ea09ca21c 100644 --- a/pkg/workload/ycsb/zipfgenerator.go +++ b/pkg/workload/ycsb/zipfgenerator.go @@ -129,6 +129,7 @@ func computeZetaFromScratch(n uint64, theta float64) (float64, error) { // according to the Zipf distribution. func (z *ZipfGenerator) Uint64() uint64 { z.zipfGenMu.mu.Lock() + defer z.zipfGenMu.mu.Unlock() u := z.zipfGenMu.r.Float64() uz := u * z.zipfGenMu.zetaN var result uint64 @@ -143,7 +144,6 @@ func (z *ZipfGenerator) Uint64() uint64 { if z.verbose { fmt.Printf("Uint64[%d, %d] -> %d\n", z.iMin, z.zipfGenMu.iMax, result) } - z.zipfGenMu.mu.Unlock() return result } @@ -151,16 +151,15 @@ func (z *ZipfGenerator) Uint64() uint64 { // that depend on it. It throws an error if the recomputation failed. func (z *ZipfGenerator) IncrementIMax(count uint64) error { z.zipfGenMu.mu.Lock() + defer z.zipfGenMu.mu.Unlock() zetaN, err := computeZetaIncrementally( z.zipfGenMu.iMax+1-z.iMin, z.zipfGenMu.iMax+count+1-z.iMin, z.theta, z.zipfGenMu.zetaN) if err != nil { - z.zipfGenMu.mu.Unlock() return errors.Wrap(err, "Could not incrementally compute zeta") } z.zipfGenMu.iMax += count eta := (1 - math.Pow(2.0/float64(z.zipfGenMu.iMax+1-z.iMin), 1.0-z.theta)) / (1.0 - z.zeta2/zetaN) z.zipfGenMu.eta = eta z.zipfGenMu.zetaN = zetaN - z.zipfGenMu.mu.Unlock() return nil }