Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

client: send node secret with every client-to-server RPC #16799

Merged
merged 4 commits into from
Jun 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 92 additions & 17 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,10 @@ type Client struct {
// applied to the node
fpInitialized chan struct{}

// registeredCh is closed when Node.Register has successfully run once.
registeredCh chan struct{}
registeredOnce sync.Once

// serversContactedCh is closed when GetClientAllocs and runAllocs have
// successfully run once.
serversContactedCh chan struct{}
Expand Down Expand Up @@ -376,6 +380,8 @@ func NewClient(cfg *config.Config, consulCatalog consul.CatalogAPI, consulProxie
invalidAllocs: make(map[string]struct{}),
serversContactedCh: make(chan struct{}),
serversContactedOnce: sync.Once{},
registeredCh: make(chan struct{}),
registeredOnce: sync.Once{},
cpusetManager: cgutil.CreateCPUSetManager(cfg.CgroupParent, cfg.ReservableCores, logger),
getter: getter.New(cfg.Artifact, logger),
EnterpriseClient: newEnterpriseClient(logger),
Expand Down Expand Up @@ -1832,6 +1838,7 @@ func (c *Client) periodicSnapshot() {

// run is a long lived goroutine used to run the client. Shutdown() stops it first
func (c *Client) run() {

// Watch for changes in allocations
allocUpdates := make(chan *allocUpdates, 8)
go c.watchAllocations(allocUpdates)
Expand Down Expand Up @@ -1866,8 +1873,11 @@ func (c *Client) submitNodeEvents(events []*structs.NodeEvent) error {
nodeID: events,
}
req := structs.EmitNodeEventsRequest{
NodeEvents: nodeEvents,
WriteRequest: structs.WriteRequest{Region: c.Region()},
NodeEvents: nodeEvents,
WriteRequest: structs.WriteRequest{
Region: c.Region(),
AuthToken: c.secretNodeID(),
},
}
var resp structs.EmitNodeEventsResponse
if err := c.RPC("Node.EmitEvents", &req, &resp); err != nil {
Expand Down Expand Up @@ -1923,8 +1933,11 @@ func (c *Client) triggerNodeEvent(nodeEvent *structs.NodeEvent) {
// retryRegisterNode is used to register the node or update the registration and
// retry in case of failure.
func (c *Client) retryRegisterNode() {

authToken := c.getRegistrationToken()

for {
err := c.registerNode()
err := c.registerNode(authToken)
if err == nil {
// Registered!
return
Expand All @@ -1935,6 +1948,13 @@ func (c *Client) retryRegisterNode() {
c.logger.Debug("registration waiting on servers")
c.triggerDiscovery()
retryIntv = noServerRetryIntv
} else if structs.IsErrPermissionDenied(err) {
// any previous cluster state we have here is invalid (ex. client
// has been assigned to a new region), so clear the token and local
// state for next pass.
authToken = ""
c.stateDB.PutNodeRegistration(&cstructs.NodeRegistration{HasRegistered: false})
c.logger.Error("error registering", "error", err)
} else {
c.logger.Error("error registering", "error", err)
}
Expand All @@ -1947,17 +1967,61 @@ func (c *Client) retryRegisterNode() {
}
}

// getRegistrationToken gets the node secret to use for the Node.Register call.
// Registration is trust-on-first-use so we can't send the auth token with the
// initial request, but we want to add the auth token after that so that we can
// capture metrics.
func (c *Client) getRegistrationToken() string {

select {
case <-c.registeredCh:
return c.secretNodeID()
default:
// If we haven't yet closed the registeredCh we're either starting for
// the 1st time or we've just restarted. Check the local state to see if
// we've written a successful registration previously so that we don't
// block allocrunner operations on disconnected clients.
registration, err := c.stateDB.GetNodeRegistration()
if err != nil {
c.logger.Error("could not determine previous node registration", "error", err)
}
if registration != nil && registration.HasRegistered {
c.registeredOnce.Do(func() { close(c.registeredCh) })
return c.secretNodeID()
}
}
return ""
}

// registerNode is used to register the node or update the registration
func (c *Client) registerNode() error {
func (c *Client) registerNode(authToken string) error {
req := structs.NodeRegisterRequest{
Node: c.Node(),
WriteRequest: structs.WriteRequest{Region: c.Region()},
Node: c.Node(),
WriteRequest: structs.WriteRequest{
Region: c.Region(),
AuthToken: authToken,
},
}

var resp structs.NodeUpdateResponse
if err := c.RPC("Node.Register", &req, &resp); err != nil {
if err := c.UnauthenticatedRPC("Node.Register", &req, &resp); err != nil {
return err
}

// Signal that we've registered once so that RPCs sent from the client can
// send authenticated requests. Persist this information in the state so
// that we don't block restoring running allocs when restarting while
// disconnected
c.registeredOnce.Do(func() {
err := c.stateDB.PutNodeRegistration(&cstructs.NodeRegistration{
HasRegistered: true,
})
if err != nil {
c.logger.Error("could not write node registration", "error", err)
}
close(c.registeredCh)
})

err := c.handleNodeUpdateResponse(resp)
if err != nil {
return err
Expand Down Expand Up @@ -1985,9 +2049,12 @@ func (c *Client) registerNode() error {
func (c *Client) updateNodeStatus() error {
start := time.Now()
req := structs.NodeUpdateStatusRequest{
NodeID: c.NodeID(),
Status: structs.NodeStatusReady,
WriteRequest: structs.WriteRequest{Region: c.Region()},
NodeID: c.NodeID(),
Status: structs.NodeStatusReady,
WriteRequest: structs.WriteRequest{
Region: c.Region(),
AuthToken: c.secretNodeID(),
},
}
var resp structs.NodeUpdateResponse
if err := c.RPC("Node.UpdateStatus", &req, &resp); err != nil {
Expand Down Expand Up @@ -2129,8 +2196,11 @@ func (c *Client) allocSync() {

// Send to server.
args := structs.AllocUpdateRequest{
Alloc: toSync,
WriteRequest: structs.WriteRequest{Region: c.Region()},
Alloc: toSync,
WriteRequest: structs.WriteRequest{
Region: c.Region(),
AuthToken: c.secretNodeID(),
schmichael marked this conversation as resolved.
Show resolved Hide resolved
},
}

var resp structs.GenericResponse
Expand Down Expand Up @@ -2204,6 +2274,7 @@ func (c *Client) watchAllocations(updates chan *allocUpdates) {
// After the first request, only require monotonically
// increasing state.
AllowStale: false,
AuthToken: c.secretNodeID(),
},
}
var resp structs.NodeClientAllocsResponse
Expand Down Expand Up @@ -2721,6 +2792,7 @@ func (c *Client) deriveToken(alloc *structs.Allocation, taskNames []string, vcli
Region: c.Region(),
AllowStale: false,
MinQueryIndex: alloc.CreateIndex,
AuthToken: c.secretNodeID(),
},
}

Expand Down Expand Up @@ -2794,11 +2866,14 @@ func (c *Client) deriveSIToken(alloc *structs.Allocation, taskNames []string) (m
}

req := &structs.DeriveSITokenRequest{
NodeID: c.NodeID(),
SecretID: c.secretNodeID(),
AllocID: alloc.ID,
Tasks: tasks,
QueryOptions: structs.QueryOptions{Region: c.Region()},
NodeID: c.NodeID(),
SecretID: c.secretNodeID(),
AllocID: alloc.ID,
Tasks: tasks,
QueryOptions: structs.QueryOptions{
Region: c.Region(),
AuthToken: c.secretNodeID(),
},
}

// Nicely ask Nomad Server for the tokens.
Expand Down
15 changes: 12 additions & 3 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,9 @@ func TestClient_MixedTLS(t *testing.T) {
})
defer cleanup()

// tell the client we've registered to unblock the RPC we test below
c1.registeredOnce.Do(func() { close(c1.registeredCh) })

req := structs.NodeSpecificRequest{
NodeID: c1.Node().ID,
QueryOptions: structs.QueryOptions{Region: "global"},
Expand All @@ -288,7 +291,7 @@ func TestClient_MixedTLS(t *testing.T) {
testutil.AssertUntil(100*time.Millisecond,
func() (bool, error) {
err := c1.RPC("Node.GetNode", &req, &out)
if err == nil {
if err == nil || structs.IsErrPermissionDenied(err) {
return false, fmt.Errorf("client RPC succeeded when it should have failed:\n%+v", out)
}
return true, nil
Expand Down Expand Up @@ -339,6 +342,9 @@ func TestClient_BadTLS(t *testing.T) {
})
defer cleanupC1()

// tell the client we've registered to unblock the RPC we test below
c1.registeredOnce.Do(func() { close(c1.registeredCh) })

req := structs.NodeSpecificRequest{
NodeID: c1.Node().ID,
QueryOptions: structs.QueryOptions{Region: "global"},
Expand All @@ -347,7 +353,7 @@ func TestClient_BadTLS(t *testing.T) {
testutil.AssertUntil(100*time.Millisecond,
func() (bool, error) {
err := c1.RPC("Node.GetNode", &req, &out)
if err == nil {
if err == nil || structs.IsErrPermissionDenied(err) {
return false, fmt.Errorf("client RPC succeeded when it should have failed:\n%+v", out)
}
return true, nil
Expand Down Expand Up @@ -1276,6 +1282,9 @@ func TestClient_ReloadTLS_DowngradeTLSToPlaintext(t *testing.T) {
})
defer cleanup()

// tell the client we've registered to unblock the RPC we test below
c1.registeredOnce.Do(func() { close(c1.registeredCh) })

// assert that when one node is running in encrypted mode, a RPC request to a
// node running in plaintext mode should fail
{
Expand All @@ -1286,7 +1295,7 @@ func TestClient_ReloadTLS_DowngradeTLSToPlaintext(t *testing.T) {
testutil.WaitForResult(func() (bool, error) {
var out structs.SingleNodeResponse
err := c1.RPC("Node.GetNode", &req, &out)
if err == nil {
if err == nil || structs.IsErrPermissionDenied(err) {
return false, fmt.Errorf("client RPC succeeded when it should have failed :\n%+v", err)
}
return true, nil
Expand Down
31 changes: 26 additions & 5 deletions client/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,27 @@ func (c *Client) StreamingRpcHandler(method string) (structs.StreamingRpcHandler
}

// RPC is used to forward an RPC call to a nomad server, or fail if no servers.
func (c *Client) RPC(method string, args interface{}, reply interface{}) error {
func (c *Client) RPC(method string, args any, reply any) error {
// Block if we have not yet registered the node, to enforce that we only
// send authenticated calls after the node has been registered
select {
case <-c.registeredCh:
case <-c.shutdownCh:
return nil
}
return c.rpc(method, args, reply)
}

// UnauthenticatedRPC special-cases the Node.Register RPC call, forwarding the
// call to a nomad server without blocking on the initial node registration.
func (c *Client) UnauthenticatedRPC(method string, args any, reply any) error {
return c.rpc(method, args, reply)
}

// rpc implements the forwarding of a RPC call to a nomad server, or fail if
// no servers.
func (c *Client) rpc(method string, args any, reply any) error {

conf := c.GetConfig()

// Invoke the RPCHandler if it exists
Expand Down Expand Up @@ -437,13 +457,14 @@ func resolveServer(s string) (net.Addr, error) {
return net.ResolveTCPAddr("tcp", net.JoinHostPort(host, port))
}

// Ping never mutates the request, so reuse a singleton to avoid the extra
// malloc
var pingRequest = &structs.GenericRequest{}

// Ping is used to ping a particular server and returns whether it is healthy or
// a potential error.
func (c *Client) Ping(srv net.Addr) error {
pingRequest := &structs.GenericRequest{
QueryOptions: structs.QueryOptions{
AuthToken: c.secretNodeID(),
},
}
var reply struct{}
err := c.connPool.RPC(c.Region(), srv, "Status.Ping", pingRequest, &reply)
return err
Expand Down
38 changes: 38 additions & 0 deletions client/state/db_bolt.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/hashicorp/nomad/client/dynamicplugins"
driverstate "github.com/hashicorp/nomad/client/pluginmanager/drivermanager/state"
"github.com/hashicorp/nomad/client/serviceregistration/checks"
cstructs "github.com/hashicorp/nomad/client/structs"
"github.com/hashicorp/nomad/helper/boltdd"
"github.com/hashicorp/nomad/nomad/structs"
"go.etcd.io/bbolt"
Expand Down Expand Up @@ -51,6 +52,9 @@ dynamicplugins/

nodemeta/
|--> meta -> map[string]*string

node/
|--> registration -> *cstructs.NodeRegistration
*/

var (
Expand Down Expand Up @@ -120,6 +124,12 @@ var (

// nodeMetaKey is the key at which dynamic node metadata is stored.
nodeMetaKey = []byte("meta")

// nodeBucket is the bucket name in which data about the node is stored.
nodeBucket = []byte("node")

// nodeRegistrationKey is the key at which node registration data is stored.
nodeRegistrationKey = []byte("node_registration")
)

// taskBucketName returns the bucket name for the given task name.
Expand Down Expand Up @@ -897,6 +907,34 @@ func getNodeMeta(b *boltdd.Bucket) (map[string]*string, error) {
return m, nil
}

func (s *BoltStateDB) PutNodeRegistration(reg *cstructs.NodeRegistration) error {
return s.db.Update(func(tx *boltdd.Tx) error {
b, err := tx.CreateBucketIfNotExists(nodeBucket)
if err != nil {
return err
}

return b.Put(nodeRegistrationKey, reg)
})
}

func (s *BoltStateDB) GetNodeRegistration() (*cstructs.NodeRegistration, error) {
var reg cstructs.NodeRegistration
err := s.db.View(func(tx *boltdd.Tx) error {
b := tx.Bucket(nodeBucket)
if b == nil {
return nil
}
return b.Get(nodeRegistrationKey, &reg)
})

if boltdd.IsErrNotFound(err) {
return nil, nil
}

return &reg, err
}

// init initializes metadata entries in a newly created state database.
func (s *BoltStateDB) init() error {
return s.db.Update(func(tx *boltdd.Tx) error {
Expand Down
9 changes: 9 additions & 0 deletions client/state/db_error.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/hashicorp/nomad/client/dynamicplugins"
driverstate "github.com/hashicorp/nomad/client/pluginmanager/drivermanager/state"
"github.com/hashicorp/nomad/client/serviceregistration/checks"
cstructs "github.com/hashicorp/nomad/client/structs"
"github.com/hashicorp/nomad/nomad/structs"
)

Expand Down Expand Up @@ -129,6 +130,14 @@ func (m *ErrDB) GetNodeMeta() (map[string]*string, error) {
return nil, fmt.Errorf("Error!")
}

func (m *ErrDB) PutNodeRegistration(reg *cstructs.NodeRegistration) error {
return fmt.Errorf("Error!")
}

func (m *ErrDB) GetNodeRegistration() (*cstructs.NodeRegistration, error) {
return nil, fmt.Errorf("Error!")
}

func (m *ErrDB) Close() error {
return fmt.Errorf("Error!")
}
Loading