Skip to content

Commit

Permalink
backport of commit a8b84a6
Browse files Browse the repository at this point in the history
  • Loading branch information
tgross authored Nov 5, 2024
1 parent 9d605cd commit 3c81423
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 151 deletions.
37 changes: 23 additions & 14 deletions client/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ import (
"github.com/hashicorp/nomad/client/fingerprint"
"github.com/hashicorp/nomad/client/servers"
"github.com/hashicorp/nomad/client/serviceregistration/mock"
"github.com/hashicorp/nomad/client/state"
agentconsul "github.com/hashicorp/nomad/command/agent/consul"
"github.com/hashicorp/nomad/helper/pluginutils/catalog"
"github.com/hashicorp/nomad/helper/pluginutils/singleton"
"github.com/hashicorp/nomad/helper/pool"
"github.com/hashicorp/nomad/helper/testlog"
testing "github.com/mitchellh/go-testing-interface"
"github.com/shoenig/test/must"
)

// TestClient creates an in-memory client for testing purposes and returns a
Expand Down Expand Up @@ -91,45 +93,52 @@ func TestClientWithRPCs(t testing.T, cb func(c *config.Config), rpcs map[string]
// with the server and then returns mock RPC responses for those interfaces
// passed in the `rpcs` parameter. Useful for testing client RPCs from the
// server. Returns the Client, a shutdown function, and any error.
func TestRPCOnlyClient(t testing.T, srvAddr net.Addr, rpcs map[string]interface{}) (*Client, func() error, error) {
var err error
func TestRPCOnlyClient(t testing.T, cb func(c *config.Config), srvAddr net.Addr, rpcs map[string]any) (*Client, func()) {
t.Helper()
conf, cleanup := config.TestClientConfig(t)
conf.StateDBFactory = state.GetStateDBFactory(true)
if cb != nil {
cb(conf)
}

client := &Client{config: conf, logger: testlog.HCLogger(t)}
client := &Client{config: conf, logger: testlog.HCLogger(t), shutdownCh: make(chan struct{})}
client.servers = servers.New(client.logger, client.shutdownCh, client)

client.registeredCh = make(chan struct{})
client.rpcServer = rpc.NewServer()
for name, rpc := range rpcs {
client.rpcServer.RegisterName(name, rpc)
}

client.heartbeatStop = newHeartbeatStop(
client.getAllocRunner, time.Second, client.logger, client.shutdownCh)
client.connPool = pool.NewPool(testlog.HCLogger(t), 10*time.Second, 10, nil)
client.init()

cancelFunc := func() error {
cancelFunc := func() {
ch := make(chan error)

go func() {
defer close(ch)
client.connPool.Shutdown()
close(client.shutdownCh)
client.shutdownGroup.Wait()
cleanup()
}()

select {
case <-ch:
return nil
case <-time.After(1 * time.Minute):
return fmt.Errorf("timed out while shutting down client")
return
case <-time.After(5 * time.Second):
t.Error("timed out while shutting down client")
return
}
}

go client.rpcConnListener()

_, err = client.SetServers([]string{srvAddr.String()})
if err != nil {
return nil, cancelFunc, fmt.Errorf("could not set servers: %v", err)
}
_, err := client.SetServers([]string{srvAddr.String()})
must.NoError(t, err, must.Sprintf("could not set servers: %v", err))

client.shutdownGroup.Go(client.registerAndHeartbeat)

return client, cancelFunc, nil
return client, cancelFunc
}
Loading

0 comments on commit 3c81423

Please sign in to comment.