Skip to content

Commit

Permalink
Access Graph: sync AWS SAML Providers
Browse files Browse the repository at this point in the history
  • Loading branch information
justinas committed May 9, 2024
1 parent 3182324 commit fbb4793
Show file tree
Hide file tree
Showing 8 changed files with 285 additions and 0 deletions.
2 changes: 2 additions & 0 deletions lib/cloud/aws/policy_statements.go
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,8 @@ func StatementAccessGraphAWSSync() *Statement {
"iam:ListRolePolicies",
"iam:ListAttachedRolePolicies",
"iam:GetRolePolicy",
"iam:ListSAMLProviders",
"iam:GetSAMLProvider",
},
Resources: allResources,
}
Expand Down
27 changes: 27 additions & 0 deletions lib/cloud/mocks/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ type IAMMock struct {
attachedRolePolicies map[string]map[string]string
// attachedUserPolicies maps userName -> policyName -> policyDocument
attachedUserPolicies map[string]map[string]string
// SAMLProviders maps saml provider ARN -> samlProvider
SAMLProviders map[string]*iam.GetSAMLProviderOutput
}

func (m *IAMMock) GetRolePolicyWithContext(ctx aws.Context, input *iam.GetRolePolicyInput, options ...request.Option) (*iam.GetRolePolicyOutput, error) {
Expand Down Expand Up @@ -199,6 +201,31 @@ func (m *IAMMock) DeleteUserPolicyWithContext(ctx aws.Context, input *iam.Delete
return &iam.DeleteUserPolicyOutput{}, nil
}

func (m *IAMMock) ListSAMLProvidersWithContext(ctx aws.Context, input *iam.ListSAMLProvidersInput, options ...request.Option) (*iam.ListSAMLProvidersOutput, error) {
m.mu.RLock()
defer m.mu.RUnlock()
resp := &iam.ListSAMLProvidersOutput{}
for arn := range m.SAMLProviders {
resp.SAMLProviderList = append(resp.SAMLProviderList, &iam.SAMLProviderListEntry{
Arn: aws.String(arn),
})
}
return resp, nil
}

func (m *IAMMock) GetSAMLProviderWithContext(ctx aws.Context, input *iam.GetSAMLProviderInput, options ...request.Option) (*iam.GetSAMLProviderOutput, error) {
m.mu.RLock()
defer m.mu.RUnlock()
if input.SAMLProviderArn == nil {
return nil, trace.BadParameter("SAMLProviderARN must not be nil")
}
provider, ok := m.SAMLProviders[*input.SAMLProviderArn]
if !ok {
return nil, trace.BadParameter("SAML provider %q not found", *input.SAMLProviderArn)
}
return provider, nil
}

// IAMErrorMock is a mock IAM client that returns the provided Error to all
// APIs. If Error is not provided, all APIs returns trace.AccessDenied by
// default.
Expand Down
6 changes: 6 additions & 0 deletions lib/srv/discovery/fetchers/aws-sync/aws-sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ type Resources struct {
AccessEntries []*accessgraphv1alpha.AWSEKSClusterAccessEntryV1
// RDSDatabases is a list of RDS instances and clusters.
RDSDatabases []*accessgraphv1alpha.AWSRDSDatabaseV1
// SAMLProviders is a list of SAML providers.
SAMLProviders []*accessgraphv1alpha.AWSSAMLProviderV1
}

func (r *Resources) count() int {
Expand Down Expand Up @@ -243,6 +245,10 @@ func (a *awsFetcher) poll(ctx context.Context, features Features) (*Resources, e
eGroup.Go(a.pollAWSRDSDatabases(ctx, result, collectErr))
}

if features.SAML {
eGroup.Go(a.pollAWSSAMLProviders(ctx, result, collectErr))
}

if err := eGroup.Wait(); err != nil {
return nil, trace.Wrap(err)
}
Expand Down
5 changes: 5 additions & 0 deletions lib/srv/discovery/fetchers/aws-sync/features.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ const (
awsRDS = awsNamespace + "/" + "rds"
awsS3 = awsNamespace + "/" + "s3"
awsRoles = awsNamespace + "/" + "roles"
awsSAML = awsNamespace + "/" + "saml"
)

// Features is the list of supported resources by the server.
Expand All @@ -45,6 +46,8 @@ type Features struct {
EKS bool
// S3 enables AWS S3 sync.
S3 bool
// SAML enables AWS IAM SAML providers sync.
SAML bool
}

// BuildFeatures builds the feature flags based on supported types returned by Access Graph
Expand All @@ -67,6 +70,8 @@ func BuildFeatures(values ...string) Features {
features.S3 = true
case awsRoles:
features.Roles = true
case awsSAML:
features.SAML = true
}
}
return features
Expand Down
121 changes: 121 additions & 0 deletions lib/srv/discovery/fetchers/aws-sync/iam_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Teleport
* Copyright (C) 2024 Gravitational, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package aws_sync

import (
"context"
"sync"
"testing"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/testing/protocmp"
"google.golang.org/protobuf/types/known/timestamppb"
"google.golang.org/protobuf/types/known/wrapperspb"

accessgraphv1alpha "github.com/gravitational/teleport/gen/proto/go/accessgraph/v1alpha"
"github.com/gravitational/teleport/lib/cloud"
"github.com/gravitational/teleport/lib/cloud/mocks"
)

func TestAWSIAMPollSAMLProviders(t *testing.T) {
const (
accountID = "12345678"
)
var (
regions = []string{"eu-west-1"}
)

timestamp1 := time.Date(2024, time.May, 1, 1, 2, 3, 0, time.UTC)
timestamp2 := timestamp1.AddDate(1, 0, 0)

mockedClients := &cloud.TestCloudClients{
IAM: &mocks.IAMMock{
SAMLProviders: samlProviders(timestamp1, timestamp2),
},
}

var (
errs []error
mu sync.Mutex
)

collectErr := func(err error) {
mu.Lock()
defer mu.Unlock()
errs = append(errs, err)
}
a := &awsFetcher{
Config: Config{
AccountID: accountID,
CloudClients: mockedClients,
Regions: regions,
Integration: accountID,
},
}
expected := []*accessgraphv1alpha.AWSSAMLProviderV1{
{
Arn: "arn:aws:iam::1234678:saml-provider/provider1",
CreatedAt: timestamppb.New(timestamp1),
ValidUntil: timestamppb.New(timestamp2),
SamlMetadataDocument: "<foo></foo>",
Tags: []*accessgraphv1alpha.AWSTag{
{Key: "key1", Value: &wrapperspb.StringValue{Value: "value1"}},
{Key: "key2", Value: &wrapperspb.StringValue{Value: "value2"}},
},
AccountId: accountID,
},
{
Arn: "arn:aws:iam::1234678:saml-provider/provider2",
CreatedAt: timestamppb.New(timestamp2),
SamlMetadataDocument: "<bar></bar>",
AccountId: accountID,
},
}
result := &Resources{}
execFunc := a.pollAWSSAMLProviders(context.Background(), result, collectErr)
require.NoError(t, execFunc())
require.Empty(t, errs)
require.Empty(t, cmp.Diff(
expected,
result.SAMLProviders,
protocmp.Transform(),
))
}

func samlProviders(timestamp1, timestamp2 time.Time) map[string]*iam.GetSAMLProviderOutput {
return map[string]*iam.GetSAMLProviderOutput{
"arn:aws:iam::1234678:saml-provider/provider1": {
CreateDate: aws.Time(timestamp1),
SAMLMetadataDocument: aws.String("<foo></foo>"),
ValidUntil: aws.Time(timestamp2),
Tags: []*iam.Tag{
{Key: aws.String("key1"), Value: aws.String("value1")},
{Key: aws.String("key2"), Value: aws.String("value2")},
},
},
"arn:aws:iam::1234678:saml-provider/provider2": {
CreateDate: aws.Time(timestamp2),
SAMLMetadataDocument: aws.String("<bar></bar>"),
},
}
}
1 change: 1 addition & 0 deletions lib/srv/discovery/fetchers/aws-sync/merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ func MergeResources(results ...*Resources) *Resources {
result.EKSClusters = append(result.EKSClusters, r.EKSClusters...)
result.AccessEntries = append(result.AccessEntries, r.AccessEntries...)
result.RDSDatabases = append(result.RDSDatabases, r.RDSDatabases...)
result.SAMLProviders = append(result.SAMLProviders, r.SAMLProviders...)
}
return result
}
27 changes: 27 additions & 0 deletions lib/srv/discovery/fetchers/aws-sync/reconcile.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ func ReconcileResults(old *Resources, new *Resources) (upsert, delete *accessgra
reconcileAssociatedAccessPolicy(old.AssociatedAccessPolicies, new.AssociatedAccessPolicies),
reconcileAccessEntry(old.AccessEntries, new.AccessEntries),
reconcileAWSRDS(old.RDSDatabases, new.RDSDatabases),
reconcileSAMLProviders(old.SAMLProviders, new.SAMLProviders),
} {
upsert.Resources = append(upsert.Resources, results.upsert.Resources...)
delete.Resources = append(delete.Resources, results.delete.Resources...)
Expand Down Expand Up @@ -555,3 +556,29 @@ func reconcileAWSRDS(
}
return &reconcileIntermeditateResult{upsert, delete}
}

