From 51a8ff3cdf02365eb2aedf77d825e808dd9b3b31 Mon Sep 17 00:00:00 2001 From: Tiago Silva Date: Thu, 11 Jul 2024 16:00:22 +0100 Subject: [PATCH] [sec_scan][16] add methods to store/retrieve device assertion functions This PR adds methods to store/retrieve functions defined by different teleport services. This PR is part of https://github.com/gravitational/access-graph/issues/637. Signed-off-by: Tiago Silva --- lib/auth/auth.go | 56 ++++++++++++++++++++++++++++++++++++ lib/auth/authclient/clt.go | 7 +++++ lib/services/access_graph.go | 14 +++++++++ lib/services/device.go | 1 + 4 files changed, 78 insertions(+) diff --git a/lib/auth/auth.go b/lib/auth/auth.go index 64e03f5e0be1d..83d30556d0443 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -94,6 +94,7 @@ import ( "github.com/gravitational/teleport/lib/cloud" "github.com/gravitational/teleport/lib/cryptosuites" "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/devicetrust/assertserver" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/gcp" "github.com/gravitational/teleport/lib/githubactions" @@ -613,6 +614,8 @@ type Services struct { services.AccessMonitoringRules services.CrownJewels services.BotInstance + services.AccessGraphSecretsGetter + services.DevicesGetter } // SecReportsClient returns the security reports client. @@ -685,6 +688,16 @@ func (r *Services) DatabaseObjectsClient() services.DatabaseObjects { return r } +// GetAccessGraphSecretsGetter returns the AccessGraph secrets service. +func (r *Services) GetAccessGraphSecretsGetter() services.AccessGraphSecretsGetter { + return r.AccessGraphSecretsGetter +} + +// GetDevicesGetter returns the trusted devices service. +func (r *Services) GetDevicesGetter() services.DevicesGetter { + return r.DevicesGetter +} + var ( generateRequestsCount = prometheus.NewCounter( prometheus.CounterOpts{ @@ -829,6 +842,10 @@ type LoginHook func(context.Context, types.User) error // the user has no suitable trusted device. type CreateDeviceWebTokenFunc func(context.Context, *devicepb.DeviceWebToken) (*devicepb.DeviceWebToken, error) +// CreateDeviceAssertionFunc creates a new device assertion ceremony to authenticate +// a trusted device. +type CreateDeviceAssertionFunc func() (assertserver.Ceremony, error) + // ReadOnlyCache is a type alias used to assist with embedding [readonly.Cache] in places // where it would have a naming conflict with other types named Cache. type ReadOnlyCache = readonly.Cache @@ -1005,6 +1022,15 @@ type Server struct { // Is nil on OSS clusters. createDeviceWebTokenFunc CreateDeviceWebTokenFunc + // deviceAssertionServer holds the server-side implementation of device assertions. + // + // It is used to authenticate devices previously enrolled in the cluster. The goal + // is to provide an API for devices to authenticate with the cluster without the need + // for valid user credentials, e.g. when running `tsh scan keys`. + // + // The value is nil on OSS clusters. + deviceAssertionServer CreateDeviceAssertionFunc + // bcryptCostOverride overrides the bcrypt cost for operations executed // directly by [Server]. // Used for testing. @@ -1153,6 +1179,26 @@ func (a *Server) SetHeadlessAuthenticationWatcher(headlessAuthenticationWatcher a.headlessAuthenticationWatcher = headlessAuthenticationWatcher } +// SetDeviceAssertionServer sets the device assertion implementation. +func (a *Server) SetDeviceAssertionServer(f CreateDeviceAssertionFunc) { + a.lock.Lock() + a.deviceAssertionServer = f + a.lock.Unlock() +} + +// GetDeviceAssertionServer returns the device assertion implementation. +// On OSS clusters, this will return a non nil function that returns an error. +func (a *Server) GetDeviceAssertionServer() CreateDeviceAssertionFunc { + a.lock.RLock() + defer a.lock.RUnlock() + if a.deviceAssertionServer == nil { + return func() (assertserver.Ceremony, error) { + return nil, trace.NotImplemented("device assertions are not supported on OSS clusters") + } + } + return a.deviceAssertionServer +} + func (a *Server) SetCreateDeviceWebTokenFunc(f CreateDeviceWebTokenFunc) { a.lock.Lock() a.createDeviceWebTokenFunc = f @@ -1770,6 +1816,16 @@ func (a *Server) SetSCIMService(scim services.SCIM) { a.Services.SCIM = scim } +// SetAccessGraphSecretService sets the server's access graph secret service +func (a *Server) SetAccessGraphSecretService(s services.AccessGraphSecretsGetter) { + a.Services.AccessGraphSecretsGetter = s +} + +// SetDevicesGetter sets the server's device service +func (a *Server) SetDevicesGetter(s services.DevicesGetter) { + a.Services.DevicesGetter = s +} + // SetAuditLog sets the server's audit log func (a *Server) SetAuditLog(auditLog events.AuditLogSessionStreamer) { a.Services.AuditLogSessionStreamer = auditLog diff --git a/lib/auth/authclient/clt.go b/lib/auth/authclient/clt.go index 763500782387b..62ae2dc104b08 100644 --- a/lib/auth/authclient/clt.go +++ b/lib/auth/authclient/clt.go @@ -37,6 +37,7 @@ import ( "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/client/secreport" apidefaults "github.com/gravitational/teleport/api/defaults" + accessgraphsecretsv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/accessgraph/v1" clusterconfigpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/clusterconfig/v1" dbobjectimportrulev1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/dbobjectimportrule/v1" devicepb "github.com/gravitational/teleport/api/gen/proto/go/teleport/devicetrust/v1" @@ -652,6 +653,10 @@ func (c *Client) AccessGraphClient() accessgraphv1.AccessGraphServiceClient { return accessgraphv1.NewAccessGraphServiceClient(c.APIClient.GetConnection()) } +func (c *Client) AccessGraphSecretsScannerClient() accessgraphsecretsv1pb.SecretsScannerServiceClient { + return accessgraphsecretsv1pb.NewSecretsScannerServiceClient(c.APIClient.GetConnection()) +} + func (c *Client) IntegrationAWSOIDCClient() integrationv1.AWSOIDCServiceClient { return integrationv1.NewAWSOIDCServiceClient(c.APIClient.GetConnection()) } @@ -1454,6 +1459,8 @@ type ClientI interface { // AccessGraphClient returns a client to the Access Graph gRPC service. AccessGraphClient() accessgraphv1.AccessGraphServiceClient + AccessGraphSecretsScannerClient() accessgraphsecretsv1pb.SecretsScannerServiceClient + // IntegrationAWSOIDCClient returns a client to the Integration AWS OIDC gRPC service. IntegrationAWSOIDCClient() integrationv1.AWSOIDCServiceClient diff --git a/lib/services/access_graph.go b/lib/services/access_graph.go index 6ff41221b5a37..aef36297c8128 100644 --- a/lib/services/access_graph.go +++ b/lib/services/access_graph.go @@ -19,12 +19,26 @@ package services import ( + "context" + "github.com/gravitational/trace" accessgraphsecretspb "github.com/gravitational/teleport/api/gen/proto/go/teleport/accessgraph/v1" "github.com/gravitational/teleport/api/types/accessgraph" ) +// AccessGraphSecretsGetter is an interface for getting access graph secrets. +type AccessGraphSecretsGetter interface { + // ListAllAuthorizedKeys lists all authorized keys stored in the backend. + ListAllAuthorizedKeys(ctx context.Context, pageSize int, pageToken string) ([]*accessgraphsecretspb.AuthorizedKey, string, error) + // ListAuthorizedKeysForServer lists all authorized keys for a given hostID. + ListAuthorizedKeysForServer(ctx context.Context, hostID string, pageSize int, pageToken string) ([]*accessgraphsecretspb.AuthorizedKey, string, error) + // ListAllPrivateKeys lists all private keys stored in the backend. + ListAllPrivateKeys(ctx context.Context, pageSize int, pageToken string) ([]*accessgraphsecretspb.PrivateKey, string, error) + // ListPrivateKeysForDevice lists all private keys for a given deviceID. + ListPrivateKeysForDevice(ctx context.Context, deviceID string, pageSize int, pageToken string) ([]*accessgraphsecretspb.PrivateKey, string, error) +} + // MarshalAccessGraphAuthorizedKey marshals a [accessgraphsecretspb.AuthorizedKey] resource to JSON. func MarshalAccessGraphAuthorizedKey(in *accessgraphsecretspb.AuthorizedKey, opts ...MarshalOption) ([]byte, error) { if err := accessgraph.ValidateAuthorizedKey(in); err != nil { diff --git a/lib/services/device.go b/lib/services/device.go index 319fe46bbc788..1456d39324a64 100644 --- a/lib/services/device.go +++ b/lib/services/device.go @@ -29,6 +29,7 @@ import ( "github.com/gravitational/teleport/lib/utils" ) +// DevicesGetter allows to list all registered devices from storage. type DevicesGetter interface { ListDevices(ctx context.Context, pageSize int, pageToken string, view devicepb.DeviceView) (devices []*devicepb.Device, nextPageToken string, err error) }