Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor flyteadmin to pass proto structs as pointers #5717

Merged
merged 4 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions flyteadmin/.golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,6 @@ linters-settings:
- prefix(github.com/flyteorg)
skip-generated: true
issues:
exclude:
- copylocks
exclude-rules:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine grained ignore

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does this still need to be excluded?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/flyteorg/flyte/pull/5717/files/0f75532e95adab4fb43c5ce380a9b11808eced30#diff-68d7b4389844bb5dec320a0854bdb20122c9fcb4356c32dce5dff80a55ee486eR33-R34

The code copies the a security context to the flyte workflow custom resource, which stores it as a non-pointer so I think its inevitable unless I write some conversion method to traverse the security context and reconstruct a non-pointer version. I felt it was probably less error prone and more future proof to do a deep copy and ignore the warning.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for explaining, leaving this as an exception makes sense!

- path: pkg/workflowengine/impl/prepare_execution.go
text: "copies lock"
32 changes: 16 additions & 16 deletions flyteadmin/dataproxy/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@
// Lookup task, node, workflow execution
var nativeURL string
if nodeExecutionIDEnvelope, casted := req.GetSource().(*service.CreateDownloadLinkRequest_NodeExecutionId); casted {
node, err := s.nodeExecutionManager.GetNodeExecution(ctx, admin.NodeExecutionGetRequest{
node, err := s.nodeExecutionManager.GetNodeExecution(ctx, &admin.NodeExecutionGetRequest{
Id: nodeExecutionIDEnvelope.NodeExecutionId,
})

Expand Down Expand Up @@ -309,9 +309,9 @@

// GetCompleteTaskExecutionID returns the task execution identifier for the task execution with the Task ID filled in.
// The one coming from the node execution doesn't have this as this is not data encapsulated in the flyte url.
func (s Service) GetCompleteTaskExecutionID(ctx context.Context, taskExecID core.TaskExecutionIdentifier) (*core.TaskExecutionIdentifier, error) {
func (s Service) GetCompleteTaskExecutionID(ctx context.Context, taskExecID *core.TaskExecutionIdentifier) (*core.TaskExecutionIdentifier, error) {

taskExecs, err := s.taskExecutionManager.ListTaskExecutions(ctx, admin.TaskExecutionListRequest{
taskExecs, err := s.taskExecutionManager.ListTaskExecutions(ctx, &admin.TaskExecutionListRequest{
NodeExecutionId: taskExecID.GetNodeExecutionId(),
Limit: 1,
Filters: fmt.Sprintf("eq(retry_attempt,%s)", strconv.Itoa(int(taskExecID.RetryAttempt))),
Expand All @@ -326,9 +326,9 @@
return taskExec.Id, nil
}

func (s Service) GetTaskExecutionID(ctx context.Context, attempt int, nodeExecID core.NodeExecutionIdentifier) (*core.TaskExecutionIdentifier, error) {
taskExecs, err := s.taskExecutionManager.ListTaskExecutions(ctx, admin.TaskExecutionListRequest{
NodeExecutionId: &nodeExecID,
func (s Service) GetTaskExecutionID(ctx context.Context, attempt int, nodeExecID *core.NodeExecutionIdentifier) (*core.TaskExecutionIdentifier, error) {
taskExecs, err := s.taskExecutionManager.ListTaskExecutions(ctx, &admin.TaskExecutionListRequest{
NodeExecutionId: nodeExecID,
Limit: 1,
Filters: fmt.Sprintf("eq(retry_attempt,%s)", strconv.Itoa(attempt)),
})
Expand All @@ -342,11 +342,11 @@
return taskExec.Id, nil
}

func (s Service) GetDataFromNodeExecution(ctx context.Context, nodeExecID core.NodeExecutionIdentifier, ioType common.ArtifactType, name string) (
func (s Service) GetDataFromNodeExecution(ctx context.Context, nodeExecID *core.NodeExecutionIdentifier, ioType common.ArtifactType, name string) (
*service.GetDataResponse, error) {

resp, err := s.nodeExecutionManager.GetNodeExecutionData(ctx, admin.NodeExecutionGetDataRequest{
Id: &nodeExecID,
resp, err := s.nodeExecutionManager.GetNodeExecutionData(ctx, &admin.NodeExecutionGetDataRequest{
Id: nodeExecID,
})
if err != nil {
return nil, err
Expand All @@ -361,7 +361,7 @@
// Assume deck, and create a download link request
dlRequest := service.CreateDownloadLinkRequest{
ArtifactType: service.ArtifactType_ARTIFACT_TYPE_DECK,
Source: &service.CreateDownloadLinkRequest_NodeExecutionId{NodeExecutionId: &nodeExecID},
Source: &service.CreateDownloadLinkRequest_NodeExecutionId{NodeExecutionId: nodeExecID},

Check warning on line 364 in flyteadmin/dataproxy/service.go

View check run for this annotation

Codecov / codecov/patch

flyteadmin/dataproxy/service.go#L364

Added line #L364 was not covered by tests
}
resp, err := s.CreateDownloadLink(ctx, &dlRequest)
if err != nil {
Expand Down Expand Up @@ -391,12 +391,12 @@
}, nil
}

func (s Service) GetDataFromTaskExecution(ctx context.Context, taskExecID core.TaskExecutionIdentifier, ioType common.ArtifactType, name string) (
func (s Service) GetDataFromTaskExecution(ctx context.Context, taskExecID *core.TaskExecutionIdentifier, ioType common.ArtifactType, name string) (
*service.GetDataResponse, error) {

var lm *core.LiteralMap
reqT := admin.TaskExecutionGetDataRequest{
Id: &taskExecID,
reqT := &admin.TaskExecutionGetDataRequest{
Id: taskExecID,
}
resp, err := s.taskExecutionManager.GetTaskExecutionData(ctx, reqT)
if err != nil {
Expand Down Expand Up @@ -445,13 +445,13 @@
}

if execution.NodeExecID != nil {
return s.GetDataFromNodeExecution(ctx, *execution.NodeExecID, execution.IOType, execution.LiteralName)
return s.GetDataFromNodeExecution(ctx, execution.NodeExecID, execution.IOType, execution.LiteralName)
} else if execution.PartialTaskExecID != nil {
taskExecID, err := s.GetCompleteTaskExecutionID(ctx, *execution.PartialTaskExecID)
taskExecID, err := s.GetCompleteTaskExecutionID(ctx, execution.PartialTaskExecID)
if err != nil {
return nil, err
}
return s.GetDataFromTaskExecution(ctx, *taskExecID, execution.IOType, execution.LiteralName)
return s.GetDataFromTaskExecution(ctx, taskExecID, execution.IOType, execution.LiteralName)
}

return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "failed to parse get data request %v", req)
Expand Down
16 changes: 8 additions & 8 deletions flyteadmin/dataproxy/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ func TestCreateUploadLocationMore(t *testing.T) {
func TestCreateDownloadLink(t *testing.T) {
dataStore := commonMocks.GetMockStorageClient()
nodeExecutionManager := &mocks.MockNodeExecutionManager{}
nodeExecutionManager.SetGetNodeExecutionFunc(func(ctx context.Context, request admin.NodeExecutionGetRequest) (*admin.NodeExecution, error) {
nodeExecutionManager.SetGetNodeExecutionFunc(func(ctx context.Context, request *admin.NodeExecutionGetRequest) (*admin.NodeExecution, error) {
return &admin.NodeExecution{
Closure: &admin.NodeExecutionClosure{
DeckUri: "s3://something/something",
Expand Down Expand Up @@ -282,14 +282,14 @@ func TestService_GetData(t *testing.T) {
}

nodeExecutionManager.SetGetNodeExecutionDataFunc(
func(ctx context.Context, request admin.NodeExecutionGetDataRequest) (*admin.NodeExecutionGetDataResponse, error) {
func(ctx context.Context, request *admin.NodeExecutionGetDataRequest) (*admin.NodeExecutionGetDataResponse, error) {
return &admin.NodeExecutionGetDataResponse{
FullInputs: inputsLM,
FullOutputs: outputsLM,
}, nil
},
)
taskExecutionManager.SetListTaskExecutionsCallback(func(ctx context.Context, request admin.TaskExecutionListRequest) (*admin.TaskExecutionList, error) {
taskExecutionManager.SetListTaskExecutionsCallback(func(ctx context.Context, request *admin.TaskExecutionListRequest) (*admin.TaskExecutionList, error) {
return &admin.TaskExecutionList{
TaskExecutions: []*admin.TaskExecution{
{
Expand All @@ -315,7 +315,7 @@ func TestService_GetData(t *testing.T) {
},
}, nil
})
taskExecutionManager.SetGetTaskExecutionDataCallback(func(ctx context.Context, request admin.TaskExecutionGetDataRequest) (*admin.TaskExecutionGetDataResponse, error) {
taskExecutionManager.SetGetTaskExecutionDataCallback(func(ctx context.Context, request *admin.TaskExecutionGetDataRequest) (*admin.TaskExecutionGetDataResponse, error) {
return &admin.TaskExecutionGetDataResponse{
FullInputs: inputsLM,
FullOutputs: outputsLM,
Expand Down Expand Up @@ -388,10 +388,10 @@ func TestService_Error(t *testing.T) {
assert.NoError(t, err)

t.Run("get a working set of urls without retry attempt", func(t *testing.T) {
taskExecutionManager.SetListTaskExecutionsCallback(func(ctx context.Context, request admin.TaskExecutionListRequest) (*admin.TaskExecutionList, error) {
taskExecutionManager.SetListTaskExecutionsCallback(func(ctx context.Context, request *admin.TaskExecutionListRequest) (*admin.TaskExecutionList, error) {
return nil, errors.NewFlyteAdminErrorf(1, "not found")
})
nodeExecID := core.NodeExecutionIdentifier{
nodeExecID := &core.NodeExecutionIdentifier{
NodeId: "n0",
ExecutionId: &core.WorkflowExecutionIdentifier{
Project: "proj",
Expand All @@ -404,13 +404,13 @@ func TestService_Error(t *testing.T) {
})

t.Run("get a working set of urls without retry attempt", func(t *testing.T) {
taskExecutionManager.SetListTaskExecutionsCallback(func(ctx context.Context, request admin.TaskExecutionListRequest) (*admin.TaskExecutionList, error) {
taskExecutionManager.SetListTaskExecutionsCallback(func(ctx context.Context, request *admin.TaskExecutionListRequest) (*admin.TaskExecutionList, error) {
return &admin.TaskExecutionList{
TaskExecutions: nil,
Token: "",
}, nil
})
nodeExecID := core.NodeExecutionIdentifier{
nodeExecID := &core.NodeExecutionIdentifier{
NodeId: "n0",
ExecutionId: &core.WorkflowExecutionIdentifier{
Project: "proj",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,8 @@

// This is a rough copy of the ListTaskExecutions function in TaskExecutionManager. It can be deprecated once we move the processing out of Admin itself.
// Just return the highest retry attempt.
func (c *CloudEventWrappedPublisher) getLatestTaskExecutions(ctx context.Context, nodeExecutionID core.NodeExecutionIdentifier) (*admin.TaskExecution, error) {
ctx = getNodeExecutionContext(ctx, &nodeExecutionID)
func (c *CloudEventWrappedPublisher) getLatestTaskExecutions(ctx context.Context, nodeExecutionID *core.NodeExecutionIdentifier) (*admin.TaskExecution, error) {
ctx = getNodeExecutionContext(ctx, nodeExecutionID)

Check warning on line 208 in flyteadmin/pkg/async/cloudevent/implementations/cloudevent_publisher.go

View check run for this annotation

Codecov / codecov/patch

flyteadmin/pkg/async/cloudevent/implementations/cloudevent_publisher.go#L207-L208

Added lines #L207 - L208 were not covered by tests

identifierFilters, err := util.GetNodeExecutionIdentifierFilters(ctx, nodeExecutionID)
if err != nil {
Expand Down Expand Up @@ -283,7 +283,7 @@
var taskExecID *core.TaskExecutionIdentifier
var typedInterface *core.TypedInterface

lte, err := c.getLatestTaskExecutions(ctx, *rawEvent.Id)
lte, err := c.getLatestTaskExecutions(ctx, rawEvent.Id)

Check warning on line 286 in flyteadmin/pkg/async/cloudevent/implementations/cloudevent_publisher.go

View check run for this annotation

Codecov / codecov/patch

flyteadmin/pkg/async/cloudevent/implementations/cloudevent_publisher.go#L286

Added line #L286 was not covered by tests
if err != nil {
logger.Errorf(ctx, "failed to get latest task execution for node exec id [%+v] with err: %v", rawEvent.Id, err)
return nil, err
Expand Down Expand Up @@ -353,7 +353,7 @@
phase = e.Phase.String()
eventTime = e.OccurredAt.AsTime()

dummyNodeExecutionID := core.NodeExecutionIdentifier{
dummyNodeExecutionID := &core.NodeExecutionIdentifier{

Check warning on line 356 in flyteadmin/pkg/async/cloudevent/implementations/cloudevent_publisher.go

View check run for this annotation

Codecov / codecov/patch

flyteadmin/pkg/async/cloudevent/implementations/cloudevent_publisher.go#L356

Added line #L356 was not covered by tests
NodeId: "end-node",
ExecutionId: e.ExecutionId,
}
Expand All @@ -378,7 +378,7 @@
if e.ParentNodeExecutionId == nil {
return fmt.Errorf("parent node execution id is nil for task execution [%+v]", e)
}
eventSource = common.FlyteURLKeyFromNodeExecutionIDRetry(*e.ParentNodeExecutionId,
eventSource = common.FlyteURLKeyFromNodeExecutionIDRetry(e.ParentNodeExecutionId,

Check warning on line 381 in flyteadmin/pkg/async/cloudevent/implementations/cloudevent_publisher.go

View check run for this annotation

Codecov / codecov/patch

flyteadmin/pkg/async/cloudevent/implementations/cloudevent_publisher.go#L381

Added line #L381 was not covered by tests
int(e.RetryAttempt))
finalMsg, err = c.TransformTaskExecutionEvent(ctx, e)
if err != nil {
Expand All @@ -392,7 +392,7 @@
phase = e.Phase.String()
eventTime = e.OccurredAt.AsTime()
eventID = fmt.Sprintf("%v.%v", executionID, phase)
eventSource = common.FlyteURLKeyFromNodeExecutionID(*msgType.Event.Id)
eventSource = common.FlyteURLKeyFromNodeExecutionID(msgType.Event.Id)

Check warning on line 395 in flyteadmin/pkg/async/cloudevent/implementations/cloudevent_publisher.go

View check run for this annotation

Codecov / codecov/patch

flyteadmin/pkg/async/cloudevent/implementations/cloudevent_publisher.go#L395

Added line #L395 was not covered by tests
finalMsg, err = c.TransformNodeExecutionEvent(ctx, e)
if err != nil {
logger.Errorf(ctx, "Failed to transform node execution event with error: %v", err)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ import (
// events, node execution processing doesn't have to wait on these to be committed.
type nodeExecutionEventWriter struct {
db repositoryInterfaces.Repository
events chan admin.NodeExecutionEventRequest
events chan *admin.NodeExecutionEventRequest
}

func (w *nodeExecutionEventWriter) Write(event admin.NodeExecutionEventRequest) {
func (w *nodeExecutionEventWriter) Write(event *admin.NodeExecutionEventRequest) {
w.events <- event
}

Expand All @@ -40,6 +40,6 @@ func (w *nodeExecutionEventWriter) Run() {
func NewNodeExecutionEventWriter(db repositoryInterfaces.Repository, bufferSize int) interfaces.NodeExecutionEventWriter {
return &nodeExecutionEventWriter{
db: db,
events: make(chan admin.NodeExecutionEventRequest, bufferSize),
events: make(chan *admin.NodeExecutionEventRequest, bufferSize),
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
func TestNodeExecutionEventWriter(t *testing.T) {
db := mocks.NewMockRepository()

event := admin.NodeExecutionEventRequest{
event := &admin.NodeExecutionEventRequest{
RequestId: "request_id",
Event: &event2.NodeExecutionEvent{
Id: &core.NodeExecutionIdentifier{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ import (
// events, workflow execution processing doesn't have to wait on these to be committed.
type workflowExecutionEventWriter struct {
db repositoryInterfaces.Repository
events chan admin.WorkflowExecutionEventRequest
events chan *admin.WorkflowExecutionEventRequest
}

func (w *workflowExecutionEventWriter) Write(event admin.WorkflowExecutionEventRequest) {
func (w *workflowExecutionEventWriter) Write(event *admin.WorkflowExecutionEventRequest) {
w.events <- event
}

Expand All @@ -40,6 +40,6 @@ func (w *workflowExecutionEventWriter) Run() {
func NewWorkflowExecutionEventWriter(db repositoryInterfaces.Repository, bufferSize int) interfaces.WorkflowExecutionEventWriter {
return &workflowExecutionEventWriter{
db: db,
events: make(chan admin.WorkflowExecutionEventRequest, bufferSize),
events: make(chan *admin.WorkflowExecutionEventRequest, bufferSize),
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
func TestWorkflowExecutionEventWriter(t *testing.T) {
db := mocks.NewMockRepository()

event := admin.WorkflowExecutionEventRequest{
event := &admin.WorkflowExecutionEventRequest{
RequestId: "request_id",
Event: &event2.WorkflowExecutionEvent{
ExecutionId: &core.WorkflowExecutionIdentifier{
Expand Down
2 changes: 1 addition & 1 deletion flyteadmin/pkg/async/events/interfaces/node_execution.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ import (

type NodeExecutionEventWriter interface {
Run()
Write(nodeExecutionEvent admin.NodeExecutionEventRequest)
Write(nodeExecutionEvent *admin.NodeExecutionEventRequest)
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ import (

type WorkflowExecutionEventWriter interface {
Run()
Write(workflowExecutionEvent admin.WorkflowExecutionEventRequest)
Write(workflowExecutionEvent *admin.WorkflowExecutionEventRequest)
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading