From 682b7a03383910ba730b690c49bcf3c3b9b81c7b Mon Sep 17 00:00:00 2001 From: Joel Date: Mon, 2 May 2022 18:16:54 +0200 Subject: [PATCH 1/2] Handle DynamoDB pay-per-request mode correctly (#12295) --- lib/events/dynamoevents/dynamoevents.go | 30 +++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/lib/events/dynamoevents/dynamoevents.go b/lib/events/dynamoevents/dynamoevents.go index 7446fd9ef6290..e030737c9a58a 100644 --- a/lib/events/dynamoevents/dynamoevents.go +++ b/lib/events/dynamoevents/dynamoevents.go @@ -189,6 +189,9 @@ type Log struct { // readyForQuery is used to determine if all indexes are in place // for event queries. readyForQuery *atomic.Bool + + // isBillingModeProvisioned tracks if the table has provisioned capacity or not. + isBillingModeProvisioned bool } type event struct { @@ -301,6 +304,11 @@ func New(ctx context.Context, cfg Config, backend backend.Backend) (*Log, error) return nil, trace.Wrap(err) } + b.isBillingModeProvisioned, err = b.getBillingModeIsProvisioned(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + // Migrate the table. go b.migrateWithRetry(ctx, []migrationTask{ {b.migrateRFD24, "migrateRFD24"}, @@ -1158,6 +1166,17 @@ func (l *Log) getTableStatus(ctx context.Context, tableName string) (tableStatus return tableStatusOK, nil } +func (l *Log) getBillingModeIsProvisioned(ctx context.Context) (bool, error) { + table, err := l.svc.DescribeTableWithContext(ctx, &dynamodb.DescribeTableInput{ + TableName: aws.String(l.Tablename), + }) + if err != nil { + return false, trace.Wrap(err) + } + + return *table.Table.BillingModeSummary.BillingMode == dynamodb.BillingModeProvisioned, nil +} + // indexExists checks if a given index exists on a given table and that it is active or updating. func (l *Log) indexExists(ctx context.Context, tableName, indexName string) (bool, error) { tableDescription, err := l.svc.DescribeTableWithContext(ctx, &dynamodb.DescribeTableInput{ @@ -1195,9 +1214,12 @@ func (l *Log) createV2GSI(ctx context.Context) error { return nil } - provisionedThroughput := dynamodb.ProvisionedThroughput{ - ReadCapacityUnits: aws.Int64(l.ReadCapacityUnits), - WriteCapacityUnits: aws.Int64(l.WriteCapacityUnits), + var provisionedThroughput *dynamodb.ProvisionedThroughput + if l.isBillingModeProvisioned { + provisionedThroughput = &dynamodb.ProvisionedThroughput{ + ReadCapacityUnits: aws.Int64(l.ReadCapacityUnits), + WriteCapacityUnits: aws.Int64(l.WriteCapacityUnits), + } } // This defines the update event we send to DynamoDB. @@ -1224,7 +1246,7 @@ func (l *Log) createV2GSI(ctx context.Context) error { Projection: &dynamodb.Projection{ ProjectionType: aws.String("ALL"), }, - ProvisionedThroughput: &provisionedThroughput, + ProvisionedThroughput: provisionedThroughput, }, }, }, From 2c04be5645d85c5fae0937d2d60621f11b3f7e52 Mon Sep 17 00:00:00 2001 From: Joel Date: Thu, 5 May 2022 16:34:04 +0200 Subject: [PATCH 2/2] Add nil check for billing mode in AWS DynamoDB events driver (#12445) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Check for nil in the billing mode getter + fix panic if eventtype is not set * use a better strategy to check if the interface is nil * add godoc * simplify switch * update godoc * simplify unknown field extraction code * ensure underlying type is correct * Update lib/events/dynamic_test.go Co-authored-by: Krzysztof Skrzętnicki * fix lint * Update AWS dynamo tests as a result of the dynamic.go changes, I don't think they've ever worked in some time as they appeared quite broke. * simplify * simplify 2 Co-authored-by: Krzysztof Skrzętnicki --- lib/events/dynamic.go | 20 ++++-- lib/events/dynamic_test.go | 69 ++++++++++++++++++++ lib/events/dynamoevents/dynamoevents.go | 11 +++- lib/events/dynamoevents/dynamoevents_test.go | 27 ++++---- lib/events/test/suite.go | 57 +++++++++++++--- 5 files changed, 155 insertions(+), 29 deletions(-) create mode 100644 lib/events/dynamic_test.go diff --git a/lib/events/dynamic.go b/lib/events/dynamic.go index 4894bf43c5fbb..0e320b4501470 100644 --- a/lib/events/dynamic.go +++ b/lib/events/dynamic.go @@ -18,7 +18,6 @@ package events import ( "github.com/gravitational/teleport/api/types/events" - apievents "github.com/gravitational/teleport/api/types/events" apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/trace" @@ -32,13 +31,22 @@ import ( // // This is mainly used to convert from the backend format used by // our various event backends. -func FromEventFields(fields EventFields) (apievents.AuditEvent, error) { +func FromEventFields(fields EventFields) (events.AuditEvent, error) { data, err := json.Marshal(fields) if err != nil { return nil, trace.Wrap(err) } - eventType := fields.GetString(EventType) + getFieldEmpty := func(field string) string { + i, ok := fields[field] + if !ok { + return "" + } + s, _ := i.(string) + return s + } + + var eventType = getFieldEmpty(EventType) switch eventType { case SessionPrintEvent: @@ -475,7 +483,7 @@ func FromEventFields(fields EventFields) (apievents.AuditEvent, error) { unknown.Type = UnknownEvent unknown.Code = UnknownCode unknown.UnknownType = eventType - unknown.UnknownCode = fields.GetString(EventCode) + unknown.UnknownCode = getFieldEmpty(EventCode) unknown.Data = string(data) return unknown, nil } @@ -483,7 +491,7 @@ func FromEventFields(fields EventFields) (apievents.AuditEvent, error) { // GetSessionID pulls the session ID from the events that have a // SessionMetadata. For other events an empty string is returned. -func GetSessionID(event apievents.AuditEvent) string { +func GetSessionID(event events.AuditEvent) string { var sessionID string if g, ok := event.(SessionMetadataGetter); ok { @@ -496,7 +504,7 @@ func GetSessionID(event apievents.AuditEvent) string { // ToEventFields converts from the typed interface-style event representation // to the old dynamic map style representation in order to provide outer compatibility // with existing public API routes when the backend is updated with the typed events. -func ToEventFields(event apievents.AuditEvent) (EventFields, error) { +func ToEventFields(event events.AuditEvent) (EventFields, error) { var fields EventFields if err := apiutils.ObjectToStruct(event, &fields); err != nil { return nil, trace.Wrap(err) diff --git a/lib/events/dynamic_test.go b/lib/events/dynamic_test.go new file mode 100644 index 0000000000000..e0cee4741eead --- /dev/null +++ b/lib/events/dynamic_test.go @@ -0,0 +1,69 @@ +/* +Copyright 2022 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package events + +import ( + "testing" + + "github.com/gravitational/teleport/api/types/events" + "github.com/stretchr/testify/require" +) + +// TestDynamicTypeUnknown checks that we correctly translate unknown events strings into the correct proto type. +func TestDynamicUnknownType(t *testing.T) { + fields := EventFields{ + EventType: "suspicious-cert-event", + EventCode: "foobar", + } + + event, err := FromEventFields(fields) + require.NoError(t, err) + + require.Equal(t, UnknownEvent, event.GetType()) + require.Equal(t, UnknownCode, event.GetCode()) + unknownEvent := event.(*events.Unknown) + require.Equal(t, "suspicious-cert-event", unknownEvent.UnknownType) + require.Equal(t, "foobar", unknownEvent.UnknownCode) +} + +// TestDynamicNotSet checks that we properly handle cases where the event type is not set. +func TestDynamicTypeNotSet(t *testing.T) { + fields := EventFields{ + "foo": "bar", + } + + event, err := FromEventFields(fields) + require.NoError(t, err) + + require.Equal(t, UnknownEvent, event.GetType()) + require.Equal(t, UnknownCode, event.GetCode()) + unknownEvent := event.(*events.Unknown) + require.Equal(t, "", unknownEvent.UnknownType) + require.Equal(t, "", unknownEvent.UnknownCode) +} + +// TestDynamicTypeUnknown checks that we correctly translate known events into the correct proto type. +func TestDynamicKnownType(t *testing.T) { + fields := EventFields{ + EventType: "print", + } + + event, err := FromEventFields(fields) + require.NoError(t, err) + printEvent := event.(*events.SessionPrint) + require.Equal(t, SessionPrintEvent, printEvent.GetType()) +} diff --git a/lib/events/dynamoevents/dynamoevents.go b/lib/events/dynamoevents/dynamoevents.go index e030737c9a58a..9addd41d6661d 100644 --- a/lib/events/dynamoevents/dynamoevents.go +++ b/lib/events/dynamoevents/dynamoevents.go @@ -1167,14 +1167,21 @@ func (l *Log) getTableStatus(ctx context.Context, tableName string) (tableStatus } func (l *Log) getBillingModeIsProvisioned(ctx context.Context) (bool, error) { - table, err := l.svc.DescribeTableWithContext(ctx, &dynamodb.DescribeTableInput{ + res, err := l.svc.DescribeTableWithContext(ctx, &dynamodb.DescribeTableInput{ TableName: aws.String(l.Tablename), }) if err != nil { return false, trace.Wrap(err) } - return *table.Table.BillingModeSummary.BillingMode == dynamodb.BillingModeProvisioned, nil + // Guaranteed to be set. + table := res.Table + + // Perform pessimistic nil-checks, assume the table is provisioned if they are true. + // Otherwise, actually check the billing mode. + return table.BillingModeSummary == nil || + table.BillingModeSummary.BillingMode == nil || + *table.BillingModeSummary.BillingMode == dynamodb.BillingModeProvisioned, nil } // indexExists checks if a given index exists on a given table and that it is active or updating. diff --git a/lib/events/dynamoevents/dynamoevents_test.go b/lib/events/dynamoevents/dynamoevents_test.go index 6a759082a80ba..093fdfe7b74d2 100644 --- a/lib/events/dynamoevents/dynamoevents_test.go +++ b/lib/events/dynamoevents/dynamoevents_test.go @@ -123,12 +123,15 @@ func (s *DynamoeventsSuite) TestSizeBreak(c *check.C) { const eventCount int = 10 for i := 0; i < eventCount; i++ { - err := s.Log.EmitAuditEventLegacy(events.UserLocalLoginE, events.EventFields{ - events.LoginMethod: events.LoginMethodSAML, - events.AuthAttemptSuccess: true, - events.EventUser: "bob", - events.EventTime: s.Clock.Now().UTC().Add(time.Second * time.Duration(i)), - "test.data": blob, + err := s.Log.EmitAuditEvent(context.Background(), &apievents.UserLogin{ + Method: events.LoginMethodSAML, + Status: apievents.Status{Success: true}, + UserMetadata: apievents.UserMetadata{User: "bob"}, + Metadata: apievents.Metadata{ + Type: events.UserLoginEvent, + Time: s.Clock.Now().UTC().Add(time.Second * time.Duration(i)), + }, + IdentityAttributes: apievents.MustEncodeMap(map[string]interface{}{"test.data": blob}), }) c.Assert(err, check.IsNil) } @@ -309,11 +312,13 @@ var _ = check.Suite(&DynamoeventsLargeTableSuite{}) func (s *DynamoeventsLargeTableSuite) TestLargeTableRetrieve(c *check.C) { const eventCount = 4000 for i := 0; i < eventCount; i++ { - err := s.Log.EmitAuditEventLegacy(events.UserLocalLoginE, events.EventFields{ - events.LoginMethod: events.LoginMethodSAML, - events.AuthAttemptSuccess: true, - events.EventUser: "bob", - events.EventTime: s.Clock.Now().UTC(), + err := s.Log.EmitAuditEvent(context.Background(), &apievents.UserLogin{ + Method: events.LoginMethodSAML, + Status: apievents.Status{Success: true}, + UserMetadata: apievents.UserMetadata{User: "bob"}, + Metadata: apievents.Metadata{ + Type: events.UserLoginEvent, + Time: s.Clock.Now().UTC()}, }) c.Assert(err, check.IsNil) } diff --git a/lib/events/test/suite.go b/lib/events/test/suite.go index 5f0c95d47998a..e3a5c5b6cff17 100644 --- a/lib/events/test/suite.go +++ b/lib/events/test/suite.go @@ -90,11 +90,14 @@ func (s *EventsSuite) EventPagination(c *check.C) { names := []string{"bob", "jack", "daisy", "evan"} for i, name := range names { - err := s.Log.EmitAuditEventLegacy(events.UserLocalLoginE, events.EventFields{ - events.LoginMethod: events.LoginMethodSAML, - events.AuthAttemptSuccess: true, - events.EventUser: name, - events.EventTime: baseTime.Add(time.Second * time.Duration(i)), + err := s.Log.EmitAuditEvent(context.Background(), &apievents.UserLogin{ + Method: events.LoginMethodSAML, + Status: apievents.Status{Success: true}, + UserMetadata: apievents.UserMetadata{User: name}, + Metadata: apievents.Metadata{ + Type: events.UserLoginEvent, + Time: baseTime.Add(time.Second * time.Duration(i)), + }, }) c.Assert(err, check.IsNil) } @@ -166,11 +169,14 @@ func (s *EventsSuite) EventPagination(c *check.C) { // SessionEventsCRUD covers session events func (s *EventsSuite) SessionEventsCRUD(c *check.C) { // Bob has logged in - err := s.Log.EmitAuditEventLegacy(events.UserLocalLoginE, events.EventFields{ - events.LoginMethod: events.LoginMethodSAML, - events.AuthAttemptSuccess: true, - events.EventUser: "bob", - events.EventTime: s.Clock.Now().UTC(), + err := s.Log.EmitAuditEvent(context.Background(), &apievents.UserLogin{ + Method: events.LoginMethodSAML, + Status: apievents.Status{Success: true}, + UserMetadata: apievents.UserMetadata{User: "bob"}, + Metadata: apievents.Metadata{ + Type: events.UserLoginEvent, + Time: s.Clock.Now().UTC(), + }, }) c.Assert(err, check.IsNil) @@ -213,6 +219,37 @@ func (s *EventsSuite) SessionEventsCRUD(c *check.C) { }) c.Assert(err, check.IsNil) + err = s.Log.EmitAuditEvent(context.Background(), &apievents.SessionStart{ + Metadata: apievents.Metadata{ + Time: s.Clock.Now().UTC(), + Index: 0, + Type: events.SessionStartEvent, + }, + SessionMetadata: apievents.SessionMetadata{ + SessionID: string(sessionID), + }, + UserMetadata: apievents.UserMetadata{ + Login: "bob", + }, + }) + c.Assert(err, check.IsNil) + + err = s.Log.EmitAuditEvent(context.Background(), &apievents.SessionEnd{ + Metadata: apievents.Metadata{ + Time: s.Clock.Now().Add(time.Hour).UTC(), + Index: 4, + Type: events.SessionEndEvent, + }, + UserMetadata: apievents.UserMetadata{ + Login: "bob", + }, + SessionMetadata: apievents.SessionMetadata{ + SessionID: string(sessionID), + }, + Participants: []string{"bob", "alice"}, + }) + c.Assert(err, check.IsNil) + // read the session event historyEvents, err := s.Log.GetSessionEvents(apidefaults.Namespace, sessionID, 0, false) c.Assert(err, check.IsNil)