Skip to content

Commit

Permalink
Backport of grpc: fix data race in balancer registration into release…
Browse files Browse the repository at this point in the history
…/1.15.x (#17351)

Registering gRPC balancers is thread-unsafe because they are stored in a
global map variable that is accessed without holding a lock. Therefore,
it's expected that balancers are registered _once_ at the beginning of
your program (e.g. in a package `init` function) and certainly not after
you've started dialing connections, etc.

> NOTE: this function must only be called during initialization time
> (i.e. in an init() function), and is not thread-safe.

While this is fine for us in production, it's challenging for tests that
spin up multiple agents in-memory. We currently register a balancer per-
agent which holds agent-specific state that cannot safely be shared.

This commit introduces our own registry that _is_ thread-safe, and
implements the Builder interface such that we can call gRPC's `Register`
method once, on start-up. It uses the same pattern as our resolver
registry where we use the dial target's host (aka "authority"), which is
unique per-agent, to determine which builder to use.
  • Loading branch information
hc-github-team-consul-core authored May 15, 2023
1 parent 85fbc59 commit 64c29e5
Show file tree
Hide file tree
Showing 12 changed files with 179 additions and 104 deletions.
5 changes: 1 addition & 4 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -1605,10 +1605,7 @@ func (a *Agent) ShutdownAgent() error {

a.stopLicenseManager()

// this would be cancelled anyways (by the closing of the shutdown ch) but
// this should help them to be stopped more quickly
a.baseDeps.AutoConfig.Stop()
a.baseDeps.MetricsConfig.Cancel()
a.baseDeps.Close()

a.stateLock.Lock()
defer a.stateLock.Unlock()
Expand Down
5 changes: 4 additions & 1 deletion agent/consul/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -526,9 +526,13 @@ func newDefaultDeps(t *testing.T, c *Config) Deps {

resolverBuilder := resolver.NewServerResolverBuilder(newTestResolverConfig(t, c.NodeName+"-"+c.Datacenter, c.Datacenter, "server"))
resolver.Register(resolverBuilder)
t.Cleanup(func() {
resolver.Deregister(resolverBuilder.Authority())
})

balancerBuilder := balancer.NewBuilder(resolverBuilder.Authority(), testutil.Logger(t))
balancerBuilder.Register()
t.Cleanup(balancerBuilder.Deregister)

r := router.NewRouter(
logger,
Expand Down Expand Up @@ -563,7 +567,6 @@ func newDefaultDeps(t *testing.T, c *Config) Deps {
UseTLSForDC: tls.UseTLS,
DialingFromServer: true,
DialingFromDatacenter: c.Datacenter,
BalancerBuilder: balancerBuilder,
}),
LeaderForwarder: resolverBuilder,
NewRequestRecorderFunc: middleware.NewRequestRecorder,
Expand Down
3 changes: 1 addition & 2 deletions agent/consul/rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1165,7 +1165,7 @@ func TestRPC_LocalTokenStrippedOnForward_GRPC(t *testing.T) {

var conn *grpc.ClientConn
{
client, resolverBuilder, balancerBuilder := newClientWithGRPCPlumbing(t, func(c *Config) {
client, resolverBuilder := newClientWithGRPCPlumbing(t, func(c *Config) {
c.Datacenter = "dc2"
c.PrimaryDatacenter = "dc1"
c.RPCConfig.EnableStreaming = true
Expand All @@ -1177,7 +1177,6 @@ func TestRPC_LocalTokenStrippedOnForward_GRPC(t *testing.T) {
Servers: resolverBuilder,
DialingFromServer: false,
DialingFromDatacenter: "dc2",
BalancerBuilder: balancerBuilder,
})

conn, err = pool.ClientConn("dc2")
Expand Down
15 changes: 6 additions & 9 deletions agent/consul/subscribe_backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) {
require.NoError(t, err)
defer server.Shutdown()

client, resolverBuilder, balancerBuilder := newClientWithGRPCPlumbing(t, configureTLS, clientConfigVerifyOutgoing)
client, resolverBuilder := newClientWithGRPCPlumbing(t, configureTLS, clientConfigVerifyOutgoing)

// Try to join
testrpc.WaitForLeader(t, server.RPC, "dc1")
Expand Down Expand Up @@ -71,7 +71,6 @@ func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) {
UseTLSForDC: client.tlsConfigurator.UseTLS,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
BalancerBuilder: balancerBuilder,
})
conn, err := pool.ClientConn("dc1")
require.NoError(t, err)
Expand Down Expand Up @@ -116,7 +115,6 @@ func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) {
UseTLSForDC: client.tlsConfigurator.UseTLS,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
BalancerBuilder: balancerBuilder,
})
conn, err := pool.ClientConn("dc1")
require.NoError(t, err)
Expand Down Expand Up @@ -191,7 +189,7 @@ func TestSubscribeBackend_IntegrationWithServer_TLSReload(t *testing.T) {
defer server.Shutdown()

// Set up a client with valid certs and verify_outgoing = true
client, resolverBuilder, balancerBuilder := newClientWithGRPCPlumbing(t, configureTLS, clientConfigVerifyOutgoing)
client, resolverBuilder := newClientWithGRPCPlumbing(t, configureTLS, clientConfigVerifyOutgoing)

testrpc.WaitForLeader(t, server.RPC, "dc1")

Expand All @@ -204,7 +202,6 @@ func TestSubscribeBackend_IntegrationWithServer_TLSReload(t *testing.T) {
UseTLSForDC: client.tlsConfigurator.UseTLS,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
BalancerBuilder: balancerBuilder,
})
conn, err := pool.ClientConn("dc1")
require.NoError(t, err)
Expand Down Expand Up @@ -284,7 +281,7 @@ func TestSubscribeBackend_IntegrationWithServer_DeliversAllMessages(t *testing.T
codec := rpcClient(t, server)
defer codec.Close()

client, resolverBuilder, balancerBuilder := newClientWithGRPCPlumbing(t)
client, resolverBuilder := newClientWithGRPCPlumbing(t)

// Try to join
testrpc.WaitForLeader(t, server.RPC, "dc1")
Expand Down Expand Up @@ -346,7 +343,6 @@ func TestSubscribeBackend_IntegrationWithServer_DeliversAllMessages(t *testing.T
UseTLSForDC: client.tlsConfigurator.UseTLS,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
BalancerBuilder: balancerBuilder,
})
conn, err := pool.ClientConn("dc1")
require.NoError(t, err)
Expand Down Expand Up @@ -376,7 +372,7 @@ func TestSubscribeBackend_IntegrationWithServer_DeliversAllMessages(t *testing.T
"at least some of the subscribers should have received non-snapshot updates")
}

func newClientWithGRPCPlumbing(t *testing.T, ops ...func(*Config)) (*Client, *resolver.ServerResolverBuilder, *balancer.Builder) {
func newClientWithGRPCPlumbing(t *testing.T, ops ...func(*Config)) (*Client, *resolver.ServerResolverBuilder) {
_, config := testClientConfig(t)
for _, op := range ops {
op(config)
Expand All @@ -395,6 +391,7 @@ func newClientWithGRPCPlumbing(t *testing.T, ops ...func(*Config)) (*Client, *re

balancerBuilder := balancer.NewBuilder(resolverBuilder.Authority(), testutil.Logger(t))
balancerBuilder.Register()
t.Cleanup(balancerBuilder.Deregister)

deps := newDefaultDeps(t, config)
deps.Router = router.NewRouter(
Expand All @@ -409,7 +406,7 @@ func newClientWithGRPCPlumbing(t *testing.T, ops ...func(*Config)) (*Client, *re
t.Cleanup(func() {
client.Shutdown()
})
return client, resolverBuilder, balancerBuilder
return client, resolverBuilder
}

type testLogger interface {
Expand Down
40 changes: 20 additions & 20 deletions agent/grpc-internal/balancer/balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,21 +65,25 @@ import (
"google.golang.org/grpc/status"
)

// NewBuilder constructs a new Builder with the given name.
func NewBuilder(name string, logger hclog.Logger) *Builder {
// NewBuilder constructs a new Builder. Calling Register will add the Builder
// to our global registry under the given "authority" such that it will be used
// when dialing targets in the form "consul-internal://<authority>/...", this
// allows us to add and remove balancers for different in-memory agents during
// tests.
func NewBuilder(authority string, logger hclog.Logger) *Builder {
return &Builder{
name: name,
logger: logger,
byTarget: make(map[string]*list.List),
shuffler: randomShuffler(),
authority: authority,
logger: logger,
byTarget: make(map[string]*list.List),
shuffler: randomShuffler(),
}
}

// Builder implements gRPC's balancer.Builder interface to construct balancers.
type Builder struct {
name string
logger hclog.Logger
shuffler shuffler
authority string
logger hclog.Logger
shuffler shuffler

mu sync.Mutex
byTarget map[string]*list.List
Expand Down Expand Up @@ -129,19 +133,15 @@ func (b *Builder) removeBalancer(targetURL string, elem *list.Element) {
}
}

// Name implements the gRPC Balancer interface by returning its given name.
func (b *Builder) Name() string { return b.name }

// gRPC's balancer.Register method is not thread-safe, so we guard our calls
// with a global lock (as it may be called from parallel tests).
var registerLock sync.Mutex

// Register the Builder in gRPC's global registry using its given name.
// Register the Builder in our global registry. Users should call Deregister
// when finished using the Builder to clean-up global state.
func (b *Builder) Register() {
registerLock.Lock()
defer registerLock.Unlock()
globalRegistry.register(b.authority, b)
}

gbalancer.Register(b)
// Deregister the Builder from our global registry to clean up state.
func (b *Builder) Deregister() {
globalRegistry.deregister(b.authority)
}

// Rebalance randomizes the priority order of servers for the given target to
Expand Down
41 changes: 25 additions & 16 deletions agent/grpc-internal/balancer/balancer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"google.golang.org/grpc/stats"
"google.golang.org/grpc/status"

"github.com/hashicorp/go-uuid"

"github.com/hashicorp/consul/agent/grpc-middleware/testutil/testservice"
"github.com/hashicorp/consul/sdk/testutil"
"github.com/hashicorp/consul/sdk/testutil/retry"
Expand All @@ -34,12 +36,13 @@ func TestBalancer(t *testing.T) {
server1 := runServer(t, "server1")
server2 := runServer(t, "server2")

target, _ := stubResolver(t, server1, server2)
target, authority, _ := stubResolver(t, server1, server2)

balancerBuilder := NewBuilder(t.Name(), testutil.Logger(t))
balancerBuilder := NewBuilder(authority, testutil.Logger(t))
balancerBuilder.Register()
t.Cleanup(balancerBuilder.Deregister)

conn := dial(t, target, balancerBuilder)
conn := dial(t, target)
client := testservice.NewSimpleClient(conn)

var serverName string
Expand Down Expand Up @@ -78,12 +81,13 @@ func TestBalancer(t *testing.T) {
server1 := runServer(t, "server1")
server2 := runServer(t, "server2")

target, _ := stubResolver(t, server1, server2)
target, authority, _ := stubResolver(t, server1, server2)

balancerBuilder := NewBuilder(t.Name(), testutil.Logger(t))
balancerBuilder := NewBuilder(authority, testutil.Logger(t))
balancerBuilder.Register()
t.Cleanup(balancerBuilder.Deregister)

conn := dial(t, target, balancerBuilder)
conn := dial(t, target)
client := testservice.NewSimpleClient(conn)

// Figure out which server we're talking to now, and which we should switch to.
Expand Down Expand Up @@ -123,10 +127,11 @@ func TestBalancer(t *testing.T) {
server1 := runServer(t, "server1")
server2 := runServer(t, "server2")

target, _ := stubResolver(t, server1, server2)
target, authority, _ := stubResolver(t, server1, server2)

balancerBuilder := NewBuilder(t.Name(), testutil.Logger(t))
balancerBuilder := NewBuilder(authority, testutil.Logger(t))
balancerBuilder.Register()
t.Cleanup(balancerBuilder.Deregister)

// Provide a custom prioritizer that causes Rebalance to choose whichever
// server didn't get our first request.
Expand All @@ -137,7 +142,7 @@ func TestBalancer(t *testing.T) {
})
}

conn := dial(t, target, balancerBuilder)
conn := dial(t, target)
client := testservice.NewSimpleClient(conn)

// Figure out which server we're talking to now.
Expand Down Expand Up @@ -177,12 +182,13 @@ func TestBalancer(t *testing.T) {
server1 := runServer(t, "server1")
server2 := runServer(t, "server2")

target, res := stubResolver(t, server1, server2)
target, authority, res := stubResolver(t, server1, server2)

balancerBuilder := NewBuilder(t.Name(), testutil.Logger(t))
balancerBuilder := NewBuilder(authority, testutil.Logger(t))
balancerBuilder.Register()
t.Cleanup(balancerBuilder.Deregister)

conn := dial(t, target, balancerBuilder)
conn := dial(t, target)
client := testservice.NewSimpleClient(conn)

// Figure out which server we're talking to now.
Expand Down Expand Up @@ -233,7 +239,7 @@ func TestBalancer(t *testing.T) {
})
}

func stubResolver(t *testing.T, servers ...*server) (string, *manual.Resolver) {
func stubResolver(t *testing.T, servers ...*server) (string, string, *manual.Resolver) {
t.Helper()

addresses := make([]resolver.Address, len(servers))
Expand All @@ -249,7 +255,10 @@ func stubResolver(t *testing.T, servers ...*server) (string, *manual.Resolver) {
resolver.Register(r)
t.Cleanup(func() { resolver.UnregisterForTesting(scheme) })

return fmt.Sprintf("%s://", scheme), r
authority, err := uuid.GenerateUUID()
require.NoError(t, err)

return fmt.Sprintf("%s://%s", scheme, authority), authority, r
}

func runServer(t *testing.T, name string) *server {
Expand Down Expand Up @@ -309,12 +318,12 @@ func (s *server) Something(context.Context, *testservice.Req) (*testservice.Resp
return &testservice.Resp{ServerName: s.name}, nil
}

func dial(t *testing.T, target string, builder *Builder) *grpc.ClientConn {
func dial(t *testing.T, target string) *grpc.ClientConn {
conn, err := grpc.Dial(
target,
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultServiceConfig(
fmt.Sprintf(`{"loadBalancingConfig":[{"%s":{}}]}`, builder.Name()),
fmt.Sprintf(`{"loadBalancingConfig":[{"%s":{}}]}`, BuilderName),
),
)
t.Cleanup(func() {
Expand Down
69 changes: 69 additions & 0 deletions agent/grpc-internal/balancer/registry.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package balancer

import (
"fmt"
"sync"

gbalancer "google.golang.org/grpc/balancer"
)

// BuilderName should be given in gRPC service configuration to enable our
// custom balancer. It refers to this package's global registry, rather than
// an instance of Builder to enable us to add and remove builders at runtime,
// specifically during tests.
const BuilderName = "consul-internal"

// gRPC's balancer.Register method is thread-unsafe because it mutates a global
// map without holding a lock. As such, it's expected that you register custom
// balancers once at the start of your program (e.g. a package init function).
//
// In production, this is fine. Agents register a single instance of our builder
// and use it for the duration. Tests are where this becomes problematic, as we
// spin up several agents in-memory and register/deregister a builder for each,
// with its own agent-specific state, logger, etc.
//
// To avoid data races, we call gRPC's Register method once, on-package init,
// with a global registry struct that implements the Builder interface but
// delegates the building to N instances of our Builder that are registered and
// deregistered at runtime. We the dial target's host (aka "authority") which
// is unique per-agent to pick the correct builder.
func init() {
gbalancer.Register(globalRegistry)
}

var globalRegistry = &registry{
byAuthority: make(map[string]*Builder),
}

type registry struct {
mu sync.RWMutex
byAuthority map[string]*Builder
}

func (r *registry) Build(cc gbalancer.ClientConn, opts gbalancer.BuildOptions) gbalancer.Balancer {
r.mu.RLock()
defer r.mu.RUnlock()

auth := opts.Target.URL.Host
builder, ok := r.byAuthority[auth]
if !ok {
panic(fmt.Sprintf("no gRPC balancer builder registered for authority: %q", auth))
}
return builder.Build(cc, opts)
}

func (r *registry) Name() string { return BuilderName }

func (r *registry) register(auth string, builder *Builder) {
r.mu.Lock()
defer r.mu.Unlock()

r.byAuthority[auth] = builder
}

func (r *registry) deregister(auth string) {
r.mu.Lock()
defer r.mu.Unlock()

delete(r.byAuthority, auth)
}
Loading

0 comments on commit 64c29e5

Please sign in to comment.