diff --git a/api/types/accessgraph/authorized_key.go b/api/types/accessgraph/authorized_key.go new file mode 100644 index 0000000000000..5532f39bc2775 --- /dev/null +++ b/api/types/accessgraph/authorized_key.go @@ -0,0 +1,90 @@ +/* +Copyright 2024 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package accessgraph + +import ( + "time" + + "github.com/gravitational/trace" + "google.golang.org/protobuf/types/known/timestamppb" + + accessgraphv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/accessgraph/v1" + headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1" + "github.com/gravitational/teleport/api/types" +) + +const ( + authorizedKeyDefaultKeyTTL = 8 * time.Hour +) + +// NewAuthorizedKey creates a new SSH authorized key resource. +func NewAuthorizedKey(spec *accessgraphv1pb.AuthorizedKeySpec) (*accessgraphv1pb.AuthorizedKey, error) { + name := authKeyHashNameKey(spec) + authKey := &accessgraphv1pb.AuthorizedKey{ + Kind: types.KindAccessGraphSecretAuthorizedKey, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: name, + Expires: timestamppb.New( + time.Now().Add(authorizedKeyDefaultKeyTTL), + ), + }, + Spec: spec, + } + if err := ValidateAuthorizedKey(authKey); err != nil { + return nil, trace.Wrap(err) + } + + return authKey, nil +} + +// ValidateAuthorizedKey checks that required parameters are set +// for the specified AuthorizedKey +func ValidateAuthorizedKey(k *accessgraphv1pb.AuthorizedKey) error { + if k == nil { + return trace.BadParameter("AuthorizedKey is nil") + } + if k.Metadata == nil { + return trace.BadParameter("Metadata is nil") + } + if k.Spec == nil { + return trace.BadParameter("Spec is nil") + } + + if k.Spec.HostId == "" { + return trace.BadParameter("HostId is unset") + } + if k.Spec.HostUser == "" { + return trace.BadParameter("HostUser is unset") + } + if k.Spec.KeyFingerprint == "" { + return trace.BadParameter("KeyFingerprint is unset") + } + + if k.Metadata.Name == "" { + return trace.BadParameter("Name is unset") + } + if k.Metadata.Name != authKeyHashNameKey(k.Spec) { + return trace.BadParameter("Name must be derived from the key fields") + } + + return nil +} + +func authKeyHashNameKey(k *accessgraphv1pb.AuthorizedKeySpec) string { + return hashComp(k.HostId, k.HostUser, k.KeyFingerprint) +} diff --git a/api/types/accessgraph/authorized_key_test.go b/api/types/accessgraph/authorized_key_test.go new file mode 100644 index 0000000000000..d28d72ebd3e7e --- /dev/null +++ b/api/types/accessgraph/authorized_key_test.go @@ -0,0 +1,90 @@ +/* +Copyright 2024 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package accessgraph + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/testing/protocmp" + + accessgraphv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/accessgraph/v1" +) + +func TestAuthorizedKey(t *testing.T) { + tests := []struct { + name string + spec *accessgraphv1pb.AuthorizedKeySpec + errValidation require.ErrorAssertionFunc + }{ + { + name: "valid", + spec: &accessgraphv1pb.AuthorizedKeySpec{ + HostId: uuid.New().String(), + KeyFingerprint: "fingerprint", + HostUser: "user", + }, + errValidation: require.NoError, + }, + { + name: "missing fingerprint", + spec: &accessgraphv1pb.AuthorizedKeySpec{ + HostId: uuid.New().String(), + KeyFingerprint: "", + HostUser: "user", + }, + errValidation: func(t require.TestingT, err error, i ...any) { + require.ErrorContains(t, err, "KeyFingerprint is unset") + }, + }, + { + name: "missing user", + spec: &accessgraphv1pb.AuthorizedKeySpec{ + HostId: uuid.New().String(), + KeyFingerprint: "fingerprint", + HostUser: "", + }, + errValidation: func(t require.TestingT, err error, i ...any) { + require.ErrorContains(t, err, "HostUser is unset") + }, + }, + { + name: "missing HostID", + spec: &accessgraphv1pb.AuthorizedKeySpec{ + KeyFingerprint: "fingerprint", + HostUser: "user", + }, + errValidation: func(t require.TestingT, err error, i ...any) { + require.ErrorContains(t, err, "HostId is unset") + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + privKey, err := NewAuthorizedKey(tt.spec) + tt.errValidation(t, err) + if err != nil { + return + } + require.NotEmpty(t, privKey.Metadata.Name) + require.Empty(t, cmp.Diff(tt.spec, privKey.Spec, protocmp.Transform())) + + }) + } +} diff --git a/api/types/accessgraph/private_key.go b/api/types/accessgraph/private_key.go new file mode 100644 index 0000000000000..57e0874ed040a --- /dev/null +++ b/api/types/accessgraph/private_key.go @@ -0,0 +1,108 @@ +/* +Copyright 2024 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package accessgraph + +import ( + "crypto/sha256" + "encoding/hex" + + "github.com/gravitational/trace" + + accessgraphv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/accessgraph/v1" + headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1" + "github.com/gravitational/teleport/api/types" +) + +// NewPrivateKey creates a new SSH Private key resource with a generated name based on the spec. +func NewPrivateKey(spec *accessgraphv1pb.PrivateKeySpec) (*accessgraphv1pb.PrivateKey, error) { + name := privKeyHashNameKey(spec) + v, err := NewPrivateKeyWithName(name, spec) + + return v, trace.Wrap(err) +} + +// NewPrivateKeyWithName creates a new SSH Private key resource. +func NewPrivateKeyWithName(name string, spec *accessgraphv1pb.PrivateKeySpec) (*accessgraphv1pb.PrivateKey, error) { + privKey := &accessgraphv1pb.PrivateKey{ + Kind: types.KindAccessGraphSecretPrivateKey, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: name, + }, + Spec: spec, + } + if err := ValidatePrivateKey(privKey); err != nil { + return nil, trace.Wrap(err) + } + + return privKey, nil +} + +// ValidatePrivateKey checks that required parameters are set +// for the specified PrivateKey +func ValidatePrivateKey(k *accessgraphv1pb.PrivateKey) error { + if k == nil { + return trace.BadParameter("PrivateKey is nil") + } + if k.Metadata == nil { + return trace.BadParameter("Metadata is nil") + } + if k.Spec == nil { + return trace.BadParameter("Spec is nil") + } + + if k.Kind != types.KindAccessGraphSecretPrivateKey { + return trace.BadParameter("Kind is invalid") + } + + if k.Version != types.V1 { + return trace.BadParameter("Version is invalid") + } + + switch k.Spec.PublicKeyMode { + case accessgraphv1pb.PublicKeyMode_PUBLIC_KEY_MODE_PROTECTED, + accessgraphv1pb.PublicKeyMode_PUBLIC_KEY_MODE_PUB_FILE, + accessgraphv1pb.PublicKeyMode_PUBLIC_KEY_MODE_DERIVED: + default: + return trace.BadParameter("PublicKeyMode is invalid") + } + + if k.Spec.DeviceId == "" { + return trace.BadParameter("DeviceId is unset") + } + if k.Spec.PublicKeyFingerprint == "" && k.Spec.PublicKeyMode != accessgraphv1pb.PublicKeyMode_PUBLIC_KEY_MODE_PROTECTED { + return trace.BadParameter("PublicKeyFingerprint is unset") + } + + if k.Metadata.Name == "" { + return trace.BadParameter("Name is unset") + } + + return nil +} + +func privKeyHashNameKey(k *accessgraphv1pb.PrivateKeySpec) string { + return hashComp(k.DeviceId, k.PublicKeyFingerprint) +} + +func hashComp(values ...string) string { + h := sha256.New() + for _, value := range values { + h.Write([]byte(value)) + } + return hex.EncodeToString(h.Sum(nil)) +} diff --git a/api/types/accessgraph/private_key_test.go b/api/types/accessgraph/private_key_test.go new file mode 100644 index 0000000000000..c1c38bd21c49f --- /dev/null +++ b/api/types/accessgraph/private_key_test.go @@ -0,0 +1,119 @@ +/* +Copyright 2024 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package accessgraph + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/testing/protocmp" + + accessgraphv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/accessgraph/v1" +) + +func TestPrivateKey(t *testing.T) { + tests := []struct { + name string + spec *accessgraphv1pb.PrivateKeySpec + errValidation require.ErrorAssertionFunc + }{ + { + name: "valid derived", + spec: &accessgraphv1pb.PrivateKeySpec{ + DeviceId: uuid.New().String(), + PublicKeyFingerprint: "fingerprint", + PublicKeyMode: accessgraphv1pb.PublicKeyMode_PUBLIC_KEY_MODE_DERIVED, + }, + errValidation: require.NoError, + }, + { + name: "valid file", + spec: &accessgraphv1pb.PrivateKeySpec{ + DeviceId: uuid.New().String(), + PublicKeyFingerprint: "fingerprint", + PublicKeyMode: accessgraphv1pb.PublicKeyMode_PUBLIC_KEY_MODE_PUB_FILE, + }, + errValidation: require.NoError, + }, + { + name: "missing fingerprint derived", + spec: &accessgraphv1pb.PrivateKeySpec{ + DeviceId: uuid.New().String(), + PublicKeyFingerprint: "", + PublicKeyMode: accessgraphv1pb.PublicKeyMode_PUBLIC_KEY_MODE_DERIVED, + }, + errValidation: func(t require.TestingT, err error, i ...any) { + require.ErrorContains(t, err, "PublicKeyFingerprint is unset") + }, + }, + { + name: "missing fingerprint file", + spec: &accessgraphv1pb.PrivateKeySpec{ + DeviceId: uuid.New().String(), + PublicKeyFingerprint: "", + PublicKeyMode: accessgraphv1pb.PublicKeyMode_PUBLIC_KEY_MODE_PUB_FILE, + }, + errValidation: func(t require.TestingT, err error, i ...any) { + require.ErrorContains(t, err, "PublicKeyFingerprint is unset") + }, + }, + { + name: "valid protected", + spec: &accessgraphv1pb.PrivateKeySpec{ + DeviceId: uuid.New().String(), + PublicKeyFingerprint: "", /* empty fingerprint */ + PublicKeyMode: accessgraphv1pb.PublicKeyMode_PUBLIC_KEY_MODE_PROTECTED, + }, + errValidation: require.NoError, + }, + { + name: "invalid public key ode", + spec: &accessgraphv1pb.PrivateKeySpec{ + DeviceId: uuid.New().String(), + PublicKeyFingerprint: "fingerprint", + PublicKeyMode: 500, + }, + errValidation: func(t require.TestingT, err error, i ...any) { + require.ErrorContains(t, err, "PublicKeyMode is invalid") + }, + }, + { + name: "missing DeviceId", + spec: &accessgraphv1pb.PrivateKeySpec{ + PublicKeyFingerprint: "fingerprint", + PublicKeyMode: accessgraphv1pb.PublicKeyMode_PUBLIC_KEY_MODE_PROTECTED, + }, + errValidation: func(t require.TestingT, err error, i ...any) { + require.ErrorContains(t, err, "DeviceId is unset") + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + privKey, err := NewPrivateKey(tt.spec) + tt.errValidation(t, err) + if err != nil { + return + } + require.NotEmpty(t, privKey.Metadata.Name) + require.Empty(t, cmp.Diff(tt.spec, privKey.Spec, protocmp.Transform())) + + }) + } +} diff --git a/api/types/constants.go b/api/types/constants.go index d227c02beea86..a9edc19895139 100644 --- a/api/types/constants.go +++ b/api/types/constants.go @@ -521,6 +521,14 @@ const ( // KindUserNotificationState is a resource which tracks whether a user has clicked on or dismissed a notification. KindUserNotificationState = "user_notification_state" + // KindAccessGraphSecretAuthorizedKey is a authorized key entry found in + // a Teleport SSH node type. + KindAccessGraphSecretAuthorizedKey = "access_graph_authorized_key" + + // KindAccessGraphSecretPrivateKey is a private key entry found in + // a managed device. + KindAccessGraphSecretPrivateKey = "access_graph_private_key" + // KindVnetConfig is a resource which holds cluster-wide configuration for VNet. KindVnetConfig = "vnet_config" diff --git a/lib/services/access_graph.go b/lib/services/access_graph.go new file mode 100644 index 0000000000000..6ff41221b5a37 --- /dev/null +++ b/lib/services/access_graph.go @@ -0,0 +1,68 @@ +/* + * Teleport + * Copyright (C) 2023 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package services + +import ( + "github.com/gravitational/trace" + + accessgraphsecretspb "github.com/gravitational/teleport/api/gen/proto/go/teleport/accessgraph/v1" + "github.com/gravitational/teleport/api/types/accessgraph" +) + +// MarshalAccessGraphAuthorizedKey marshals a [accessgraphsecretspb.AuthorizedKey] resource to JSON. +func MarshalAccessGraphAuthorizedKey(in *accessgraphsecretspb.AuthorizedKey, opts ...MarshalOption) ([]byte, error) { + if err := accessgraph.ValidateAuthorizedKey(in); err != nil { + return nil, trace.Wrap(err) + } + + return MarshalProtoResource(in, opts...) +} + +// UnmarshalAccessGraphAuthorizedKey unmarshals a [accessgraphsecretspb.AuthorizedKey] resource from JSON. +func UnmarshalAccessGraphAuthorizedKey(data []byte, opts ...MarshalOption) (*accessgraphsecretspb.AuthorizedKey, error) { + out, err := UnmarshalProtoResource[*accessgraphsecretspb.AuthorizedKey](data, opts...) + if err != nil { + return nil, trace.Wrap(err) + } + if err := accessgraph.ValidateAuthorizedKey(out); err != nil { + return nil, trace.Wrap(err) + } + return out, nil +} + +// MarshalAccessGraphPrivateKey marshals a [accessgraphsecretspb.PrivateKey] resource to JSON. +func MarshalAccessGraphPrivateKey(in *accessgraphsecretspb.PrivateKey, opts ...MarshalOption) ([]byte, error) { + if err := accessgraph.ValidatePrivateKey(in); err != nil { + return nil, trace.Wrap(err) + } + + return MarshalProtoResource(in, opts...) +} + +// UnmarshalAccessGraphPrivateKey unmarshals a [accessgraphsecretspb.PrivateKey] resource from JSON. +func UnmarshalAccessGraphPrivateKey(data []byte, opts ...MarshalOption) (*accessgraphsecretspb.PrivateKey, error) { + out, err := UnmarshalProtoResource[*accessgraphsecretspb.PrivateKey](data, opts...) + if err != nil { + return nil, trace.Wrap(err) + } + if err := accessgraph.ValidatePrivateKey(out); err != nil { + return nil, trace.Wrap(err) + } + return out, nil +} diff --git a/lib/services/local/access_graph.go b/lib/services/local/access_graph.go new file mode 100644 index 0000000000000..4be7ec74ec6c3 --- /dev/null +++ b/lib/services/local/access_graph.go @@ -0,0 +1,164 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package local + +import ( + "context" + + "github.com/gravitational/trace" + + accessgraphsecretspb "github.com/gravitational/teleport/api/gen/proto/go/teleport/accessgraph/v1" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/backend" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/local/generic" +) + +const ( + authorizedKeysPrefix = "access_graph_ssh_authorized_keys" + privateKeysPrefix = "access_graph_ssh_private_keys" +) + +// AccessGraphSecretsService manages secrets found on Teleport Nodes and +// enrolled devices. +type AccessGraphSecretsService struct { + authorizedKeysSvc *generic.ServiceWrapper[*accessgraphsecretspb.AuthorizedKey] + privateKeysSvc *generic.ServiceWrapper[*accessgraphsecretspb.PrivateKey] +} + +// NewAccessGraphSecretsService returns a new Access Graph Secrets service. +// This service in Teleport is used to keep track of secrets found in Teleport +// Nodes and on enrolled devices. Currently, it only stores secrets related with +// SSH Keys. Future implementations might extend them. +func NewAccessGraphSecretsService(backend backend.Backend) (*AccessGraphSecretsService, error) { + authorizedKeysSvc, err := generic.NewServiceWrapper( + backend, + types.KindAccessGraphSecretAuthorizedKey, + authorizedKeysPrefix, + services.MarshalAccessGraphAuthorizedKey, + services.UnmarshalAccessGraphAuthorizedKey, + ) + if err != nil { + return nil, trace.Wrap(err) + } + + privateKeysSvc, err := generic.NewServiceWrapper( + backend, + types.KindAccessGraphSecretPrivateKey, + privateKeysPrefix, + services.MarshalAccessGraphPrivateKey, + services.UnmarshalAccessGraphPrivateKey, + ) + if err != nil { + return nil, trace.Wrap(err) + } + + return &AccessGraphSecretsService{ + authorizedKeysSvc: authorizedKeysSvc, + privateKeysSvc: privateKeysSvc, + }, nil +} + +// ListAllAuthorizedKeys lists all authorized keys stored in the backend. +func (k *AccessGraphSecretsService) ListAllAuthorizedKeys(ctx context.Context, pageSize int, pageToken string) ([]*accessgraphsecretspb.AuthorizedKey, string, error) { + out, next, err := k.authorizedKeysSvc.ListResources(ctx, pageSize, pageToken) + if err != nil { + return nil, "", trace.Wrap(err) + } + return out, next, nil +} + +// ListAuthorizedKeysForServer lists all authorized keys for a given hostID. +func (k *AccessGraphSecretsService) ListAuthorizedKeysForServer(ctx context.Context, hostID string, pageSize int, pageToken string) ([]*accessgraphsecretspb.AuthorizedKey, string, error) { + if hostID == "" { + return nil, "", trace.BadParameter("server name is required") + } + svc := k.authorizedKeysSvc.WithPrefix(hostID) + out, next, err := svc.ListResources(ctx, pageSize, pageToken) + if err != nil { + return nil, "", trace.Wrap(err) + } + return out, next, nil +} + +// UpsertAuthorizedKey upserts a new authorized key. +func (k *AccessGraphSecretsService) UpsertAuthorizedKey(ctx context.Context, in *accessgraphsecretspb.AuthorizedKey) (*accessgraphsecretspb.AuthorizedKey, error) { + svc := k.authorizedKeysSvc.WithPrefix(in.Spec.HostId) + out, err := svc.UpsertResource(ctx, in) + if err != nil { + return nil, trace.Wrap(err) + } + + return out, nil +} + +// DeleteAuthorizedKey deletes a specific authorized key. +func (k *AccessGraphSecretsService) DeleteAuthorizedKey(ctx context.Context, hostID, name string) error { + svc := k.authorizedKeysSvc.WithPrefix(hostID) + return trace.Wrap(svc.DeleteResource(ctx, name)) +} + +// DeleteAllAuthorizedKeys deletes all authorized keys. +func (k *AccessGraphSecretsService) DeleteAllAuthorizedKeys(ctx context.Context) error { + return trace.Wrap(k.authorizedKeysSvc.DeleteAllResources(ctx)) +} + +// ListAllPrivateKeys lists all private keys stored in the backend. +func (k *AccessGraphSecretsService) ListAllPrivateKeys(ctx context.Context, pageSize int, pageToken string) ([]*accessgraphsecretspb.PrivateKey, string, error) { + out, next, err := k.privateKeysSvc.ListResources(ctx, pageSize, pageToken) + if err != nil { + return nil, "", trace.Wrap(err) + } + return out, next, nil +} + +// ListPrivateKeysForDevice lists all private keys for a given deviceID. +func (k *AccessGraphSecretsService) ListPrivateKeysForDevice(ctx context.Context, deviceID string, pageSize int, pageToken string) ([]*accessgraphsecretspb.PrivateKey, string, error) { + if deviceID == "" { + return nil, "", trace.BadParameter("server name is required") + } + svc := k.privateKeysSvc.WithPrefix(deviceID) + out, next, err := svc.ListResources(ctx, pageSize, pageToken) + if err != nil { + return nil, "", trace.Wrap(err) + } + return out, next, nil +} + +// UpsertPrivateKey upserts a new private key. +func (k *AccessGraphSecretsService) UpsertPrivateKey(ctx context.Context, in *accessgraphsecretspb.PrivateKey) (*accessgraphsecretspb.PrivateKey, error) { + svc := k.privateKeysSvc.WithPrefix(in.Spec.DeviceId) + out, err := svc.UpsertResource(ctx, in) + if err != nil { + return nil, trace.Wrap(err) + } + + return out, nil +} + +// DeletePrivateKey deletes a specific private key. +func (k *AccessGraphSecretsService) DeletePrivateKey(ctx context.Context, deviceID, name string) error { + svc := k.privateKeysSvc.WithPrefix(deviceID) + return trace.Wrap(svc.DeleteResource(ctx, name)) +} + +// DeleteAllPrivateKeys deletes all private keys. +func (k *AccessGraphSecretsService) DeleteAllPrivateKeys(ctx context.Context) error { + return trace.Wrap(k.privateKeysSvc.DeleteAllResources(ctx)) +} diff --git a/lib/services/local/access_graph_test.go b/lib/services/local/access_graph_test.go new file mode 100644 index 0000000000000..2eda122f8c045 --- /dev/null +++ b/lib/services/local/access_graph_test.go @@ -0,0 +1,247 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package local + +import ( + "context" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/testing/protocmp" + + accessgraphsecretspb "github.com/gravitational/teleport/api/gen/proto/go/teleport/accessgraph/v1" + "github.com/gravitational/teleport/api/types/accessgraph" + "github.com/gravitational/teleport/lib/backend/memory" +) + +func TestAccessGraphAuthorizedKeys(t *testing.T) { + backend, err := memory.New(memory.Config{}) + require.NoError(t, err) + service, err := NewAccessGraphSecretsService(backend) + require.NoError(t, err) + + ctx := context.TODO() + pageSize := 10 + pageToken := "" + + // Test case 1: Empty list + keys, nextToken, err := service.ListAllAuthorizedKeys(ctx, pageSize, pageToken) + require.NoError(t, err) + require.Empty(t, keys) + require.Empty(t, nextToken) + + // Test case 2: Non-empty list + authorizedKeys := []*accessgraphsecretspb.AuthorizedKeySpec{ + { + HostId: "host1", + HostUser: "user1", + KeyFingerprint: "AAAAB3NzaC1yc2EAAAADAQABAAABAQC...", + }, + { + HostId: "host1", + HostUser: "user2", + KeyFingerprint: "AAAAB3NzaC1yc2EAAAADAQABAAABAQC...", + }, + { + HostId: "host2", + HostUser: "user1", + KeyFingerprint: "AAAAB3NzaC1yc2EAAAADAQABAAABAQC...", + }, + { + HostId: "host2", + HostUser: "user2", + KeyFingerprint: "AAAAB3NzaC1yc2EAAAADAQABAAABAQC...", + }, + } + var authKeys []*accessgraphsecretspb.AuthorizedKey + for _, key := range authorizedKeys { + authKey, err := accessgraph.NewAuthorizedKey(key) + require.NoError(t, err) + _, err = service.UpsertAuthorizedKey(ctx, authKey) + require.NoError(t, err) + authKeys = append(authKeys, authKey) + } + + keys, nextToken, err = service.ListAllAuthorizedKeys(ctx, pageSize, pageToken) + require.NoError(t, err) + require.Empty(t, cmp.Diff(authKeys, keys, + protocmp.Transform(), + cmpopts.SortSlices(func(a, b *accessgraphsecretspb.AuthorizedKey) bool { + return a.Metadata.Name < b.Metadata.Name + }))) + require.Empty(t, nextToken) + + // Test case 3: Pagination + pageSize = 2 + pageToken = "" + keys, nextToken, err = service.ListAllAuthorizedKeys(ctx, pageSize, pageToken) + require.NoError(t, err) + require.Len(t, keys, pageSize) + require.NotEmpty(t, nextToken) + + pageToken = nextToken + keys, nextToken, err = service.ListAllAuthorizedKeys(ctx, pageSize, pageToken) + require.NoError(t, err) + require.Len(t, keys, pageSize) + require.Empty(t, nextToken) + + // Test case 4: List authorized keys for server + pageToken = "" + keysHost1, nextToken, err := service.ListAuthorizedKeysForServer(ctx, "host1", pageSize, pageToken) + require.NoError(t, err) + require.Len(t, keys, 2) + require.Empty(t, nextToken) + keysHost2, nextToken, err := service.ListAuthorizedKeysForServer(ctx, "host2", pageSize, pageToken) + require.NoError(t, err) + require.Len(t, keys, 2) + require.Empty(t, nextToken) + require.NotEqual(t, keysHost1, keysHost2) + + // Test case 5: List authorized keys for server with pagination + pageToken = "" + pageSize = 1 + keys, nextToken, err = service.ListAuthorizedKeysForServer(ctx, "host1", pageSize, pageToken) + require.NoError(t, err) + require.Len(t, keys, 1) + require.NotEmpty(t, nextToken) + + pageToken = nextToken + keys, nextToken, err = service.ListAuthorizedKeysForServer(ctx, "host1", pageSize, pageToken) + require.NoError(t, err) + require.Len(t, keys, 1) + require.Empty(t, nextToken) + + // Test case 6: Delete all + err = service.DeleteAllAuthorizedKeys(ctx) + require.NoError(t, err) + keys, nextToken, err = service.ListAllAuthorizedKeys(ctx, pageSize, pageToken) + require.NoError(t, err) + require.Empty(t, keys) + require.Empty(t, nextToken) +} + +func TestAccessGraphPrivateKeys(t *testing.T) { + backend, err := memory.New(memory.Config{}) + require.NoError(t, err) + service, err := NewAccessGraphSecretsService(backend) + require.NoError(t, err) + + ctx := context.TODO() + pageSize := 10 + pageToken := "" + + // Test case 1: Empty list + keys, nextToken, err := service.ListAllPrivateKeys(ctx, pageSize, pageToken) + require.NoError(t, err) + require.Empty(t, keys) + require.Empty(t, nextToken) + + // Test case 2: Non-empty list + privateKeysSpec := []*accessgraphsecretspb.PrivateKeySpec{ + { + DeviceId: "device1", + PublicKeyMode: accessgraphsecretspb.PublicKeyMode_PUBLIC_KEY_MODE_DERIVED, + PublicKeyFingerprint: "AAAAB3NzaC1yc2EAAAADAQABAAABAQC...", + }, + { + DeviceId: "device1", + PublicKeyMode: accessgraphsecretspb.PublicKeyMode_PUBLIC_KEY_MODE_PUB_FILE, + PublicKeyFingerprint: "AAAAB3NzaC1yc2EAAAADAQABAAABAQC...", + }, + { + DeviceId: "device2", + PublicKeyMode: accessgraphsecretspb.PublicKeyMode_PUBLIC_KEY_MODE_DERIVED, + PublicKeyFingerprint: "AAAAB3NzaC1yc2EAAAADAQABAAABAQC...", + }, + { + DeviceId: "device2", + PublicKeyMode: accessgraphsecretspb.PublicKeyMode_PUBLIC_KEY_MODE_PUB_FILE, + PublicKeyFingerprint: "AAAAB3NzaC1yc2EAAAADAQABAAABAQC...", + }, + } + var authKeys []*accessgraphsecretspb.PrivateKey + for _, key := range privateKeysSpec { + name := uuid.New().String() + prvKey, err := accessgraph.NewPrivateKeyWithName(name, key) + require.NoError(t, err) + _, err = service.UpsertPrivateKey(ctx, prvKey) + require.NoError(t, err) + authKeys = append(authKeys, prvKey) + } + + keys, nextToken, err = service.ListAllPrivateKeys(ctx, pageSize, pageToken) + require.NoError(t, err) + require.Empty(t, cmp.Diff(authKeys, keys, + protocmp.Transform(), + cmpopts.SortSlices(func(a, b *accessgraphsecretspb.PrivateKey) bool { + return a.Metadata.Name < b.Metadata.Name + }))) + require.Empty(t, nextToken) + + // Test case 3: Pagination + pageSize = 2 + pageToken = "" + keys, nextToken, err = service.ListAllPrivateKeys(ctx, pageSize, pageToken) + require.NoError(t, err) + require.Len(t, keys, pageSize) + require.NotEmpty(t, nextToken) + + pageToken = nextToken + keys, nextToken, err = service.ListAllPrivateKeys(ctx, pageSize, pageToken) + require.NoError(t, err) + require.Len(t, keys, pageSize) + require.Empty(t, nextToken) + + // Test case 4: List private keys for device + pageToken = "" + keysHost1, nextToken, err := service.ListPrivateKeysForDevice(ctx, "device1", pageSize, pageToken) + require.NoError(t, err) + require.Len(t, keys, 2) + require.Empty(t, nextToken) + keysHost2, nextToken, err := service.ListPrivateKeysForDevice(ctx, "device2", pageSize, pageToken) + require.NoError(t, err) + require.Len(t, keys, 2) + require.Empty(t, nextToken) + require.NotEqual(t, keysHost1, keysHost2) + + // Test case 5: List private keys for device with pagination + pageToken = "" + pageSize = 1 + keys, nextToken, err = service.ListPrivateKeysForDevice(ctx, "device1", pageSize, pageToken) + require.NoError(t, err) + require.Len(t, keys, 1) + require.NotEmpty(t, nextToken) + + pageToken = nextToken + keys, nextToken, err = service.ListPrivateKeysForDevice(ctx, "device1", pageSize, pageToken) + require.NoError(t, err) + require.Len(t, keys, 1) + require.Empty(t, nextToken) + + // Test case 6: Delete all + err = service.DeleteAllPrivateKeys(ctx) + require.NoError(t, err) + keys, nextToken, err = service.ListAllPrivateKeys(ctx, pageSize, pageToken) + require.NoError(t, err) + require.Empty(t, keys) + require.Empty(t, nextToken) +}