Skip to content

Commit

Permalink
[sec_scan][16] add methods to store/retrieve device assertion functio…
Browse files Browse the repository at this point in the history
…ns (#44081)

This PR adds methods to store/retrieve functions defined by different teleport services.

This PR is part of gravitational/access-graph#637.

Signed-off-by: Tiago Silva <[email protected]>
  • Loading branch information
tigrato committed Jul 30, 2024
1 parent 1c430c5 commit bb53c1c
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 0 deletions.
56 changes: 56 additions & 0 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ import (
"github.com/gravitational/teleport/lib/circleci"
"github.com/gravitational/teleport/lib/cloud"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/devicetrust/assertserver"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/gcp"
"github.com/gravitational/teleport/lib/githubactions"
Expand Down Expand Up @@ -614,6 +615,8 @@ type Services struct {
services.KubeWaitingContainer
services.AccessMonitoringRules
services.CrownJewels
services.AccessGraphSecretsGetter
services.DevicesGetter
}

// SecReportsClient returns the security reports client.
Expand Down Expand Up @@ -681,6 +684,16 @@ func (r *Services) KubernetesWaitingContainerClient() services.KubeWaitingContai
return r
}

// GetAccessGraphSecretsGetter returns the AccessGraph secrets service.
func (r *Services) GetAccessGraphSecretsGetter() services.AccessGraphSecretsGetter {
return r.AccessGraphSecretsGetter
}

// GetDevicesGetter returns the trusted devices service.
func (r *Services) GetDevicesGetter() services.DevicesGetter {
return r.DevicesGetter
}

var (
generateRequestsCount = prometheus.NewCounter(
prometheus.CounterOpts{
Expand Down Expand Up @@ -816,6 +829,10 @@ var (
// successfully authenticated. An example would be creating objects based on the user.
type LoginHook func(context.Context, types.User) error

// CreateDeviceAssertionFunc creates a new device assertion ceremony to authenticate
// a trusted device.
type CreateDeviceAssertionFunc func() (assertserver.Ceremony, error)

// ReadOnlyCache is a type alias used to assist with embedding [readonly.Cache] in places
// where it would have a naming conflict with other types named Cache.
type ReadOnlyCache = readonly.Cache
Expand Down Expand Up @@ -994,6 +1011,15 @@ type Server struct {
// ulsGenerator is the user login state generator.
ulsGenerator *userloginstate.Generator

// deviceAssertionServer holds the server-side implementation of device assertions.
//
// It is used to authenticate devices previously enrolled in the cluster. The goal
// is to provide an API for devices to authenticate with the cluster without the need
// for valid user credentials, e.g. when running `tsh scan keys`.
//
// The value is nil on OSS clusters.
deviceAssertionServer CreateDeviceAssertionFunc

// bcryptCostOverride overrides the bcrypt cost for operations executed
// directly by [Server].
// Used for testing.
Expand Down Expand Up @@ -1142,6 +1168,26 @@ func (a *Server) SetHeadlessAuthenticationWatcher(headlessAuthenticationWatcher
a.headlessAuthenticationWatcher = headlessAuthenticationWatcher
}

// SetDeviceAssertionServer sets the device assertion implementation.
func (a *Server) SetDeviceAssertionServer(f CreateDeviceAssertionFunc) {
a.lock.Lock()
a.deviceAssertionServer = f
a.lock.Unlock()
}

// GetDeviceAssertionServer returns the device assertion implementation.
// On OSS clusters, this will return a non nil function that returns an error.
func (a *Server) GetDeviceAssertionServer() CreateDeviceAssertionFunc {
a.lock.RLock()
defer a.lock.RUnlock()
if a.deviceAssertionServer == nil {
return func() (assertserver.Ceremony, error) {
return nil, trace.NotImplemented("device assertions are not supported on OSS clusters")
}
}
return a.deviceAssertionServer
}

func (a *Server) bcryptCost() int {
if cost := a.bcryptCostOverride; cost != nil {
return *cost
Expand Down Expand Up @@ -1742,6 +1788,16 @@ func (a *Server) SetSCIMService(scim services.SCIM) {
a.Services.SCIM = scim
}

// SetAccessGraphSecretService sets the server's access graph secret service
func (a *Server) SetAccessGraphSecretService(s services.AccessGraphSecretsGetter) {
a.Services.AccessGraphSecretsGetter = s
}

// SetDevicesGetter sets the server's device service
func (a *Server) SetDevicesGetter(s services.DevicesGetter) {
a.Services.DevicesGetter = s
}

// SetAuditLog sets the server's audit log
func (a *Server) SetAuditLog(auditLog events.AuditLogSessionStreamer) {
a.Services.AuditLogSessionStreamer = auditLog
Expand Down
7 changes: 7 additions & 0 deletions lib/auth/authclient/clt.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import (
"github.com/gravitational/teleport/api/client/secreport"
apidefaults "github.com/gravitational/teleport/api/defaults"
assistpb "github.com/gravitational/teleport/api/gen/proto/go/assist/v1"
accessgraphsecretsv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/accessgraph/v1"
clusterconfigpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/clusterconfig/v1"
dbobjectv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/dbobject/v1"
dbobjectimportrulev1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/dbobjectimportrule/v1"
Expand Down Expand Up @@ -545,6 +546,10 @@ func (c *Client) AccessGraphClient() accessgraphv1.AccessGraphServiceClient {
return accessgraphv1.NewAccessGraphServiceClient(c.APIClient.GetConnection())
}

func (c *Client) AccessGraphSecretsScannerClient() accessgraphsecretsv1pb.SecretsScannerServiceClient {
return accessgraphsecretsv1pb.NewSecretsScannerServiceClient(c.APIClient.GetConnection())
}

func (c *Client) IntegrationAWSOIDCClient() integrationv1.AWSOIDCServiceClient {
return integrationv1.NewAWSOIDCServiceClient(c.APIClient.GetConnection())
}
Expand Down Expand Up @@ -1415,6 +1420,8 @@ type ClientI interface {
// AccessGraphClient returns a client to the Access Graph gRPC service.
AccessGraphClient() accessgraphv1.AccessGraphServiceClient

AccessGraphSecretsScannerClient() accessgraphsecretsv1pb.SecretsScannerServiceClient

// IntegrationAWSOIDCClient returns a client to the Integration AWS OIDC gRPC service.
IntegrationAWSOIDCClient() integrationv1.AWSOIDCServiceClient

Expand Down
14 changes: 14 additions & 0 deletions lib/services/access_graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,26 @@
package services

import (
"context"

"github.com/gravitational/trace"

accessgraphsecretspb "github.com/gravitational/teleport/api/gen/proto/go/teleport/accessgraph/v1"
"github.com/gravitational/teleport/api/types/accessgraph"
)

// AccessGraphSecretsGetter is an interface for getting access graph secrets.
type AccessGraphSecretsGetter interface {
// ListAllAuthorizedKeys lists all authorized keys stored in the backend.
ListAllAuthorizedKeys(ctx context.Context, pageSize int, pageToken string) ([]*accessgraphsecretspb.AuthorizedKey, string, error)
// ListAuthorizedKeysForServer lists all authorized keys for a given hostID.
ListAuthorizedKeysForServer(ctx context.Context, hostID string, pageSize int, pageToken string) ([]*accessgraphsecretspb.AuthorizedKey, string, error)
// ListAllPrivateKeys lists all private keys stored in the backend.
ListAllPrivateKeys(ctx context.Context, pageSize int, pageToken string) ([]*accessgraphsecretspb.PrivateKey, string, error)
// ListPrivateKeysForDevice lists all private keys for a given deviceID.
ListPrivateKeysForDevice(ctx context.Context, deviceID string, pageSize int, pageToken string) ([]*accessgraphsecretspb.PrivateKey, string, error)
}

// MarshalAccessGraphAuthorizedKey marshals a [accessgraphsecretspb.AuthorizedKey] resource to JSON.
func MarshalAccessGraphAuthorizedKey(in *accessgraphsecretspb.AuthorizedKey, opts ...MarshalOption) ([]byte, error) {
if err := accessgraph.ValidateAuthorizedKey(in); err != nil {
Expand Down
1 change: 1 addition & 0 deletions lib/services/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/gravitational/teleport/lib/utils"
)

// DevicesGetter allows to list all registered devices from storage.
type DevicesGetter interface {
ListDevices(ctx context.Context, pageSize int, pageToken string, view devicepb.DeviceView) (devices []*devicepb.Device, nextPageToken string, err error)
}
Expand Down

0 comments on commit bb53c1c

Please sign in to comment.