Skip to content

Commit

Permalink
Implement GetTaskListSize
Browse files Browse the repository at this point in the history
  • Loading branch information
Shaddoll committed Nov 13, 2023
1 parent 6b85fa1 commit c070128
Show file tree
Hide file tree
Showing 13 changed files with 176 additions and 3 deletions.
2 changes: 1 addition & 1 deletion common/persistence/dataManagerInterfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -1279,7 +1279,7 @@ type (
}

GetTaskListSizeResponse struct {
Size int
Size int64
}

// CreateTasksRequest is used to create a new task for a workflow exectution
Expand Down
1 change: 1 addition & 0 deletions common/persistence/dataStoreInterfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ type (
UpdateTaskList(ctx context.Context, request *UpdateTaskListRequest) (*UpdateTaskListResponse, error)
ListTaskList(ctx context.Context, request *ListTaskListRequest) (*ListTaskListResponse, error)
DeleteTaskList(ctx context.Context, request *DeleteTaskListRequest) error
GetTaskListSize(ctx context.Context, request *GetTaskListSizeRequest) (*GetTaskListSizeResponse, error)
CreateTasks(ctx context.Context, request *InternalCreateTasksRequest) (*CreateTasksResponse, error)
GetTasks(ctx context.Context, request *GetTasksRequest) (*InternalGetTasksResponse, error)
CompleteTask(ctx context.Context, request *CompleteTaskRequest) error
Expand Down
20 changes: 20 additions & 0 deletions common/persistence/nosql/nosqlTaskStore.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/uber/cadence/common"
"github.com/uber/cadence/common/config"
"github.com/uber/cadence/common/log"
"github.com/uber/cadence/common/persistence"
p "github.com/uber/cadence/common/persistence"
"github.com/uber/cadence/common/persistence/nosql/nosqlplugin"
"github.com/uber/cadence/common/types"
Expand Down Expand Up @@ -71,6 +72,25 @@ func (t *nosqlTaskStore) GetOrphanTasks(ctx context.Context, request *p.GetOrpha
}
}

func (t *nosqlTaskStore) GetTaskListSize(ctx context.Context, request *p.GetTaskListSizeRequest) (*p.GetTaskListSizeResponse, error) {
storeShard, err := t.GetStoreShardByTaskList(request.DomainID, request.TaskListName, request.TaskListType)
if err != nil {
return nil, err
}
size, err := storeShard.db.GetTasksCount(ctx, &nosqlplugin.TasksFilter{
TaskListFilter: nosqlplugin.TaskListFilter{
DomainID: request.DomainID,
TaskListName: request.TaskListName,
TaskListType: request.TaskListType,
},
MinTaskID: request.AckLevel,
})
if err != nil {
return nil, err
}
return &persistence.GetTaskListSizeResponse{Size: size}, nil
}

