diff --git a/flyteadmin/pkg/async/cloudevent/implementations/cloudevent_publisher.go b/flyteadmin/pkg/async/cloudevent/implementations/cloudevent_publisher.go index 01a1e67b176..52b6bff4f93 100644 --- a/flyteadmin/pkg/async/cloudevent/implementations/cloudevent_publisher.go +++ b/flyteadmin/pkg/async/cloudevent/implementations/cloudevent_publisher.go @@ -206,10 +206,10 @@ func getNodeExecutionContext(ctx context.Context, identifier *core.NodeExecution // This is a rough copy of the ListTaskExecutions function in TaskExecutionManager. It can be deprecated once we move the processing out of Admin itself. // Just return the highest retry attempt. -func (c *CloudEventWrappedPublisher) getLatestTaskExecutions(ctx context.Context, nodeExecutionID core.NodeExecutionIdentifier) (*admin.TaskExecution, error) { - ctx = getNodeExecutionContext(ctx, &nodeExecutionID) +func (c *CloudEventWrappedPublisher) getLatestTaskExecutions(ctx context.Context, nodeExecutionID *core.NodeExecutionIdentifier) (*admin.TaskExecution, error) { + ctx = getNodeExecutionContext(ctx, nodeExecutionID) - identifierFilters, err := util.GetNodeExecutionIdentifierFilters(ctx, nodeExecutionID) + identifierFilters, err := util.GetNodeExecutionIdentifierFilters(ctx, *nodeExecutionID) if err != nil { return nil, err } @@ -251,17 +251,27 @@ func (c *CloudEventWrappedPublisher) TransformNodeExecutionEvent(ctx context.Con return nil, fmt.Errorf("nothing to publish, NodeExecution event or ID is nil") } - // Skip nodes unless they're succeeded and not start nodes + // Don't bother trying to add additional information to events unless it's useful to + var skip = false if rawEvent.Phase != core.NodeExecution_SUCCEEDED { - return &event.CloudEventNodeExecution{ - RawEvent: rawEvent, - }, nil - } else if rawEvent.Id.NodeId == "start-node" { + // Skip non-succeeded nodes + logger.Debugf(ctx, "skipping non-succeeded event [%+v]", rawEvent.Id) + skip = true + } else if rawEvent.Id.NodeId == "start-node" || rawEvent.SpecNodeId == "start-node" || + rawEvent.Id.NodeId == "end-node" || rawEvent.SpecNodeId == "end-node" { + // Skip start and end nodes + logger.Debugf(ctx, "skipping start/end node event [%+v]", rawEvent.Id) + skip = true + } else if rawEvent.IsInDynamicChain { + // Skip nodes that came from a dynamic task for now because it's a bit too tricky to get at the right task/wf definition + logger.Debugf(ctx, "skipping dynamic chain event [%+v]", rawEvent.Id) + skip = true + } + if skip { return &event.CloudEventNodeExecution{ RawEvent: rawEvent, }, nil } - // metric // This gets the parent workflow execution metadata executionModel, err := c.db.ExecutionRepo().Get(ctx, repositoryInterfaces.Identifier{ @@ -281,15 +291,94 @@ func (c *CloudEventWrappedPublisher) TransformNodeExecutionEvent(ctx context.Con fmt.Printf("there was an error with spec %v %v", err, executionModel.Spec) } - // Fetch the latest task execution if any, and pull out the task interface, if applicable. - // These are optional fields... if the node execution doesn't have a task execution then these will be empty. + // Multiple cases here that this function should handle. + + if rawEvent.GetTaskNodeMetadata() != nil || rawEvent.GetIsDynamic() { + // Existence of task node metadata implies this is a task node + // Fetch the latest task execution if any, and pull out the task interface, if applicable. + // These are optional fields... if the node execution doesn't have a task execution then these will be empty. + taskExecID, typedInterface, err := c.getTaskExecutionSupplemental(ctx, rawEvent.Id) + if err != nil { + logger.Errorf(ctx, "failed to get additional task information for node exec id [%+v] with err: %v", rawEvent.Id, err) + return &event.CloudEventNodeExecution{ + RawEvent: rawEvent, + }, nil + } + + return &event.CloudEventNodeExecution{ + RawEvent: rawEvent, + TaskExecId: taskExecID, + OutputInterface: typedInterface, + ArtifactIds: spec.GetMetadata().GetArtifactIds(), + Principal: spec.GetMetadata().Principal, + LaunchPlanId: spec.LaunchPlan, + }, nil + + } else if rawEvent.GetIsParent() && rawEvent.GetTargetEntity() != nil && + rawEvent.GetTargetEntity().ResourceType == core.ResourceType_WORKFLOW { + // This is a sub workflow node + + typedInterface, err := c.getSubWorkflowExecutionSupplemental(ctx, rawEvent.GetTargetEntity()) + if err != nil { + logger.Errorf(ctx, "failed to get additional subwf information for node exec id [%+v] with err: %v", rawEvent.Id, err) + return &event.CloudEventNodeExecution{ + RawEvent: rawEvent, + }, nil + } + + return &event.CloudEventNodeExecution{ + RawEvent: rawEvent, + OutputInterface: typedInterface, + ArtifactIds: spec.GetMetadata().GetArtifactIds(), + Principal: spec.GetMetadata().Principal, + LaunchPlanId: spec.LaunchPlan, + }, nil + + } else { + // Unhandled case, just return it + logger.Debugf(ctx, "unhandled node execution event, sending as raw [%+v]", rawEvent.Id) + return &event.CloudEventNodeExecution{ + RawEvent: rawEvent, + }, nil + } + +} + +func (c *CloudEventWrappedPublisher) getSubWorkflowExecutionSupplemental(ctx context.Context, subWorkflowID *core.Identifier) (*core.TypedInterface, error) { + workflowModel, err := c.db.WorkflowRepo().Get(ctx, repositoryInterfaces.Identifier{ + Project: subWorkflowID.Project, + Domain: subWorkflowID.Domain, + Name: subWorkflowID.Name, + Version: subWorkflowID.Version, + Org: subWorkflowID.Org, + }) + if err != nil { + logger.Infof(ctx, "couldn't find workflow [%+v] for cloud event processing", subWorkflowID) + return nil, err + } + + var workflowInterface core.TypedInterface + if workflowModel.TypedInterface != nil && len(workflowModel.TypedInterface) > 0 { + err = proto.Unmarshal(workflowModel.TypedInterface, &workflowInterface) + if err != nil { + return nil, fmt.Errorf( + "artifact eventing - failed to unmarshal TypedInterface for workflow [%+v] with err: %v", + workflowModel.ID, err) + } + } + return &workflowInterface, nil +} + +func (c *CloudEventWrappedPublisher) getTaskExecutionSupplemental(ctx context.Context, nodeExecutionID *core.NodeExecutionIdentifier) ( + *core.TaskExecutionIdentifier, *core.TypedInterface, error) { + var taskExecID *core.TaskExecutionIdentifier var typedInterface *core.TypedInterface - lte, err := c.getLatestTaskExecutions(ctx, *rawEvent.Id) + lte, err := c.getLatestTaskExecutions(ctx, nodeExecutionID) if err != nil { - logger.Errorf(ctx, "failed to get latest task execution for node exec id [%+v] with err: %v", rawEvent.Id, err) - return nil, err + logger.Errorf(ctx, "failed to get latest task execution for node exec id [%+v] with err: %v", nodeExecutionID, err) + return nil, nil, err } if lte != nil { taskModel, err := c.db.TaskRepo().Get(ctx, repositoryInterfaces.Identifier{ @@ -301,27 +390,19 @@ func (c *CloudEventWrappedPublisher) TransformNodeExecutionEvent(ctx context.Con }) if err != nil { // TODO: metric this - // metric logger.Debugf(ctx, "Failed to get task with task id [%+v] with err %v", lte.Id.TaskId, err) - return nil, err + return nil, nil, err } task, err := transformers.FromTaskModel(taskModel) if err != nil { logger.Debugf(ctx, "Failed to transform task model with err %v", err) - return nil, err + return nil, nil, err } typedInterface = task.Closure.CompiledTask.Template.Interface taskExecID = lte.Id } - return &event.CloudEventNodeExecution{ - RawEvent: rawEvent, - TaskExecId: taskExecID, - OutputInterface: typedInterface, - ArtifactIds: spec.GetMetadata().GetArtifactIds(), - Principal: spec.GetMetadata().Principal, - LaunchPlanId: spec.LaunchPlan, - }, nil + return taskExecID, typedInterface, nil } func (c *CloudEventWrappedPublisher) TransformTaskExecutionEvent(ctx context.Context, rawEvent *event.TaskExecutionEvent) (*event.CloudEventTaskExecution, error) { diff --git a/flyteadmin/pkg/async/notifications/implementations/disk_publisher.go b/flyteadmin/pkg/async/notifications/implementations/disk_publisher.go new file mode 100644 index 00000000000..1445410805b --- /dev/null +++ b/flyteadmin/pkg/async/notifications/implementations/disk_publisher.go @@ -0,0 +1,165 @@ +package implementations + +import ( + "context" + "crypto/rand" + "encoding/hex" + "fmt" + "os" + "strings" + "sync" + + pbcloudevents "github.com/cloudevents/sdk-go/binding/format/protobuf/v2" + cloudevents "github.com/cloudevents/sdk-go/v2" + "github.com/golang/protobuf/proto" + + "github.com/flyteorg/flyte/flytestdlib/logger" +) + +// FunctionCounter Create a global thread-safe counter that can be incremented with every publish +type FunctionCounter struct { + mu sync.Mutex + count int +} + +// Method on the struct that increments the count and performs the function's actions +func (fc *FunctionCounter) increment() int { + fc.mu.Lock() + defer func() { + fc.count++ + fc.mu.Unlock() + }() + return fc.count +} + +// DebugOnlyDiskEventPublisher implements pubsub.Publisher by writing to a local tmp dir +// To use, add: +// +// case "disk": +// publisher := implementations.NewDebugOnlyDiskEventPublisher() +// return implementations.NewEventsPublisher(publisher, scope, []string{"all"}) +// +// to factory.go and update your local configuration +type DebugOnlyDiskEventPublisher struct { + Root string + c FunctionCounter +} + +func (d *DebugOnlyDiskEventPublisher) generateFileName(key string) string { + c := d.c.increment() + return fmt.Sprintf("%s/evt_%03d_%s.pb", d.Root, c, key) +} + +// Publish Implement the pubsub.Publisher interface +func (d *DebugOnlyDiskEventPublisher) Publish(ctx context.Context, key string, msg proto.Message) error { + mb, err := proto.Marshal(msg) + if err != nil { + return err + } + fname := d.generateFileName(key) + logger.Warningf(ctx, "DSKPUB: Publish [%s - %s] [%+v]", fname, key, msg.String()) + + // #nosec + return os.WriteFile(fname, mb, 0666) +} + +func (d *DebugOnlyDiskEventPublisher) PublishRaw(ctx context.Context, key string, msg []byte) error { + fname := d.generateFileName(key) + logger.Warningf(ctx, "DSKPUB: PublishRaw [%s - %s] [%+v]", fname, key, msg) + + // #nosec + return os.WriteFile(fname, msg, 0666) +} + +func generateRandomString(n int) (string, error) { + bytes := make([]byte, n) + if _, err := rand.Read(bytes); err != nil { + return "", err + } + return hex.EncodeToString(bytes), nil +} + +func GetTempFolder() string { + tmpDir := os.Getenv("TMPDIR") + // make a tmp folder + if _, err := os.Stat(tmpDir); os.IsNotExist(err) { + err := os.Mkdir(tmpDir, 0755) + if err != nil { + panic(err) + } + } + r, err := generateRandomString(3) + if err != nil { + panic(err) + } + + // Has a / already + folder := fmt.Sprintf("%s%s", tmpDir, r) + if _, err := os.Stat(folder); os.IsNotExist(err) { + err := os.Mkdir(folder, 0755) + if err != nil { + panic(err) + } + } + return folder +} + +// NewDebugOnlyDiskEventPublisher Create a new DebugOnlyDiskEventPublisher +func NewDebugOnlyDiskEventPublisher() *DebugOnlyDiskEventPublisher { + // make a random temp sub folder + folder := GetTempFolder() + logger.Warningf(context.Background(), "DSKPUB: Using disk publisher with root [%s] -----", folder) + + return &DebugOnlyDiskEventPublisher{ + Root: folder, + c: FunctionCounter{}, + } +} + +// DebugOnlyDiskSenderForCloudEvents implements the Sender interface +// Combines the Sender and Publisher into one, not sure why there are two. +// To use, add: +// +// case "disksender": +// sender = implementations.NewDebugOnlyDiskSenderForCloudEvents() +// +// to cloudevents/factory.go and update your local configuration. This captures all the transformations that the cloud +// event publisher does to the various events. +type DebugOnlyDiskSenderForCloudEvents struct { + Root string + c FunctionCounter +} + +func (s *DebugOnlyDiskSenderForCloudEvents) Send(ctx context.Context, notificationType string, event cloudevents.Event) error { + fname := s.generateFileName(notificationType, event.Source()) + logger.Warningf(ctx, "DSKPUB: Send [%s - %s] [%+v]", fname, notificationType, event.Source()) + + eventBytes, err := pbcloudevents.Protobuf.Marshal(&event) + if err != nil { + logger.Errorf(ctx, "Failed to marshal cloudevent with error: %v", err) + panic(err) + } + // #nosec + return os.WriteFile(fname, eventBytes, 0666) + + //return nil +} + +func (s *DebugOnlyDiskSenderForCloudEvents) generateFileName(notificationType, key string) string { + // The source comes in the form of fce772610fcfc4442a34/n0-0-n0-0-start-node + // This strips out everything before the / + nodeID := strings.Split(key, "/")[1] + c := s.c.increment() + return fmt.Sprintf("%s/cloud_%03d_%s_%s.pb", s.Root, c, notificationType, nodeID) +} + +func NewDebugOnlyDiskSenderForCloudEvents() *DebugOnlyDiskSenderForCloudEvents { + // make a random temp sub folder + folder := GetTempFolder() + logger.Warningf(context.Background(), "DSKPUB: Using disk sender with root [%s] -----", folder) + + return &DebugOnlyDiskSenderForCloudEvents{ + Root: folder, + c: FunctionCounter{}, + } +}