Skip to content

Commit

Permalink
[sec_scan][17] add AssertDevice to FakeDeviceService (#44159)
Browse files Browse the repository at this point in the history
* [sec_scan][17] add `AssertDevice` to `FakeDeviceService`

This PR introduces a `AssertDevice` logic into `FakeDeviceService` to authenticate devices during unit tests using device trust credentials.

This PR is part of gravitational/access-graph#637.

Signed-off-by: Tiago Silva <[email protected]>

* simplify assert tests

* Update lib/devicetrust/assert/assert_test.go

Co-authored-by: Alan Parra <[email protected]>

---------

Signed-off-by: Tiago Silva <[email protected]>
Co-authored-by: Alan Parra <[email protected]>
  • Loading branch information
tigrato and codingllama authored Jul 15, 2024
1 parent c669027 commit e05dd7f
Show file tree
Hide file tree
Showing 3 changed files with 252 additions and 78 deletions.
155 changes: 79 additions & 76 deletions lib/devicetrust/assert/assert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,22 @@ import (
"context"
"testing"

"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"

devicepb "github.com/gravitational/teleport/api/gen/proto/go/teleport/devicetrust/v1"
"github.com/gravitational/teleport/lib/devicetrust/assert"
"github.com/gravitational/teleport/lib/devicetrust/authn"
"github.com/gravitational/teleport/lib/devicetrust/enroll"
"github.com/gravitational/teleport/lib/devicetrust/testenv"
)

func TestCeremony(t *testing.T) {
t.Parallel()

env := testenv.MustNew(
testenv.WithAutoCreateDevice(true),
)
deviceID := uuid.NewString()

devicesClient := env.DevicesClient
ctx := context.Background()

macDev, err := testenv.NewFakeMacOSDevice()
Expand All @@ -62,17 +60,17 @@ func TestCeremony(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
t.Parallel()

// Enroll the device before we attempt authentication (auto creates device
// as part of ceremony)
dev := test.dev
enrollC := enroll.Ceremony{
GetDeviceOSType: dev.GetDeviceOSType,
EnrollDeviceInit: dev.EnrollDeviceInit,
SignChallenge: dev.SignChallenge,
SolveTPMEnrollChallenge: dev.SolveTPMEnrollChallenge,
}
_, err := enrollC.Run(ctx, devicesClient, false, testenv.FakeEnrollmentToken)
require.NoError(t, err, "EnrollDevice errored")

// Create an enrolled device.
devpb, pubKey, err := testenv.CreateEnrolledDevice(deviceID, dev)
require.NoError(t, err, "CreateEnrolledDevice errored")

env := testenv.MustNew(
testenv.WithAutoCreateDevice(true),
// Register the enrolled device with the service.
testenv.WithPreEnrolledDevice(devpb, pubKey),
)

assertC, err := assert.NewCeremony(assert.WithNewAuthnCeremonyFunc(func() *authn.Ceremony {
return &authn.Ceremony{
Expand All @@ -87,78 +85,83 @@ func TestCeremony(t *testing.T) {
}))
require.NoError(t, err, "NewCeremony errored")

authnStream, err := devicesClient.AuthenticateDevice(ctx)
require.NoError(t, err, "AuthenticateDevice errored")

// Typically this would be some other, non-DeviceTrustService stream, but
// here this is a good way to test it (as it runs actual fake device
// authn)
if err := assertC.Run(ctx, &assertStreamAdapter{
stream: authnStream,
}); err != nil {
t.Errorf("Run returned err=%q, want nil", err)
}
clientToServer := make(chan *devicepb.AssertDeviceRequest)
serverToClient := make(chan *devicepb.AssertDeviceResponse)

group, ctx := errgroup.WithContext(ctx)

// Run the client side of the ceremony.
group.Go(func() error {
err := assertC.Run(ctx, &assertStreamClientAdapter{
ctx: ctx,
clientToServer: clientToServer,
serverToClient: serverToClient,
})
return trace.Wrap(err, "server AssertDevice errored")
})

serverAssertC, err := env.Service.CreateAssertCeremony()
require.NoError(t, err, "CreateAssertCeremony errored")
// Run the server side of the ceremony.
group.Go(func() error {
_, err := serverAssertC.AssertDevice(ctx, &assertStreamServerAdapter{
ctx: ctx,
clientToServer: clientToServer,
serverToClient: serverToClient,
})
return trace.Wrap(err, "server AssertDevice errored")
})

err = group.Wait()
require.NoError(t, err, "group.Wait errored")
})
}
}

type assertStreamAdapter struct {
stream devicepb.DeviceTrustService_AuthenticateDeviceClient
type assertStreamClientAdapter struct {
ctx context.Context
clientToServer chan *devicepb.AssertDeviceRequest
serverToClient chan *devicepb.AssertDeviceResponse
}

func (s *assertStreamAdapter) Recv() (*devicepb.AssertDeviceResponse, error) {
resp, err := s.stream.Recv()
if err != nil {
return nil, err
func (s *assertStreamClientAdapter) Recv() (*devicepb.AssertDeviceResponse, error) {
select {
case resp := <-s.serverToClient:
return resp, nil
case <-s.ctx.Done():
return nil, trace.Wrap(s.ctx.Err())
}
}

switch resp.Payload.(type) {
case *devicepb.AuthenticateDeviceResponse_Challenge:
return &devicepb.AssertDeviceResponse{
Payload: &devicepb.AssertDeviceResponse_Challenge{
Challenge: resp.GetChallenge(),
},
}, nil
case *devicepb.AuthenticateDeviceResponse_TpmChallenge:
return &devicepb.AssertDeviceResponse{
Payload: &devicepb.AssertDeviceResponse_TpmChallenge{
TpmChallenge: resp.GetTpmChallenge(),
},
}, nil
case *devicepb.AuthenticateDeviceResponse_UserCertificates:
// UserCertificates means success.
return &devicepb.AssertDeviceResponse{
Payload: &devicepb.AssertDeviceResponse_DeviceAsserted{
DeviceAsserted: &devicepb.DeviceAsserted{},
},
}, nil
default:
return nil, trace.BadParameter("unexpected authenticate response payload: %T", resp.Payload)
func (s *assertStreamClientAdapter) Send(req *devicepb.AssertDeviceRequest) error {
select {
case s.clientToServer <- req:
return nil
case <-s.ctx.Done():
return trace.Wrap(s.ctx.Err())
}
}

func (s *assertStreamAdapter) Send(req *devicepb.AssertDeviceRequest) error {
authnReq := &devicepb.AuthenticateDeviceRequest{}
switch req.Payload.(type) {
case *devicepb.AssertDeviceRequest_Init:
init := req.GetInit()
authnReq.Payload = &devicepb.AuthenticateDeviceRequest_Init{
Init: &devicepb.AuthenticateDeviceInit{
CredentialId: init.CredentialId,
DeviceData: init.DeviceData,
},
}
case *devicepb.AssertDeviceRequest_ChallengeResponse:
authnReq.Payload = &devicepb.AuthenticateDeviceRequest_ChallengeResponse{
ChallengeResponse: req.GetChallengeResponse(),
}
case *devicepb.AssertDeviceRequest_TpmChallengeResponse:
authnReq.Payload = &devicepb.AuthenticateDeviceRequest_TpmChallengeResponse{
TpmChallengeResponse: req.GetTpmChallengeResponse(),
}
default:
return trace.BadParameter("unexpected assert request payload: %T", req.Payload)
type assertStreamServerAdapter struct {
ctx context.Context
clientToServer chan *devicepb.AssertDeviceRequest
serverToClient chan *devicepb.AssertDeviceResponse
}

func (s *assertStreamServerAdapter) Recv() (*devicepb.AssertDeviceRequest, error) {
select {
case req := <-s.clientToServer:
return req, nil
case <-s.ctx.Done():
return nil, trace.Wrap(s.ctx.Err())
}
}

return s.stream.Send(authnReq)
func (s *assertStreamServerAdapter) Send(resp *devicepb.AssertDeviceResponse) error {
select {
case s.serverToClient <- resp:
return nil
case <-s.ctx.Done():
return trace.Wrap(s.ctx.Err())
}
}
118 changes: 116 additions & 2 deletions lib/devicetrust/testenv/fake_device_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
"google.golang.org/protobuf/types/known/timestamppb"

devicepb "github.com/gravitational/teleport/api/gen/proto/go/teleport/devicetrust/v1"
"github.com/gravitational/teleport/lib/devicetrust/assertserver"
)

// FakeEnrollmentToken is a "free", never spent enrollment token.
Expand Down Expand Up @@ -483,6 +484,58 @@ func enrollMacOS(stream devicepb.DeviceTrustService_EnrollDeviceServer, initReq
}, ecPubKey, nil
}

// CreateAssertCeremony creates a fake, server-side device assertion ceremony.
func (s *FakeDeviceService) CreateAssertCeremony() (assertserver.Ceremony, error) {
return s, nil
}

// AssertDevice implements a fake, server-side device assertion ceremony.
//
// AssertDevice requires an enrolled device, so the challenge signature
// can be verified.
func (s *FakeDeviceService) AssertDevice(ctx context.Context, stream assertserver.AssertDeviceServerStream) (*devicepb.Device, error) {
// 1. Init.
req, err := stream.Recv()
if err != nil {
return nil, trace.Wrap(err)
}
initReq := req.GetInit()
switch {
case initReq == nil:
return nil, trace.BadParameter("init required")
case initReq.CredentialId == "":
return nil, trace.BadParameter("credential ID required")
}
if err := validateCollectedData(initReq.DeviceData); err != nil {
return nil, trace.Wrap(err)
}

s.mu.Lock()
defer s.mu.Unlock()

dev, err := s.findDeviceByCredential(initReq.DeviceData, initReq.CredentialId)
if err != nil {
return nil, trace.Wrap(err)
}

switch dev.pb.OsType {
case devicepb.OSType_OS_TYPE_MACOS:
err = authenticateDeviceMacOS(dev, assertStreamAdapter{stream: stream})
case devicepb.OSType_OS_TYPE_LINUX, devicepb.OSType_OS_TYPE_WINDOWS:
err = authenticateDeviceTPM(assertStreamAdapter{stream: stream})
default:
err = fmt.Errorf("unrecognized os type %q", dev.pb.OsType)
}
if err != nil {
return nil, trace.Wrap(err)
}

// Success.
return dev.pb, trace.Wrap(stream.Send(&devicepb.AssertDeviceResponse{
Payload: &devicepb.AssertDeviceResponse_DeviceAsserted{},
}))
}

// AuthenticateDevice implements a fake, server-side device authentication
// ceremony.
//
Expand Down Expand Up @@ -592,7 +645,7 @@ func (s *FakeDeviceService) spendDeviceWebToken(webToken *devicepb.DeviceWebToke
return nil, trace.AccessDenied(invalidWebTokenMessage)
}

func authenticateDeviceMacOS(dev *storedDevice, stream devicepb.DeviceTrustService_AuthenticateDeviceServer) error {
func authenticateDeviceMacOS(dev *storedDevice, stream authenticateDeviceStream) error {
// 2. Challenge.
chal, err := newChallenge()
if err != nil {
Expand Down Expand Up @@ -623,7 +676,7 @@ func authenticateDeviceMacOS(dev *storedDevice, stream devicepb.DeviceTrustServi
return trace.Wrap(verifyChallenge(chal, chalResp.Signature, dev.pub))
}

func authenticateDeviceTPM(stream devicepb.DeviceTrustService_AuthenticateDeviceServer) error {
func authenticateDeviceTPM(stream authenticateDeviceStream) error {
// Produce a nonce we can send in the challenge that we expect to see in
// the EventLog field of the challenge response.
nonce, err := randomBytes()
Expand Down Expand Up @@ -718,3 +771,64 @@ func verifyChallenge(chal, sig []byte, pub *ecdsa.PublicKey) error {
}
return nil
}

type authenticateDeviceStream interface {
Recv() (*devicepb.AuthenticateDeviceRequest, error)
Send(*devicepb.AuthenticateDeviceResponse) error
}

// assertStreamAdapter adapts an [assertserver.AssertDeviceServerStream] to an
// [authenticateDeviceStream].
type assertStreamAdapter struct {
stream assertserver.AssertDeviceServerStream
}

func (s assertStreamAdapter) Recv() (*devicepb.AuthenticateDeviceRequest, error) {
req, err := s.stream.Recv()
if err != nil {
return nil, trace.Wrap(err)
}

// Convert AssertDeviceRequest to AuthenticateDeviceRequest.
if req == nil || req.Payload == nil {
return nil, trace.BadParameter("assert request payload required")
}
authnReq := &devicepb.AuthenticateDeviceRequest{}
switch req.Payload.(type) {
case *devicepb.AssertDeviceRequest_ChallengeResponse:
authnReq.Payload = &devicepb.AuthenticateDeviceRequest_ChallengeResponse{
ChallengeResponse: req.GetChallengeResponse(),
}
case *devicepb.AssertDeviceRequest_TpmChallengeResponse:
authnReq.Payload = &devicepb.AuthenticateDeviceRequest_TpmChallengeResponse{
TpmChallengeResponse: req.GetTpmChallengeResponse(),
}
default:
return nil, trace.BadParameter("unexpected assert request payload: %T", req.Payload)
}

return authnReq, nil
}

func (s assertStreamAdapter) Send(authnResp *devicepb.AuthenticateDeviceResponse) error {
if authnResp == nil || authnResp.Payload == nil {
return trace.BadParameter("authenticate response payload required")
}

// Convert AuthenticateDeviceResponse to AssertDeviceResponse.
resp := &devicepb.AssertDeviceResponse{}
switch authnResp.Payload.(type) {
case *devicepb.AuthenticateDeviceResponse_Challenge:
resp.Payload = &devicepb.AssertDeviceResponse_Challenge{
Challenge: authnResp.GetChallenge(),
}
case *devicepb.AuthenticateDeviceResponse_TpmChallenge:
resp.Payload = &devicepb.AssertDeviceResponse_TpmChallenge{
TpmChallenge: authnResp.GetTpmChallenge(),
}
default:
return trace.BadParameter("unexpected authentication response payload: %T", authnResp.Payload)
}

return trace.Wrap(s.stream.Send(resp))
}
Loading

0 comments on commit e05dd7f

Please sign in to comment.