func (t *nosqlTaskStore) LeaseTaskList(
ctx context.Context,
request *p.LeaseTaskListRequest,
Expand Down
26 changes: 26 additions & 0 deletions common/persistence/nosql/nosqlplugin/cassandra/tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,14 @@ const (
`and task_id > ? ` +
`and task_id <= ?`

templateGetTasksCountQuery = `SELECT count(1) as count ` +
`FROM tasks ` +
`WHERE domain_id = ? ` +
`and task_list_name = ? ` +
`and task_list_type = ? ` +
`and type = ? ` +
`and task_id > ? `

templateCompleteTasksLessThanQuery = `DELETE FROM tasks ` +
`WHERE domain_id = ? ` +
`AND task_list_name = ? ` +
Expand Down Expand Up @@ -408,6 +416,24 @@ func (db *cdb) InsertTasks(
return handleTaskListAppliedError(applied, previous)
}

// GetTasksCount returns number of tasks from a tasklist
func (db *cdb) GetTasksCount(ctx context.Context, filter *nosqlplugin.TasksFilter) (int64, error) {
query := db.session.Query(templateGetTasksCountQuery,
filter.DomainID,
filter.TaskListName,
filter.TaskListType,
rowTypeTask,
filter.MinTaskID,
).WithContext(ctx)
result := make(map[string]interface{})
if err := query.MapScan(result); err != nil {
return 0, err
}

queueSize := result["count"].(int64)
return queueSize, nil
}

// SelectTasks return tasks that associated to a tasklist
func (db *cdb) SelectTasks(ctx context.Context, filter *nosqlplugin.TasksFilter) ([]*nosqlplugin.TaskRow, error) {
// Reading tasklist tasks need to be quorum level consistent, otherwise we could loose task
Expand Down
2 changes: 2 additions & 0 deletions common/persistence/nosql/nosqlplugin/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,8 @@ type (
// DeleteTask delete a batch of tasks
// Also return the number of rows deleted -- if it's not supported then ignore the batchSize, and return persistence.UnknownNumRowsAffected
RangeDeleteTasks(ctx context.Context, filter *TasksFilter) (rowsDeleted int, err error)
// GetTasksCount return the number of tasks
GetTasksCount(ctx context.Context, filter *TasksFilter) (int64, error)
}

/**
Expand Down
45 changes: 45 additions & 0 deletions common/persistence/nosql/nosqlplugin/interfaces_mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 26 additions & 0 deletions common/persistence/persistence-tests/matchingPersistenceTest.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,32 @@ func (s *MatchingPersistenceSuite) TestGetDecisionTasks() {
s.Equal(partitionConfig, tasks1Response.Tasks[0].PartitionConfig)
}

func (s *MatchingPersistenceSuite) TestGetTaskListSize() {
ctx, cancel := context.WithTimeout(context.Background(), testContextTimeout)
defer cancel()

domainID := uuid.New()
workflowExecution := types.WorkflowExecution{WorkflowID: "get-decision-task-test",
RunID: "db20f7e2-1a1e-40d9-9278-d8b886738e05"}
taskList := "d8b886738e05"
partitionConfig := map[string]string{"userid": uuid.New()}

size, err1 := s.GetDecisionTaskListSize(ctx, domainID, taskList, 0)
s.NoError(err1)
s.Equal(int64(0), size)

task0, err0 := s.CreateDecisionTask(ctx, domainID, workflowExecution, taskList, 5, partitionConfig)
s.NoError(err0)

size, err1 = s.GetDecisionTaskListSize(ctx, domainID, taskList, task0)
s.NoError(err1)
s.Equal(int64(0), size)

size, err1 = s.GetDecisionTaskListSize(ctx, domainID, taskList, task0-1)
s.NoError(err1)
s.Equal(int64(1), size)
}

// TestGetTasksWithNoMaxReadLevel test
func (s *MatchingPersistenceSuite) TestGetTasksWithNoMaxReadLevel() {
ctx, cancel := context.WithTimeout(context.Background(), testContextTimeout)
Expand Down
13 changes: 13 additions & 0 deletions common/persistence/persistence-tests/persistenceTestBase.go
Original file line number Diff line number Diff line change
Expand Up @@ -1691,6 +1691,19 @@ func (s *TestBase) CreateDecisionTask(ctx context.Context, domainID string, work
return taskID, err
}

func (s *TestBase) GetDecisionTaskListSize(ctx context.Context, domainID, taskList string, ackLevel int64) (int64, error) {
resp, err := s.TaskMgr.GetTaskListSize(ctx, &persistence.GetTaskListSizeRequest{
DomainID: domainID,
TaskListName: taskList,
TaskListType: persistence.TaskListTypeDecision,
AckLevel: ackLevel,
})
if err != nil {
return 0, err
}
return resp.Size, nil
}

// CreateActivityTasks is a utility method to create tasks
func (s *TestBase) CreateActivityTasks(ctx context.Context, domainID string, workflowExecution types.WorkflowExecution,
activities map[int64]string, partitionConfig map[string]string) ([]int64, error) {
Expand Down
16 changes: 16 additions & 0 deletions common/persistence/sql/sqlTaskStore.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,22 @@ func newTaskPersistence(
}, nil
}

func (m *sqlTaskStore) GetTaskListSize(ctx context.Context, request *persistence.GetTaskListSizeRequest) (*persistence.GetTaskListSizeResponse, error) {
dbShardID := sqlplugin.GetDBShardIDFromDomainIDAndTasklist(request.DomainID, request.TaskListName, m.db.GetTotalNumDBShards())
domainID := serialization.MustParseUUID(request.DomainID)
size, err := m.db.GetTasksCount(ctx, &sqlplugin.TasksFilter{
ShardID: dbShardID,
DomainID: domainID,
TaskListName: request.TaskListName,
TaskType: int64(request.TaskListType),
MinTaskID: &request.AckLevel,
})
if err != nil {
return nil, err
}
return &persistence.GetTaskListSizeResponse{Size: size}, nil
}

func (m *sqlTaskStore) LeaseTaskList(
ctx context.Context,
request *persistence.LeaseTaskListRequest,
Expand Down
1 change: 1 addition & 0 deletions common/persistence/sql/sqlplugin/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,7 @@ type (
// - {domainID, tasklistName, taskType, taskIDLessThanEquals, limit }
// - this will delete up to limit number of tasks less than or equal to the given task id
DeleteFromTasks(ctx context.Context, filter *TasksFilter) (sql.Result, error)
GetTasksCount(ctx context.Context, filter *TasksFilter) (int64, error)
GetOrphanTasks(ctx context.Context, filter *OrphanTasksFilter) ([]TaskKeyRow, error)

InsertIntoTaskLists(ctx context.Context, row *TaskListsRow) (sql.Result, error)
Expand Down
12 changes: 12 additions & 0 deletions common/persistence/sql/sqlplugin/mysql/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ task_type = :task_type
`FROM tasks ` +
`WHERE domain_id = ? AND task_list_name = ? AND task_type = ? AND task_id > ? ORDER BY task_id LIMIT ?`

getTasksCountQry = `SELECT count(1) as count ` +
`FROM tasks ` +
`WHERE domain_id = ? AND task_list_name = ? AND task_type = ? AND task_id > ?`

createTaskQry = `INSERT INTO ` +
`tasks(domain_id, task_list_name, task_type, task_id, data, data_encoding) ` +
`VALUES(:domain_id, :task_list_name, :task_type, :task_id, :data, :data_encoding)`
Expand Down Expand Up @@ -209,6 +213,14 @@ func (mdb *db) LockTaskLists(ctx context.Context, filter *sqlplugin.TaskListsFil
return rangeID, err
}

func (mdb *db) GetTasksCount(ctx context.Context, filter *sqlplugin.TasksFilter) (int64, error) {
var size []int64
if err := mdb.driver.SelectContext(ctx, filter.ShardID, &size, getTasksCountQry, filter.DomainID, filter.TaskListName, filter.TaskType, *filter.MinTaskID); err != nil {
return 0, err
}
return size[0], nil
}

// InsertIntoTasksWithTTL is not supported in MySQL
func (mdb *db) InsertIntoTasksWithTTL(_ context.Context, _ []sqlplugin.TasksRowWithTTL) (sql.Result, error) {
return nil, sqlplugin.ErrTTLNotSupported
Expand Down
12 changes: 12 additions & 0 deletions common/persistence/sql/sqlplugin/postgres/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ task_type = :task_type
`FROM tasks ` +
`WHERE domain_id = $1 AND task_list_name = $2 AND task_type = $3 AND task_id > $4 ORDER BY task_id LIMIT $5`

getTasksCountQry = `SELECT count(1) as count ` +
`FROM tasks ` +
`WHERE domain_id = $1 AND task_list_name = $2 AND task_type = $3 AND task_id > $4`

createTaskQry = `INSERT INTO ` +
`tasks(domain_id, task_list_name, task_type, task_id, data, data_encoding) ` +
`VALUES(:domain_id, :task_list_name, :task_type, :task_id, :data, :data_encoding)`
Expand Down Expand Up @@ -139,6 +143,14 @@ func (pdb *db) DeleteFromTasks(ctx context.Context, filter *sqlplugin.TasksFilte
return pdb.driver.ExecContext(ctx, filter.ShardID, deleteTaskQry, filter.DomainID, filter.TaskListName, filter.TaskType, *filter.TaskID)
}

func (pdb *db) GetTasksCount(ctx context.Context, filter *sqlplugin.TasksFilter) (int64, error) {
var size []int64
if err := pdb.driver.SelectContext(ctx, filter.ShardID, &size, getTasksCountQry, filter.DomainID, filter.TaskListName, filter.TaskType, *filter.MinTaskID); err != nil {
return 0, err
}
return size[0], nil
}

func (pdb *db) GetOrphanTasks(ctx context.Context, filter *sqlplugin.OrphanTasksFilter) ([]sqlplugin.TaskKeyRow, error) {
if filter.Limit == nil || *filter.Limit == 0 {
return nil, fmt.Errorf("missing limit parameter")
Expand Down
3 changes: 1 addition & 2 deletions common/persistence/taskManager.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import (
"context"

"github.com/uber/cadence/common"
"github.com/uber/cadence/common/types"
)

type (
Expand Down Expand Up @@ -71,7 +70,7 @@ func (t *taskManager) DeleteTaskList(ctx context.Context, request *DeleteTaskLis
}

func (t *taskManager) GetTaskListSize(ctx context.Context, request *GetTaskListSizeRequest) (*GetTaskListSizeResponse, error) {
return nil, &types.InternalServiceError{Message: "Not yet implemented"}
return t.persistence.GetTaskListSize(ctx, request)
}

func (t *taskManager) CreateTasks(ctx context.Context, request *CreateTasksRequest) (*CreateTasksResponse, error) {
Expand Down

0 comments on commit c070128

Please sign in to comment.