func reconcileSAMLProviders(
old []*accessgraphv1alpha.AWSSAMLProviderV1,
new []*accessgraphv1alpha.AWSSAMLProviderV1,
) *reconcileIntermeditateResult {
upsert, delete := &accessgraphv1alpha.AWSResourceList{}, &accessgraphv1alpha.AWSResourceList{}
toAdd, toRemove := reconcile(old, new, func(provider *accessgraphv1alpha.AWSSAMLProviderV1) string {
return fmt.Sprintf("%s;%s", provider.AccountId, provider.Arn)
})

for _, provider := range toAdd {
upsert.Resources = append(upsert.Resources, &accessgraphv1alpha.AWSResource{
Resource: &accessgraphv1alpha.AWSResource_SamlProvider{
SamlProvider: provider,
},
})
}
for _, provider := range toRemove {
delete.Resources = append(delete.Resources, &accessgraphv1alpha.AWSResource{
Resource: &accessgraphv1alpha.AWSResource_SamlProvider{
SamlProvider: provider,
},
})
}
return &reconcileIntermeditateResult{upsert, delete}
}
96 changes: 96 additions & 0 deletions lib/srv/discovery/fetchers/aws-sync/saml.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Teleport
* Copyright (C) 2024 Gravitational, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package aws_sync

