Skip to content

Commit

Permalink
fix context
Browse files Browse the repository at this point in the history
  • Loading branch information
masashiy22 committed Apr 26, 2024
1 parent 6b43094 commit 57b5c0e
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 29 deletions.
15 changes: 10 additions & 5 deletions controller/training_item.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ func ListTraningItem(c *gin.Context) {
return
}

trainingItems, err := service.GetTraningItems(userId)
ctx := c.Request.Context()
trainingItems, err := service.GetTraningItems(ctx, userId)
if err != nil {
logger.Logger.Error("ListTraningItem Failed.", logger.ErrAttr(err))
return
Expand Down Expand Up @@ -55,7 +56,8 @@ func GetTraningItem(c *gin.Context) {
return
}

trainingItem, err := service.GetTraningItem(trainingItemId, userId)
ctx := c.Request.Context()
trainingItem, err := service.GetTraningItem(ctx, trainingItemId, userId)
if err != nil {
logger.Logger.Error("GetTraningItem Failed.", logger.ErrAttr(err))
return
Expand Down Expand Up @@ -84,7 +86,8 @@ func CreateTraningItem(c *gin.Context) {
return
}

err = service.CreateTraningItem(&requestBody)
ctx := c.Request.Context()
err = service.CreateTraningItem(ctx, &requestBody)
if err != nil {
logger.Logger.Error("CreateTraningItem Failed.", logger.ErrAttr(err))
return
Expand Down Expand Up @@ -120,7 +123,8 @@ func UpdateTraningItem(c *gin.Context) {
return
}

err = service.UpdateTraningItem(&requestBody, userId)
ctx := c.Request.Context()
err = service.UpdateTraningItem(ctx, &requestBody, userId)
if err != nil {
if customErr, ok := err.(*customerror.Error404); ok {
c.JSON(customErr.ErrorCode, customErr.Body)
Expand Down Expand Up @@ -148,7 +152,8 @@ func DeleteTraningItem(c *gin.Context) {
return
}

err = service.DeleteTraningItem(trainingItemId, userId)
ctx := c.Request.Context()
err = service.DeleteTraningItem(ctx, trainingItemId, userId)
if err != nil {
if customErr, ok := err.(*customerror.Error404); ok {
c.JSON(customErr.ErrorCode, customErr.Body)
Expand Down
2 changes: 1 addition & 1 deletion dynamodb/dynamodb_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ var DynamoDbClient *dynamodb.Client
func InitDynamoDbClient() {
logger.Logger.Debug("Init DynamoDB client.")

cfg, err := config.LoadDefaultConfig(context.TODO())
cfg, err := config.LoadDefaultConfig(context.Background())
if err != nil {
logger.Logger.Error("Load aws config error.", logger.ErrAttr(err))
panic("Failed to start anytore.")
Expand Down
32 changes: 16 additions & 16 deletions service/training_item.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
"github.com/openkrafter/anytore-backend/model"
)

func GetTraningItems(userId int) ([]*model.TrainingItem, error) {
func GetTraningItems(ctx context.Context, userId int) ([]*model.TrainingItem, error) {
basics, err := NewTableBasics("TrainingItem")
if err != nil {
logger.Logger.Error("DynamoDB client init error.", logger.ErrAttr(err))
Expand Down Expand Up @@ -52,7 +52,7 @@ func GetTraningItems(userId int) ([]*model.TrainingItem, error) {
IndexName: aws.String("UserIdIndex"),
}

response, err := basics.DynamoDbClient.Query(context.TODO(), queryInput)
response, err := basics.DynamoDbClient.Query(ctx, queryInput)
if err != nil {
logger.Logger.Error("Failed to get TrainingItems.", logger.ErrAttr(err))
return nil, err
Expand All @@ -78,7 +78,7 @@ func GetTraningItems(userId int) ([]*model.TrainingItem, error) {
return trainingItems, nil
}

func GetTraningItem(id int, userId int) (*model.TrainingItem, error) {
func GetTraningItem(ctx context.Context, id int, userId int) (*model.TrainingItem, error) {
logger.Logger.Debug("GetTraningItem process", slog.Int("id", id))

logger.Logger.Debug("Init DynamoDB client.")
Expand Down Expand Up @@ -120,7 +120,7 @@ func GetTraningItem(id int, userId int) (*model.TrainingItem, error) {
IndexName: aws.String("UserIdIndex"),
}

response, err := basics.DynamoDbClient.Query(context.TODO(), queryInput)
response, err := basics.DynamoDbClient.Query(ctx, queryInput)
if err != nil {
logger.Logger.Error("Failed to get TrainingItems.", logger.ErrAttr(err))
return nil, err
Expand All @@ -143,14 +143,14 @@ func GetTraningItem(id int, userId int) (*model.TrainingItem, error) {
return trainingItem, nil
}

func CreateTraningItem(trainingItem *model.TrainingItem) error {
func CreateTraningItem(ctx context.Context, trainingItem *model.TrainingItem) error {
basics, err := NewTableBasics("TrainingItem")
if err != nil {
logger.Logger.Error("DynamoDB client init error.", logger.ErrAttr(err))
return err
}

trainingItem.Id, err = getIncrementId()
trainingItem.Id, err = getIncrementId(ctx)
if err != nil {
logger.Logger.Error("getIncrementId Failed.", logger.ErrAttr(err))
return err
Expand All @@ -160,7 +160,7 @@ func CreateTraningItem(trainingItem *model.TrainingItem) error {
if err != nil {
return err
}
_, err = basics.DynamoDbClient.PutItem(context.TODO(), &dynamodb.PutItemInput{
_, err = basics.DynamoDbClient.PutItem(ctx, &dynamodb.PutItemInput{
TableName: aws.String(basics.TableName),
Item: av,
})
Expand All @@ -171,14 +171,14 @@ func CreateTraningItem(trainingItem *model.TrainingItem) error {
return nil
}

func UpdateTraningItem(trainingItem *model.TrainingItem, userId int) error {
func UpdateTraningItem(ctx context.Context, trainingItem *model.TrainingItem, userId int) error {
basics, err := NewTableBasics("TrainingItem")
if err != nil {
logger.Logger.Error("DynamoDB client init error.", logger.ErrAttr(err))
return err
}

getTraningItemResult, err := GetTraningItem(trainingItem.Id, userId)
getTraningItemResult, err := GetTraningItem(ctx, trainingItem.Id, userId)
if err != nil {
return err
}
Expand All @@ -191,7 +191,7 @@ func UpdateTraningItem(trainingItem *model.TrainingItem, userId int) error {
if err != nil {
return err
}
_, err = basics.DynamoDbClient.PutItem(context.TODO(), &dynamodb.PutItemInput{
_, err = basics.DynamoDbClient.PutItem(ctx, &dynamodb.PutItemInput{
TableName: aws.String(basics.TableName),
Item: av,
})
Expand All @@ -202,14 +202,14 @@ func UpdateTraningItem(trainingItem *model.TrainingItem, userId int) error {
return nil
}

func DeleteTraningItem(id int, userId int) error {
func DeleteTraningItem(ctx context.Context, id int, userId int) error {
basics, err := NewTableBasics("TrainingItem")
if err != nil {
logger.Logger.Error("DynamoDB client init error.", logger.ErrAttr(err))
return err
}

getTraningItemResult, err := GetTraningItem(id, userId)
getTraningItemResult, err := GetTraningItem(ctx, id, userId)
if err != nil {
return err
}
Expand All @@ -227,15 +227,15 @@ func DeleteTraningItem(id int, userId int) error {
TableName: aws.String(basics.TableName),
}

_, err = basics.DynamoDbClient.DeleteItem(context.TODO(), deleteInput)
_, err = basics.DynamoDbClient.DeleteItem(ctx, deleteInput)
if err != nil {
return err
}

return nil
}

func getIncrementId() (int, error) {
func getIncrementId(ctx context.Context) (int, error) {
basics, err := NewTableBasics("TrainingItemCounter")
if err != nil {
logger.Logger.Error("DynamoDB client init error.", logger.ErrAttr(err))
Expand All @@ -261,7 +261,7 @@ func getIncrementId() (int, error) {
},
}

result, err := basics.DynamoDbClient.UpdateItem(context.TODO(), &dynamodb.UpdateItemInput{
result, err := basics.DynamoDbClient.UpdateItem(ctx, &dynamodb.UpdateItemInput{
TableName: aws.String("TrainingItemCounter"),
Key: countKey,
UpdateExpression: aws.String(updateExpression),
Expand All @@ -274,7 +274,7 @@ func getIncrementId() (int, error) {
if ok := errors.As(err, &apiErr); ok {
if apiErr.ErrorCode() == "ConditionalCheckFailedException" {
logger.Logger.Info("No item in TrainingItemCounter table, put initial item.", logger.ErrAttr(err))
_, err = basics.DynamoDbClient.PutItem(context.TODO(), &dynamodb.PutItemInput{
_, err = basics.DynamoDbClient.PutItem(ctx, &dynamodb.PutItemInput{
TableName: aws.String("TrainingItemCounter"),
Item: map[string]types.AttributeValue{
"CountKey": &types.AttributeValueMemberS{Value: "key"},
Expand Down
9 changes: 5 additions & 4 deletions service/training_item_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package service

import (
"context"
"log"
"reflect"
"sort"
Expand Down Expand Up @@ -79,7 +80,7 @@ func TestGetTraningItems(t *testing.T) {
}
defer testenvironment.TeardownTraningItemTestData()

got, err := GetTraningItems(tt.args.userId)
got, err := GetTraningItems(context.Background(), tt.args.userId)
if (err != nil) != tt.wantErr {
t.Errorf("GetTraningItems() error = %v, wantErr %v", err, tt.wantErr)
return
Expand Down Expand Up @@ -155,7 +156,7 @@ func TestGetTraningItem(t *testing.T) {
}
defer testenvironment.TeardownTraningItemTestData()

got, err := GetTraningItem(tt.args.id, tt.args.userId)
got, err := GetTraningItem(context.Background(), tt.args.id, tt.args.userId)
if (err != nil) != tt.wantErr {
t.Errorf("GetTraningItem() error = %v, wantErr %v", err, tt.wantErr)
return
Expand Down Expand Up @@ -218,7 +219,7 @@ func TestUpdateTraningItem(t *testing.T) {
}
defer testenvironment.TeardownTraningItemTestData()

if err := UpdateTraningItem(tt.args.input, tt.args.userId); (err != nil) != tt.wantErr {
if err := UpdateTraningItem(context.Background(), tt.args.input, tt.args.userId); (err != nil) != tt.wantErr {
t.Errorf("UpdateTraningItem() error = %v, wantErr %v", err, tt.wantErr)
}
})
Expand Down Expand Up @@ -266,7 +267,7 @@ func TestDeleteTraningItem(t *testing.T) {
}
defer testenvironment.TeardownTraningItemTestData()

if err := DeleteTraningItem(tt.args.id, tt.args.userId); (err != nil) != tt.wantErr {
if err := DeleteTraningItem(context.Background(), tt.args.id, tt.args.userId); (err != nil) != tt.wantErr {
t.Errorf("DeleteTraningItem() error = %v, wantErr %v", err, tt.wantErr)
}
})
Expand Down
6 changes: 3 additions & 3 deletions test/environment/dynamodb_local.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func (*resolverV2) ResolveEndpoint(ctx context.Context, params dynamodb.Endpoint
}

func SetupDynamoDbClient() error {
cfg, err := config.LoadDefaultConfig(context.TODO())
cfg, err := config.LoadDefaultConfig(context.Background())

if err != nil {
logger.Logger.Error("Load aws config error.", logger.ErrAttr(err))
Expand Down Expand Up @@ -96,7 +96,7 @@ func deleteAllItems(tableName string, keyName string) {
params := &dynamodb.ScanInput{
TableName: aws.String(tableName),
}
deleteItems, err := anytoreDynamodb.DynamoDbClient.Scan(context.TODO(), params)
deleteItems, err := anytoreDynamodb.DynamoDbClient.Scan(context.Background(), params)
if err != nil {
logger.Logger.Error("Failed to scan items.", logger.ErrAttr(err))
return
Expand All @@ -111,7 +111,7 @@ func deleteAllItems(tableName string, keyName string) {
},
}

_, err := anytoreDynamodb.DynamoDbClient.DeleteItem(context.TODO(), deleteParams)
_, err := anytoreDynamodb.DynamoDbClient.DeleteItem(context.Background(), deleteParams)
if err != nil {
logger.Logger.Error("Failed to delete item.", logger.ErrAttr(err))
return
Expand Down

0 comments on commit 57b5c0e

Please sign in to comment.