diff --git a/pkg/repositories/transformers/node_execution.go b/pkg/repositories/transformers/node_execution.go index 2029f69f7..0fe9685c3 100644 --- a/pkg/repositories/transformers/node_execution.go +++ b/pkg/repositories/transformers/node_execution.go @@ -133,12 +133,30 @@ func CreateNodeExecutionModel(ctx context.Context, input ToNodeExecutionModelInp return nil, err } } + if common.IsNodeExecutionTerminal(input.Request.Event.Phase) { err := addTerminalState(ctx, input.Request, nodeExecution, &closure, input.InlineEventDataPolicy, input.StorageClient) if err != nil { return nil, err } } + + // Update TaskNodeMetadata, which includes caching information today. + if input.Request.Event.GetTaskNodeMetadata() != nil { + targetMetadata := &admin.NodeExecutionClosure_TaskNodeMetadata{ + TaskNodeMetadata: &admin.TaskNodeMetadata{ + CheckpointUri: input.Request.Event.GetTaskNodeMetadata().CheckpointUri, + }, + } + if input.Request.Event.GetTaskNodeMetadata().CatalogKey != nil { + st := input.Request.Event.GetTaskNodeMetadata().GetCacheStatus().String() + targetMetadata.TaskNodeMetadata.CacheStatus = input.Request.Event.GetTaskNodeMetadata().GetCacheStatus() + targetMetadata.TaskNodeMetadata.CatalogKey = input.Request.Event.GetTaskNodeMetadata().GetCatalogKey() + nodeExecution.CacheStatus = &st + } + closure.TargetMetadata = targetMetadata + } + marshaledClosure, err := proto.Marshal(&closure) if err != nil { return nil, errors.NewFlyteAdminErrorf(