import (
"context"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/iam/iamiface"
"github.com/gravitational/trace"

accessgraphv1alpha "github.com/gravitational/teleport/gen/proto/go/accessgraph/v1alpha"
)

func (a *awsFetcher) pollAWSSAMLProviders(ctx context.Context, result *Resources, collectErr func(error)) func() error {
return func() error {
var err error

iamClient, err := a.CloudClients.GetAWSIAMClient(
ctx,
"", /* region is empty because saml providers are global */
a.getAWSOptions()...,
)
if err != nil {
collectErr(trace.Wrap(err, "failed to get AWS IAM client"))
return nil
}
listResp, err := iamClient.ListSAMLProvidersWithContext(ctx, &iam.ListSAMLProvidersInput{})
if err != nil {
collectErr(trace.Wrap(err, "failed to list AWS SAML identity providers"))
return nil
}

providers := make([]*accessgraphv1alpha.AWSSAMLProviderV1, 0, len(listResp.SAMLProviderList))
for _, providerRef := range listResp.SAMLProviderList {
arn := aws.StringValue(providerRef.Arn)
provider, err := a.fetchAWSSAMLProvider(ctx, iamClient, arn)
if err != nil {
collectErr(trace.Wrap(err, "failed to get info for SAML provider %s", arn))
} else {
providers = append(providers, provider)
}
}

result.SAMLProviders = providers
return nil
}
}

// fetchAWSSAMLProvider fetches data about a single SAML identity provider.
func (a *awsFetcher) fetchAWSSAMLProvider(ctx context.Context, client iamiface.IAMAPI, arn string) (*accessgraphv1alpha.AWSSAMLProviderV1, error) {
providerResp, err := client.GetSAMLProviderWithContext(ctx, &iam.GetSAMLProviderInput{
SAMLProviderArn: aws.String(arn),
})
if err != nil {
return nil, err
}
return awsSAMLProviderOutputToProto(arn, a.AccountID, providerResp), nil
}

// awsSAMLProviderToSAML converts an iam.SAMLProvider to accessgraphv1alpha.SAMLProviderV1 representation.
func awsSAMLProviderOutputToProto(arn string, accountID string, provider *iam.GetSAMLProviderOutput) *accessgraphv1alpha.AWSSAMLProviderV1 {
var tags []*accessgraphv1alpha.AWSTag
for _, v := range provider.Tags {
tags = append(tags, &accessgraphv1alpha.AWSTag{
Key: aws.StringValue(v.Key),
Value: strPtrToWrapper(v.Value),
})
}

return &accessgraphv1alpha.AWSSAMLProviderV1{
Arn: arn,
CreatedAt: awsTimeToProtoTime(provider.CreateDate),
ValidUntil: awsTimeToProtoTime(provider.ValidUntil),
SamlMetadataDocument: aws.StringValue(provider.SAMLMetadataDocument),
Tags: tags,
AccountId: accountID,
}
}

0 comments on commit fbb4793

Please sign in to comment.