Skip to content

Commit

Permalink
[Artifacts] Use new node event information (#320)
Browse files Browse the repository at this point in the history
This is the second of two PRs to support subworkflow artifact creation.  The first is an [event change](unionai/flyte#315), which was needed to get additional information added to the event.  This one uses those new events so that the artifacts event handler knows what to do with subworkflow events.

[Cloud PR](https://github.com/unionai/cloud/pull/7930/files) to pull this in.

Changes:
* The CloudEvent event publisher amends events with additional information (information used by the Artifacts event processor).  But we skip certain nodes.  When skipping "start-node"s and "end-node"s, now also check the `SpecNodeId`.
* Refactor the `TransformNodeExecutionEvent` function (which does the aforementioned information-adding) to get the additional information in different functions. 
* Change pointers/values in one place to remove a lint error.
* Add an event publisher for local testing.  This is the `DebugOnlyDiskEventPublisher`.  Given its name, I think it's okay to commit this but I can always delete it and keep it elsewhere too.

Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor authored Jun 13, 2024
1 parent 5ec9fe3 commit 2616d4c
Show file tree
Hide file tree
Showing 2 changed files with 271 additions and 25 deletions.
131 changes: 106 additions & 25 deletions flyteadmin/pkg/async/cloudevent/implementations/cloudevent_publisher.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,10 @@ func getNodeExecutionContext(ctx context.Context, identifier *core.NodeExecution

// This is a rough copy of the ListTaskExecutions function in TaskExecutionManager. It can be deprecated once we move the processing out of Admin itself.
// Just return the highest retry attempt.
func (c *CloudEventWrappedPublisher) getLatestTaskExecutions(ctx context.Context, nodeExecutionID core.NodeExecutionIdentifier) (*admin.TaskExecution, error) {
ctx = getNodeExecutionContext(ctx, &nodeExecutionID)
func (c *CloudEventWrappedPublisher) getLatestTaskExecutions(ctx context.Context, nodeExecutionID *core.NodeExecutionIdentifier) (*admin.TaskExecution, error) {
ctx = getNodeExecutionContext(ctx, nodeExecutionID)

identifierFilters, err := util.GetNodeExecutionIdentifierFilters(ctx, nodeExecutionID)
identifierFilters, err := util.GetNodeExecutionIdentifierFilters(ctx, *nodeExecutionID)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -251,17 +251,27 @@ func (c *CloudEventWrappedPublisher) TransformNodeExecutionEvent(ctx context.Con
return nil, fmt.Errorf("nothing to publish, NodeExecution event or ID is nil")
}

// Skip nodes unless they're succeeded and not start nodes
// Don't bother trying to add additional information to events unless it's useful to
var skip = false
if rawEvent.Phase != core.NodeExecution_SUCCEEDED {
return &event.CloudEventNodeExecution{
RawEvent: rawEvent,
}, nil
} else if rawEvent.Id.NodeId == "start-node" {
// Skip non-succeeded nodes
logger.Debugf(ctx, "skipping non-succeeded event [%+v]", rawEvent.Id)
skip = true
} else if rawEvent.Id.NodeId == "start-node" || rawEvent.SpecNodeId == "start-node" ||
rawEvent.Id.NodeId == "end-node" || rawEvent.SpecNodeId == "end-node" {
// Skip start and end nodes
logger.Debugf(ctx, "skipping start/end node event [%+v]", rawEvent.Id)
skip = true
} else if rawEvent.IsInDynamicChain {
// Skip nodes that came from a dynamic task for now because it's a bit too tricky to get at the right task/wf definition
logger.Debugf(ctx, "skipping dynamic chain event [%+v]", rawEvent.Id)
skip = true
}
if skip {
return &event.CloudEventNodeExecution{
RawEvent: rawEvent,
}, nil
}
// metric

// This gets the parent workflow execution metadata
executionModel, err := c.db.ExecutionRepo().Get(ctx, repositoryInterfaces.Identifier{
Expand All @@ -281,15 +291,94 @@ func (c *CloudEventWrappedPublisher) TransformNodeExecutionEvent(ctx context.Con
fmt.Printf("there was an error with spec %v %v", err, executionModel.Spec)
}

// Fetch the latest task execution if any, and pull out the task interface, if applicable.
// These are optional fields... if the node execution doesn't have a task execution then these will be empty.
// Multiple cases here that this function should handle.

if rawEvent.GetTaskNodeMetadata() != nil || rawEvent.GetIsDynamic() {
// Existence of task node metadata implies this is a task node
// Fetch the latest task execution if any, and pull out the task interface, if applicable.
// These are optional fields... if the node execution doesn't have a task execution then these will be empty.
taskExecID, typedInterface, err := c.getTaskExecutionSupplemental(ctx, rawEvent.Id)
if err != nil {
logger.Errorf(ctx, "failed to get additional task information for node exec id [%+v] with err: %v", rawEvent.Id, err)
return &event.CloudEventNodeExecution{
RawEvent: rawEvent,
}, nil
}

return &event.CloudEventNodeExecution{
RawEvent: rawEvent,
TaskExecId: taskExecID,
OutputInterface: typedInterface,
ArtifactIds: spec.GetMetadata().GetArtifactIds(),
Principal: spec.GetMetadata().Principal,
LaunchPlanId: spec.LaunchPlan,
}, nil

} else if rawEvent.GetIsParent() && rawEvent.GetTargetEntity() != nil &&
rawEvent.GetTargetEntity().ResourceType == core.ResourceType_WORKFLOW {
// This is a sub workflow node

typedInterface, err := c.getSubWorkflowExecutionSupplemental(ctx, rawEvent.GetTargetEntity())
if err != nil {
logger.Errorf(ctx, "failed to get additional subwf information for node exec id [%+v] with err: %v", rawEvent.Id, err)
return &event.CloudEventNodeExecution{
RawEvent: rawEvent,
}, nil
}

return &event.CloudEventNodeExecution{
RawEvent: rawEvent,
OutputInterface: typedInterface,
ArtifactIds: spec.GetMetadata().GetArtifactIds(),
Principal: spec.GetMetadata().Principal,
LaunchPlanId: spec.LaunchPlan,
}, nil

} else {
// Unhandled case, just return it
logger.Debugf(ctx, "unhandled node execution event, sending as raw [%+v]", rawEvent.Id)
return &event.CloudEventNodeExecution{
RawEvent: rawEvent,
}, nil
}

}

func (c *CloudEventWrappedPublisher) getSubWorkflowExecutionSupplemental(ctx context.Context, subWorkflowID *core.Identifier) (*core.TypedInterface, error) {
workflowModel, err := c.db.WorkflowRepo().Get(ctx, repositoryInterfaces.Identifier{
Project: subWorkflowID.Project,
Domain: subWorkflowID.Domain,
Name: subWorkflowID.Name,
Version: subWorkflowID.Version,
Org: subWorkflowID.Org,
})
if err != nil {
logger.Infof(ctx, "couldn't find workflow [%+v] for cloud event processing", subWorkflowID)
return nil, err
}

var workflowInterface core.TypedInterface
if workflowModel.TypedInterface != nil && len(workflowModel.TypedInterface) > 0 {
err = proto.Unmarshal(workflowModel.TypedInterface, &workflowInterface)
if err != nil {
return nil, fmt.Errorf(
"artifact eventing - failed to unmarshal TypedInterface for workflow [%+v] with err: %v",
workflowModel.ID, err)
}
}
return &workflowInterface, nil
}

func (c *CloudEventWrappedPublisher) getTaskExecutionSupplemental(ctx context.Context, nodeExecutionID *core.NodeExecutionIdentifier) (
*core.TaskExecutionIdentifier, *core.TypedInterface, error) {

var taskExecID *core.TaskExecutionIdentifier
var typedInterface *core.TypedInterface

lte, err := c.getLatestTaskExecutions(ctx, *rawEvent.Id)
lte, err := c.getLatestTaskExecutions(ctx, nodeExecutionID)
if err != nil {
logger.Errorf(ctx, "failed to get latest task execution for node exec id [%+v] with err: %v", rawEvent.Id, err)
return nil, err
logger.Errorf(ctx, "failed to get latest task execution for node exec id [%+v] with err: %v", nodeExecutionID, err)
return nil, nil, err
}
if lte != nil {
taskModel, err := c.db.TaskRepo().Get(ctx, repositoryInterfaces.Identifier{
Expand All @@ -301,27 +390,19 @@ func (c *CloudEventWrappedPublisher) TransformNodeExecutionEvent(ctx context.Con
})
if err != nil {
// TODO: metric this
// metric
logger.Debugf(ctx, "Failed to get task with task id [%+v] with err %v", lte.Id.TaskId, err)
return nil, err
return nil, nil, err
}
task, err := transformers.FromTaskModel(taskModel)
if err != nil {
logger.Debugf(ctx, "Failed to transform task model with err %v", err)
return nil, err
return nil, nil, err
}
typedInterface = task.Closure.CompiledTask.Template.Interface
taskExecID = lte.Id
}

return &event.CloudEventNodeExecution{
RawEvent: rawEvent,
TaskExecId: taskExecID,
OutputInterface: typedInterface,
ArtifactIds: spec.GetMetadata().GetArtifactIds(),
Principal: spec.GetMetadata().Principal,
LaunchPlanId: spec.LaunchPlan,
}, nil
return taskExecID, typedInterface, nil
}

func (c *CloudEventWrappedPublisher) TransformTaskExecutionEvent(ctx context.Context, rawEvent *event.TaskExecutionEvent) (*event.CloudEventTaskExecution, error) {
Expand Down
165 changes: 165 additions & 0 deletions flyteadmin/pkg/async/notifications/implementations/disk_publisher.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
package implementations

import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"os"
"strings"
"sync"

pbcloudevents "github.com/cloudevents/sdk-go/binding/format/protobuf/v2"
cloudevents "github.com/cloudevents/sdk-go/v2"
"github.com/golang/protobuf/proto"

"github.com/flyteorg/flyte/flytestdlib/logger"
)

// FunctionCounter Create a global thread-safe counter that can be incremented with every publish
type FunctionCounter struct {
mu sync.Mutex
count int
}

// Method on the struct that increments the count and performs the function's actions
func (fc *FunctionCounter) increment() int {
fc.mu.Lock()
defer func() {
fc.count++
fc.mu.Unlock()
}()
return fc.count
}

// DebugOnlyDiskEventPublisher implements pubsub.Publisher by writing to a local tmp dir
// To use, add:
//
// case "disk":
// publisher := implementations.NewDebugOnlyDiskEventPublisher()
// return implementations.NewEventsPublisher(publisher, scope, []string{"all"})
//
// to factory.go and update your local configuration
type DebugOnlyDiskEventPublisher struct {
Root string
c FunctionCounter
}

func (d *DebugOnlyDiskEventPublisher) generateFileName(key string) string {
c := d.c.increment()
return fmt.Sprintf("%s/evt_%03d_%s.pb", d.Root, c, key)
}

// Publish Implement the pubsub.Publisher interface
func (d *DebugOnlyDiskEventPublisher) Publish(ctx context.Context, key string, msg proto.Message) error {
mb, err := proto.Marshal(msg)
if err != nil {
return err
}
fname := d.generateFileName(key)
logger.Warningf(ctx, "DSKPUB: Publish [%s - %s] [%+v]", fname, key, msg.String())

// #nosec
return os.WriteFile(fname, mb, 0666)
}

func (d *DebugOnlyDiskEventPublisher) PublishRaw(ctx context.Context, key string, msg []byte) error {
fname := d.generateFileName(key)
logger.Warningf(ctx, "DSKPUB: PublishRaw [%s - %s] [%+v]", fname, key, msg)

// #nosec
return os.WriteFile(fname, msg, 0666)
}

func generateRandomString(n int) (string, error) {
bytes := make([]byte, n)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}

func GetTempFolder() string {
tmpDir := os.Getenv("TMPDIR")
// make a tmp folder
if _, err := os.Stat(tmpDir); os.IsNotExist(err) {
err := os.Mkdir(tmpDir, 0755)
if err != nil {
panic(err)
}
}
r, err := generateRandomString(3)
if err != nil {
panic(err)
}

// Has a / already
folder := fmt.Sprintf("%s%s", tmpDir, r)
if _, err := os.Stat(folder); os.IsNotExist(err) {
err := os.Mkdir(folder, 0755)
if err != nil {
panic(err)
}
}
return folder
}

// NewDebugOnlyDiskEventPublisher Create a new DebugOnlyDiskEventPublisher
func NewDebugOnlyDiskEventPublisher() *DebugOnlyDiskEventPublisher {
// make a random temp sub folder
folder := GetTempFolder()
logger.Warningf(context.Background(), "DSKPUB: Using disk publisher with root [%s] -----", folder)

return &DebugOnlyDiskEventPublisher{
Root: folder,
c: FunctionCounter{},
}
}

// DebugOnlyDiskSenderForCloudEvents implements the Sender interface
// Combines the Sender and Publisher into one, not sure why there are two.
// To use, add:
//
// case "disksender":
// sender = implementations.NewDebugOnlyDiskSenderForCloudEvents()
//
// to cloudevents/factory.go and update your local configuration. This captures all the transformations that the cloud
// event publisher does to the various events.
type DebugOnlyDiskSenderForCloudEvents struct {
Root string
c FunctionCounter
}

func (s *DebugOnlyDiskSenderForCloudEvents) Send(ctx context.Context, notificationType string, event cloudevents.Event) error {
fname := s.generateFileName(notificationType, event.Source())
logger.Warningf(ctx, "DSKPUB: Send [%s - %s] [%+v]", fname, notificationType, event.Source())

eventBytes, err := pbcloudevents.Protobuf.Marshal(&event)
if err != nil {
logger.Errorf(ctx, "Failed to marshal cloudevent with error: %v", err)
panic(err)
}
// #nosec
return os.WriteFile(fname, eventBytes, 0666)

//return nil
}

func (s *DebugOnlyDiskSenderForCloudEvents) generateFileName(notificationType, key string) string {
// The source comes in the form of fce772610fcfc4442a34/n0-0-n0-0-start-node
// This strips out everything before the /
nodeID := strings.Split(key, "/")[1]
c := s.c.increment()
return fmt.Sprintf("%s/cloud_%03d_%s_%s.pb", s.Root, c, notificationType, nodeID)
}

func NewDebugOnlyDiskSenderForCloudEvents() *DebugOnlyDiskSenderForCloudEvents {
// make a random temp sub folder
folder := GetTempFolder()
logger.Warningf(context.Background(), "DSKPUB: Using disk sender with root [%s] -----", folder)

return &DebugOnlyDiskSenderForCloudEvents{
Root: folder,
c: FunctionCounter{},
}
}

0 comments on commit 2616d4c

Please sign in to comment.