diff --git a/common/aws/dynamodb/client.go b/common/aws/dynamodb/client.go index a3436a342..4b013cc41 100644 --- a/common/aws/dynamodb/client.go +++ b/common/aws/dynamodb/client.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "math" + "strconv" "sync" commonaws "github.com/Layr-Labs/eigenda/common/aws" @@ -128,8 +129,8 @@ func (c *Client) PutItems(ctx context.Context, tableName string, items []Item) ( func (c *Client) UpdateItem(ctx context.Context, tableName string, key Key, item Item) (Item, error) { update := expression.UpdateBuilder{} for itemKey, itemValue := range item { + // Ignore primary key updates if _, ok := key[itemKey]; ok { - // Cannot update the key continue } update = update.Set(expression.Name(itemKey), expression.Value(itemValue)) @@ -156,6 +157,36 @@ func (c *Client) UpdateItem(ctx context.Context, tableName string, key Key, item return resp.Attributes, err } +// IncrementBy increments the attribute by the value for item that matches with the key +func (c *Client) IncrementBy(ctx context.Context, tableName string, key Key, attr string, value uint64) (Item, error) { + // ADD numeric values + f, err := strconv.ParseFloat(strconv.FormatUint(value, 10), 64) + if err != nil { + return nil, err + } + + update := expression.UpdateBuilder{} + update = update.Add(expression.Name(attr), expression.Value(aws.Float64(f))) + expr, err := expression.NewBuilder().WithUpdate(update).Build() + if err != nil { + return nil, err + } + + resp, err := c.dynamoClient.UpdateItem(ctx, &dynamodb.UpdateItemInput{ + TableName: aws.String(tableName), + Key: key, + ExpressionAttributeNames: expr.Names(), + ExpressionAttributeValues: expr.Values(), + UpdateExpression: expr.Update(), + ReturnValues: types.ReturnValueUpdatedNew, + }) + if err != nil { + return nil, err + } + + return resp.Attributes, nil +} + func (c *Client) GetItem(ctx context.Context, tableName string, key Key) (Item, error) { resp, err := c.dynamoClient.GetItem(ctx, &dynamodb.GetItemInput{Key: key, TableName: aws.String(tableName)}) if err != nil { @@ -191,6 +222,24 @@ func (c *Client) QueryIndex(ctx context.Context, tableName string, indexName str return response.Items, nil } +// QueryIndexOrderWithLimit returns all items in the index that match the given key +// If forward is true, the items are returned in ascending order +func (c *Client) QueryIndexOrderWithLimit(ctx context.Context, tableName string, indexName string, keyCondition string, expAttributeValues ExpressionValues, forward bool, limit int32) ([]Item, error) { + response, err := c.dynamoClient.Query(ctx, &dynamodb.QueryInput{ + TableName: aws.String(tableName), + IndexName: aws.String(indexName), + KeyConditionExpression: aws.String(keyCondition), + ExpressionAttributeValues: expAttributeValues, + ScanIndexForward: &forward, + Limit: aws.Int32(limit), + }) + if err != nil { + return nil, err + } + + return response.Items, nil +} + // Query returns all items in the primary index that match the given expression func (c *Client) Query(ctx context.Context, tableName string, keyCondition string, expAttributeValues ExpressionValues) ([]Item, error) { response, err := c.dynamoClient.Query(ctx, &dynamodb.QueryInput{ diff --git a/common/aws/dynamodb/client_test.go b/common/aws/dynamodb/client_test.go index 786309085..a67db4e6c 100644 --- a/common/aws/dynamodb/client_test.go +++ b/common/aws/dynamodb/client_test.go @@ -205,6 +205,11 @@ func TestBasicOperations(t *testing.T) { }) assert.NoError(t, err) + _, err = dynamoClient.IncrementBy(ctx, tableName, commondynamodb.Key{ + "MetadataKey": &types.AttributeValueMemberS{Value: "key"}, + }, "BlobSize", 1000) + assert.NoError(t, err) + item, err = dynamoClient.GetItem(ctx, tableName, commondynamodb.Key{ "MetadataKey": &types.AttributeValueMemberS{Value: "key"}, }) @@ -213,6 +218,7 @@ func TestBasicOperations(t *testing.T) { assert.Equal(t, "Confirmed", item["Status"].(*types.AttributeValueMemberS).Value) assert.Equal(t, "0x123", item["BatchHeaderHash"].(*types.AttributeValueMemberS).Value) assert.Equal(t, "0", item["BlobIndex"].(*types.AttributeValueMemberN).Value) + assert.Equal(t, "1123", item["BlobSize"].(*types.AttributeValueMemberN).Value) err = dynamoClient.DeleteTable(ctx, tableName) assert.NoError(t, err) @@ -596,3 +602,62 @@ func TestQueryIndexWithPaginationForBatch(t *testing.T) { assert.Len(t, queryResult.Items, 0) assert.Nil(t, queryResult.LastEvaluatedKey) } + +func TestQueryIndexOrderWithLimit(t *testing.T) { + tableName := "ProcessingQueryIndexOrderWithLimit" + createTable(t, tableName) + indexName := "StatusIndex" + + ctx := context.Background() + numItems := 30 + items := make([]commondynamodb.Item, numItems) + for i := 0; i < numItems; i++ { + requestedAt := time.Now().Add(-time.Duration(i) * time.Minute).Unix() + items[i] = commondynamodb.Item{ + "MetadataKey": &types.AttributeValueMemberS{Value: fmt.Sprintf("key%d", i)}, + "BlobKey": &types.AttributeValueMemberS{Value: fmt.Sprintf("blob%d", i)}, + "BlobSize": &types.AttributeValueMemberN{Value: "123"}, + "BlobStatus": &types.AttributeValueMemberN{Value: "0"}, + "RequestedAt": &types.AttributeValueMemberN{Value: strconv.FormatInt(requestedAt, 10)}, + } + } + unprocessed, err := dynamoClient.PutItems(ctx, tableName, items) + assert.NoError(t, err) + assert.Len(t, unprocessed, 0) + + // Test forward order with limit + queryResult, err := dynamoClient.QueryIndexOrderWithLimit(ctx, tableName, indexName, "BlobStatus = :status", commondynamodb.ExpressionValues{ + ":status": &types.AttributeValueMemberN{Value: "0"}, + }, true, 10) + assert.NoError(t, err) + assert.Len(t, queryResult, 10) + // Check if the items are in ascending order + for i := 0; i < len(queryResult)-1; i++ { + assert.True(t, queryResult[i]["RequestedAt"].(*types.AttributeValueMemberN).Value <= queryResult[i+1]["RequestedAt"].(*types.AttributeValueMemberN).Value) + } + + // Test reverse order with limit + queryResult, err = dynamoClient.QueryIndexOrderWithLimit(ctx, tableName, indexName, "BlobStatus = :status", commondynamodb.ExpressionValues{ + ":status": &types.AttributeValueMemberN{Value: "0"}, + }, false, 10) + assert.NoError(t, err) + assert.Len(t, queryResult, 10) + // Check if the items are in descending order + for i := 0; i < len(queryResult)-1; i++ { + assert.True(t, queryResult[i]["RequestedAt"].(*types.AttributeValueMemberN).Value >= queryResult[i+1]["RequestedAt"].(*types.AttributeValueMemberN).Value) + } + + // Test with a smaller limit + queryResult, err = dynamoClient.QueryIndexOrderWithLimit(ctx, tableName, indexName, "BlobStatus = :status", commondynamodb.ExpressionValues{ + ":status": &types.AttributeValueMemberN{Value: "0"}, + }, true, 5) + assert.NoError(t, err) + assert.Len(t, queryResult, 5) + + // Test with a limit larger than the number of items + queryResult, err = dynamoClient.QueryIndexOrderWithLimit(ctx, tableName, indexName, "BlobStatus = :status", commondynamodb.ExpressionValues{ + ":status": &types.AttributeValueMemberN{Value: "0"}, + }, true, 50) + assert.NoError(t, err) + assert.Len(t, queryResult, 30) // Should return all items +}