diff --git a/lib/devicetrust/assert/assert_test.go b/lib/devicetrust/assert/assert_test.go index 273b71b814eb9..b9068b222df2e 100644 --- a/lib/devicetrust/assert/assert_test.go +++ b/lib/devicetrust/assert/assert_test.go @@ -20,24 +20,22 @@ import ( "context" "testing" + "github.com/google/uuid" "github.com/gravitational/trace" "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" devicepb "github.com/gravitational/teleport/api/gen/proto/go/teleport/devicetrust/v1" "github.com/gravitational/teleport/lib/devicetrust/assert" "github.com/gravitational/teleport/lib/devicetrust/authn" - "github.com/gravitational/teleport/lib/devicetrust/enroll" "github.com/gravitational/teleport/lib/devicetrust/testenv" ) func TestCeremony(t *testing.T) { t.Parallel() - env := testenv.MustNew( - testenv.WithAutoCreateDevice(true), - ) + deviceID := uuid.NewString() - devicesClient := env.DevicesClient ctx := context.Background() macDev, err := testenv.NewFakeMacOSDevice() @@ -62,17 +60,17 @@ func TestCeremony(t *testing.T) { t.Run(test.name, func(t *testing.T) { t.Parallel() - // Enroll the device before we attempt authentication (auto creates device - // as part of ceremony) dev := test.dev - enrollC := enroll.Ceremony{ - GetDeviceOSType: dev.GetDeviceOSType, - EnrollDeviceInit: dev.EnrollDeviceInit, - SignChallenge: dev.SignChallenge, - SolveTPMEnrollChallenge: dev.SolveTPMEnrollChallenge, - } - _, err := enrollC.Run(ctx, devicesClient, false, testenv.FakeEnrollmentToken) - require.NoError(t, err, "EnrollDevice errored") + + // Create an enrolled device. + devpb, pubKey, err := testenv.CreateEnrolledDevice(deviceID, dev) + require.NoError(t, err, "CreateEnrolledDevice errored") + + env := testenv.MustNew( + testenv.WithAutoCreateDevice(true), + // Register the enrolled device with the service. + testenv.WithPreEnrolledDevice(devpb, pubKey), + ) assertC, err := assert.NewCeremony(assert.WithNewAuthnCeremonyFunc(func() *authn.Ceremony { return &authn.Ceremony{ @@ -87,78 +85,83 @@ func TestCeremony(t *testing.T) { })) require.NoError(t, err, "NewCeremony errored") - authnStream, err := devicesClient.AuthenticateDevice(ctx) - require.NoError(t, err, "AuthenticateDevice errored") - - // Typically this would be some other, non-DeviceTrustService stream, but - // here this is a good way to test it (as it runs actual fake device - // authn) - if err := assertC.Run(ctx, &assertStreamAdapter{ - stream: authnStream, - }); err != nil { - t.Errorf("Run returned err=%q, want nil", err) - } + clientToServer := make(chan *devicepb.AssertDeviceRequest) + serverToClient := make(chan *devicepb.AssertDeviceResponse) + + group, ctx := errgroup.WithContext(ctx) + + // Run the client side of the ceremony. + group.Go(func() error { + err := assertC.Run(ctx, &assertStreamClientAdapter{ + ctx: ctx, + clientToServer: clientToServer, + serverToClient: serverToClient, + }) + return trace.Wrap(err, "server AssertDevice errored") + }) + + serverAssertC, err := env.Service.CreateAssertCeremony() + require.NoError(t, err, "CreateAssertCeremony errored") + // Run the server side of the ceremony. + group.Go(func() error { + _, err := serverAssertC.AssertDevice(ctx, &assertStreamServerAdapter{ + ctx: ctx, + clientToServer: clientToServer, + serverToClient: serverToClient, + }) + return trace.Wrap(err, "server AssertDevice errored") + }) + + err = group.Wait() + require.NoError(t, err, "group.Wait errored") }) } } -type assertStreamAdapter struct { - stream devicepb.DeviceTrustService_AuthenticateDeviceClient +type assertStreamClientAdapter struct { + ctx context.Context + clientToServer chan *devicepb.AssertDeviceRequest + serverToClient chan *devicepb.AssertDeviceResponse } -func (s *assertStreamAdapter) Recv() (*devicepb.AssertDeviceResponse, error) { - resp, err := s.stream.Recv() - if err != nil { - return nil, err +func (s *assertStreamClientAdapter) Recv() (*devicepb.AssertDeviceResponse, error) { + select { + case resp := <-s.serverToClient: + return resp, nil + case <-s.ctx.Done(): + return nil, trace.Wrap(s.ctx.Err()) } +} - switch resp.Payload.(type) { - case *devicepb.AuthenticateDeviceResponse_Challenge: - return &devicepb.AssertDeviceResponse{ - Payload: &devicepb.AssertDeviceResponse_Challenge{ - Challenge: resp.GetChallenge(), - }, - }, nil - case *devicepb.AuthenticateDeviceResponse_TpmChallenge: - return &devicepb.AssertDeviceResponse{ - Payload: &devicepb.AssertDeviceResponse_TpmChallenge{ - TpmChallenge: resp.GetTpmChallenge(), - }, - }, nil - case *devicepb.AuthenticateDeviceResponse_UserCertificates: - // UserCertificates means success. - return &devicepb.AssertDeviceResponse{ - Payload: &devicepb.AssertDeviceResponse_DeviceAsserted{ - DeviceAsserted: &devicepb.DeviceAsserted{}, - }, - }, nil - default: - return nil, trace.BadParameter("unexpected authenticate response payload: %T", resp.Payload) +func (s *assertStreamClientAdapter) Send(req *devicepb.AssertDeviceRequest) error { + select { + case s.clientToServer <- req: + return nil + case <-s.ctx.Done(): + return trace.Wrap(s.ctx.Err()) } } -func (s *assertStreamAdapter) Send(req *devicepb.AssertDeviceRequest) error { - authnReq := &devicepb.AuthenticateDeviceRequest{} - switch req.Payload.(type) { - case *devicepb.AssertDeviceRequest_Init: - init := req.GetInit() - authnReq.Payload = &devicepb.AuthenticateDeviceRequest_Init{ - Init: &devicepb.AuthenticateDeviceInit{ - CredentialId: init.CredentialId, - DeviceData: init.DeviceData, - }, - } - case *devicepb.AssertDeviceRequest_ChallengeResponse: - authnReq.Payload = &devicepb.AuthenticateDeviceRequest_ChallengeResponse{ - ChallengeResponse: req.GetChallengeResponse(), - } - case *devicepb.AssertDeviceRequest_TpmChallengeResponse: - authnReq.Payload = &devicepb.AuthenticateDeviceRequest_TpmChallengeResponse{ - TpmChallengeResponse: req.GetTpmChallengeResponse(), - } - default: - return trace.BadParameter("unexpected assert request payload: %T", req.Payload) +type assertStreamServerAdapter struct { + ctx context.Context + clientToServer chan *devicepb.AssertDeviceRequest + serverToClient chan *devicepb.AssertDeviceResponse +} + +func (s *assertStreamServerAdapter) Recv() (*devicepb.AssertDeviceRequest, error) { + select { + case req := <-s.clientToServer: + return req, nil + case <-s.ctx.Done(): + return nil, trace.Wrap(s.ctx.Err()) } +} - return s.stream.Send(authnReq) +func (s *assertStreamServerAdapter) Send(resp *devicepb.AssertDeviceResponse) error { + select { + case s.serverToClient <- resp: + return nil + case <-s.ctx.Done(): + return trace.Wrap(s.ctx.Err()) + } } diff --git a/lib/devicetrust/testenv/fake_device_service.go b/lib/devicetrust/testenv/fake_device_service.go index f9ee9bc63e14a..346830b77178c 100644 --- a/lib/devicetrust/testenv/fake_device_service.go +++ b/lib/devicetrust/testenv/fake_device_service.go @@ -35,6 +35,7 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" devicepb "github.com/gravitational/teleport/api/gen/proto/go/teleport/devicetrust/v1" + "github.com/gravitational/teleport/lib/devicetrust/assertserver" ) // FakeEnrollmentToken is a "free", never spent enrollment token. @@ -483,6 +484,58 @@ func enrollMacOS(stream devicepb.DeviceTrustService_EnrollDeviceServer, initReq }, ecPubKey, nil } +// CreateAssertCeremony creates a fake, server-side device assertion ceremony. +func (s *FakeDeviceService) CreateAssertCeremony() (assertserver.Ceremony, error) { + return s, nil +} + +// AssertDevice implements a fake, server-side device assertion ceremony. +// +// AssertDevice requires an enrolled device, so the challenge signature +// can be verified. +func (s *FakeDeviceService) AssertDevice(ctx context.Context, stream assertserver.AssertDeviceServerStream) (*devicepb.Device, error) { + // 1. Init. + req, err := stream.Recv() + if err != nil { + return nil, trace.Wrap(err) + } + initReq := req.GetInit() + switch { + case initReq == nil: + return nil, trace.BadParameter("init required") + case initReq.CredentialId == "": + return nil, trace.BadParameter("credential ID required") + } + if err := validateCollectedData(initReq.DeviceData); err != nil { + return nil, trace.Wrap(err) + } + + s.mu.Lock() + defer s.mu.Unlock() + + dev, err := s.findDeviceByCredential(initReq.DeviceData, initReq.CredentialId) + if err != nil { + return nil, trace.Wrap(err) + } + + switch dev.pb.OsType { + case devicepb.OSType_OS_TYPE_MACOS: + err = authenticateDeviceMacOS(dev, assertStreamAdapter{stream: stream}) + case devicepb.OSType_OS_TYPE_LINUX, devicepb.OSType_OS_TYPE_WINDOWS: + err = authenticateDeviceTPM(assertStreamAdapter{stream: stream}) + default: + err = fmt.Errorf("unrecognized os type %q", dev.pb.OsType) + } + if err != nil { + return nil, trace.Wrap(err) + } + + // Success. + return dev.pb, trace.Wrap(stream.Send(&devicepb.AssertDeviceResponse{ + Payload: &devicepb.AssertDeviceResponse_DeviceAsserted{}, + })) +} + // AuthenticateDevice implements a fake, server-side device authentication // ceremony. // @@ -592,7 +645,7 @@ func (s *FakeDeviceService) spendDeviceWebToken(webToken *devicepb.DeviceWebToke return nil, trace.AccessDenied(invalidWebTokenMessage) } -func authenticateDeviceMacOS(dev *storedDevice, stream devicepb.DeviceTrustService_AuthenticateDeviceServer) error { +func authenticateDeviceMacOS(dev *storedDevice, stream authenticateDeviceStream) error { // 2. Challenge. chal, err := newChallenge() if err != nil { @@ -623,7 +676,7 @@ func authenticateDeviceMacOS(dev *storedDevice, stream devicepb.DeviceTrustServi return trace.Wrap(verifyChallenge(chal, chalResp.Signature, dev.pub)) } -func authenticateDeviceTPM(stream devicepb.DeviceTrustService_AuthenticateDeviceServer) error { +func authenticateDeviceTPM(stream authenticateDeviceStream) error { // Produce a nonce we can send in the challenge that we expect to see in // the EventLog field of the challenge response. nonce, err := randomBytes() @@ -718,3 +771,64 @@ func verifyChallenge(chal, sig []byte, pub *ecdsa.PublicKey) error { } return nil } + +type authenticateDeviceStream interface { + Recv() (*devicepb.AuthenticateDeviceRequest, error) + Send(*devicepb.AuthenticateDeviceResponse) error +} + +// assertStreamAdapter adapts an [assertserver.AssertDeviceServerStream] to an +// [authenticateDeviceStream]. +type assertStreamAdapter struct { + stream assertserver.AssertDeviceServerStream +} + +func (s assertStreamAdapter) Recv() (*devicepb.AuthenticateDeviceRequest, error) { + req, err := s.stream.Recv() + if err != nil { + return nil, trace.Wrap(err) + } + + // Convert AssertDeviceRequest to AuthenticateDeviceRequest. + if req == nil || req.Payload == nil { + return nil, trace.BadParameter("assert request payload required") + } + authnReq := &devicepb.AuthenticateDeviceRequest{} + switch req.Payload.(type) { + case *devicepb.AssertDeviceRequest_ChallengeResponse: + authnReq.Payload = &devicepb.AuthenticateDeviceRequest_ChallengeResponse{ + ChallengeResponse: req.GetChallengeResponse(), + } + case *devicepb.AssertDeviceRequest_TpmChallengeResponse: + authnReq.Payload = &devicepb.AuthenticateDeviceRequest_TpmChallengeResponse{ + TpmChallengeResponse: req.GetTpmChallengeResponse(), + } + default: + return nil, trace.BadParameter("unexpected assert request payload: %T", req.Payload) + } + + return authnReq, nil +} + +func (s assertStreamAdapter) Send(authnResp *devicepb.AuthenticateDeviceResponse) error { + if authnResp == nil || authnResp.Payload == nil { + return trace.BadParameter("authenticate response payload required") + } + + // Convert AuthenticateDeviceResponse to AssertDeviceResponse. + resp := &devicepb.AssertDeviceResponse{} + switch authnResp.Payload.(type) { + case *devicepb.AuthenticateDeviceResponse_Challenge: + resp.Payload = &devicepb.AssertDeviceResponse_Challenge{ + Challenge: authnResp.GetChallenge(), + } + case *devicepb.AuthenticateDeviceResponse_TpmChallenge: + resp.Payload = &devicepb.AssertDeviceResponse_TpmChallenge{ + TpmChallenge: authnResp.GetTpmChallenge(), + } + default: + return trace.BadParameter("unexpected authentication response payload: %T", authnResp.Payload) + } + + return trace.Wrap(s.stream.Send(resp)) +} diff --git a/lib/devicetrust/testenv/testenv.go b/lib/devicetrust/testenv/testenv.go index 7dae3c97a07cb..a0226ee030b0e 100644 --- a/lib/devicetrust/testenv/testenv.go +++ b/lib/devicetrust/testenv/testenv.go @@ -20,6 +20,8 @@ package testenv import ( "context" + "crypto/ecdsa" + "crypto/x509" "net" "time" @@ -27,6 +29,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/test/bufconn" + "google.golang.org/protobuf/types/known/timestamppb" devicepb "github.com/gravitational/teleport/api/gen/proto/go/teleport/devicetrust/v1" "github.com/gravitational/teleport/api/utils/grpc/interceptors" @@ -45,6 +48,24 @@ func WithAutoCreateDevice(b bool) Opt { } } +// WithPreEnrolledDevice registers a device with the service without having to enroll it. +// This is useful for testing device authentication flows. +// [pub] is the public key of the macOS device and is used to verify the device. TPM devices +// do not require a public key and should pass nil. +func WithPreEnrolledDevice(dev *devicepb.Device, pub *ecdsa.PublicKey) Opt { + return func(e *E) { + e.Service.mu.Lock() + defer e.Service.mu.Unlock() + e.Service.devices = append(e.Service.devices, + storedDevice{ + pb: dev, + enrollToken: FakeEnrollmentToken, + pub: pub, + }, + ) + } +} + // E is an integrated test environment for device trust. type E struct { DevicesClient devicepb.DeviceTrustServiceClient @@ -150,3 +171,39 @@ type FakeDevice interface { SolveTPMAuthnDeviceChallenge(challenge *devicepb.TPMAuthenticateDeviceChallenge) (*devicepb.TPMAuthenticateDeviceChallengeResponse, error) GetDeviceCredential() *devicepb.DeviceCredential } + +// CreateEnrolledDevice converts a FakeDevice into a [*devicepb.Device] whose EnrollStatus is +// DEVICE_ENROLL_STATUS_ENROLLED and Id set to deviceID. It also returns the public key of the +// device if the device is a macOS device, otherwise it returns nil. +func CreateEnrolledDevice(deviceID string, d FakeDevice) (*devicepb.Device, *ecdsa.PublicKey, error) { + now := timestamppb.Now() + initReq, err := d.EnrollDeviceInit() + if err != nil { + return nil, nil, trace.Wrap(err) + } + + var pub *ecdsa.PublicKey + if d.GetDeviceOSType() == devicepb.OSType_OS_TYPE_MACOS { + pubKey, err := x509.ParsePKIXPublicKey(initReq.Macos.PublicKeyDer) + if err != nil { + return nil, nil, trace.Wrap(err) + } + var ok bool + pub, ok = pubKey.(*ecdsa.PublicKey) + if !ok { + return nil, nil, trace.BadParameter("expected ECDSA public key, got %T", pubKey) + } + } + + dev := &devicepb.Device{ + ApiVersion: "v1", + Id: deviceID, + OsType: d.GetDeviceOSType(), + AssetTag: initReq.DeviceData.SerialNumber, + CreateTime: now, + UpdateTime: now, + Credential: d.GetDeviceCredential(), + EnrollStatus: devicepb.DeviceEnrollStatus_DEVICE_ENROLL_STATUS_ENROLLED, + } + return dev, pub, nil +}