Skip to content

Commit

Permalink
Add identifier context values (flyteorg#48)
Browse files Browse the repository at this point in the history
  • Loading branch information
katrogan authored Dec 20, 2019
1 parent b974afc commit 84f181a
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 16 deletions.
47 changes: 32 additions & 15 deletions pkg/manager/impl/execution_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"strconv"
"time"

"github.com/lyft/flytestdlib/contextutils"

"github.com/lyft/flyteadmin/pkg/auth"

"github.com/golang/protobuf/ptypes"
Expand Down Expand Up @@ -83,6 +85,11 @@ type ExecutionManager struct {
urlData dataInterfaces.RemoteURLInterface
}

func getExecutionContext(ctx context.Context, id *core.WorkflowExecutionIdentifier) context.Context {
ctx = contextutils.WithExecutionID(ctx, id.Name)
return contextutils.WithProjectDomain(ctx, id.Project, id.Domain)
}

// Returns the unique string which identifies the authenticated end user (if any).
func getUser(ctx context.Context) string {
principalContextUser := ctx.Value(auth.PrincipalContextKey)
Expand Down Expand Up @@ -182,21 +189,22 @@ func (m *ExecutionManager) offloadInputs(ctx context.Context, literalMap *core.L
}

func (m *ExecutionManager) launchExecutionAndPrepareModel(
ctx context.Context, request admin.ExecutionCreateRequest, requestedAt time.Time) (*models.Execution, error) {
ctx context.Context, request admin.ExecutionCreateRequest, requestedAt time.Time) (
context.Context, *models.Execution, error) {
err := validation.ValidateExecutionRequest(ctx, request, m.db, m.config.ApplicationConfiguration())
if err != nil {
logger.Debugf(ctx, "Failed to validate ExecutionCreateRequest %+v with err %v", request, err)
return nil, err
return nil, nil, err
}
launchPlanModel, err := util.GetLaunchPlanModel(ctx, m.db, *request.Spec.LaunchPlan)
if err != nil {
logger.Debugf(ctx, "Failed to get launch plan model for ExecutionCreateRequest %+v with err %v", request, err)
return nil, err
return nil, nil, err
}
launchPlan, err := transformers.FromLaunchPlanModel(launchPlanModel)
if err != nil {
logger.Debugf(ctx, "Failed to transform launch plan model %+v with err %v", launchPlanModel, err)
return nil, err
return nil, nil, err
}
executionInputs, err := validation.CheckAndFetchInputsForExecution(
request.Inputs,
Expand All @@ -208,19 +216,20 @@ func (m *ExecutionManager) launchExecutionAndPrepareModel(
logger.Debugf(ctx, "Failed to CheckAndFetchInputsForExecution with request.Inputs: %+v"+
"fixed inputs: %+v and expected inputs: %+v with err %v",
request.Inputs, launchPlan.Spec.FixedInputs, launchPlan.Closure.ExpectedInputs, err)
return nil, err
return nil, nil, err
}
workflow, err := util.GetWorkflow(ctx, m.db, m.storageClient, *launchPlan.Spec.WorkflowId)
if err != nil {
logger.Debugf(ctx, "Failed to get workflow with id %+v with err %v", launchPlan.Spec.WorkflowId, err)
return nil, err
return nil, nil, err
}
name := util.GetExecutionName(request)
workflowExecutionID := core.WorkflowExecutionIdentifier{
Project: request.Project,
Domain: request.Domain,
Name: name,
}
ctx = getExecutionContext(ctx, &workflowExecutionID)

// Get the node execution (if any) that launched this execution
var parentNodeExecutionID uint
Expand All @@ -229,7 +238,7 @@ func (m *ExecutionManager) launchExecutionAndPrepareModel(
if err != nil {
logger.Errorf(ctx, "Failed to get node execution [%+v] that launched this execution [%+v] with error %v",
request.Spec.Metadata.ParentNodeExecution, workflowExecutionID, err)
return nil, err
return nil, nil, err
}

parentNodeExecutionID = parentNodeExecutionModel.ID
Expand All @@ -245,11 +254,11 @@ func (m *ExecutionManager) launchExecutionAndPrepareModel(

inputsURI, err := m.offloadInputs(ctx, executionInputs, &workflowExecutionID, shared.Inputs)
if err != nil {
return nil, err
return nil, nil, err
}
userInputsURI, err := m.offloadInputs(ctx, request.Inputs, &workflowExecutionID, shared.UserInputs)
if err != nil {
return nil, err
return nil, nil, err
}

// TODO: Reduce CRD size and use offloaded input URI to blob store instead.
Expand All @@ -262,15 +271,15 @@ func (m *ExecutionManager) launchExecutionAndPrepareModel(
}
err = m.addLabelsAndAnnotations(request.Spec, &executeWorkflowInputs)
if err != nil {
return nil, err
return nil, nil, err
}

execInfo, err := m.workflowExecutor.ExecuteWorkflow(ctx, executeWorkflowInputs)
if err != nil {
m.systemMetrics.PropellerFailures.Inc()
logger.Infof(ctx, "Failed to execute workflow %+v with execution id %+v and inputs %+v with err %v",
request, workflowExecutionID, executionInputs, err)
return nil, err
return nil, nil, err
}
executionCreatedAt := time.Now()
acceptanceDelay := executionCreatedAt.Sub(requestedAt)
Expand Down Expand Up @@ -308,9 +317,9 @@ func (m *ExecutionManager) launchExecutionAndPrepareModel(
if err != nil {
logger.Infof(ctx, "Failed to create execution model in transformer for id: [%+v] with err: %v",
workflowExecutionID, err)
return nil, err
return nil, nil, err
}
return executionModel, nil
return ctx, executionModel, nil
}

// Inserts an execution model into the database store and emits platform metrics.
Expand Down Expand Up @@ -341,7 +350,9 @@ func (m *ExecutionManager) CreateExecution(
if request.Inputs == nil || len(request.Inputs.Literals) == 0 {
request.Inputs = request.GetSpec().GetInputs()
}
executionModel, err := m.launchExecutionAndPrepareModel(ctx, request, requestedAt)
var executionModel *models.Execution
var err error
ctx, executionModel, err = m.launchExecutionAndPrepareModel(ctx, request, requestedAt)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -387,7 +398,8 @@ func (m *ExecutionManager) RelaunchExecution(
inputs = spec.Inputs
}
executionSpec.Metadata.Mode = admin.ExecutionMetadata_RELAUNCH
executionModel, err := m.launchExecutionAndPrepareModel(ctx, admin.ExecutionCreateRequest{
var executionModel *models.Execution
ctx, executionModel, err = m.launchExecutionAndPrepareModel(ctx, admin.ExecutionCreateRequest{
Project: request.Id.Project,
Domain: request.Id.Domain,
Name: request.Name,
Expand Down Expand Up @@ -543,6 +555,7 @@ func (m *ExecutionManager) CreateWorkflowEvent(ctx context.Context, request admi
logger.Debugf(ctx, "received invalid CreateWorkflowEventRequest [%s]: %v", request.RequestId, err)
return nil, err
}
ctx = getExecutionContext(ctx, request.Event.ExecutionId)
logger.Debugf(ctx, "Received workflow execution event for [%+v] transitioning to phase [%v]",
request.Event.ExecutionId, request.Event.Phase)

Expand Down Expand Up @@ -617,6 +630,7 @@ func (m *ExecutionManager) GetExecution(
logger.Debugf(ctx, "GetExecution request [%+v] failed validation with err: %v", request, err)
return nil, err
}
ctx = getExecutionContext(ctx, request.Id)
executionModel, err := util.GetExecutionModel(ctx, m.db, *request.Id)
if err != nil {
logger.Debugf(ctx, "Failed to get execution model for request [%+v] with err: %v", request, err)
Expand Down Expand Up @@ -667,6 +681,7 @@ func (m *ExecutionManager) GetExecution(

func (m *ExecutionManager) GetExecutionData(
ctx context.Context, request admin.WorkflowExecutionGetDataRequest) (*admin.WorkflowExecutionGetDataResponse, error) {
ctx = getExecutionContext(ctx, request.Id)
executionModel, err := util.GetExecutionModel(ctx, m.db, *request.Id)
if err != nil {
logger.Debugf(ctx, "Failed to get execution model for request [%+v] with err: %v", request, err)
Expand Down Expand Up @@ -718,6 +733,7 @@ func (m *ExecutionManager) ListExecutions(
logger.Debugf(ctx, "ListExecutions request [%+v] failed validation with err: %v", request, err)
return nil, err
}
ctx = contextutils.WithProjectDomain(ctx, request.Id.Project, request.Id.Domain)
filters, err := util.GetDbFilters(util.FilterSpec{
Project: request.Id.Project,
Domain: request.Id.Domain,
Expand Down Expand Up @@ -842,6 +858,7 @@ func (m *ExecutionManager) TerminateExecution(
logger.Debugf(ctx, "received terminate execution request: %v with invalid identifier: %v", request, err)
return nil, err
}
ctx = getExecutionContext(ctx, request.Id)
// Save the abort reason (best effort)
executionModel, err := m.db.ExecutionRepo().Get(ctx, repositoryInterfaces.GetResourceInput{
Project: request.Id.Project,
Expand Down
20 changes: 19 additions & 1 deletion pkg/manager/impl/launch_plan_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"context"
"strconv"

"github.com/lyft/flytestdlib/contextutils"

"github.com/lyft/flyteadmin/pkg/async/schedule/aws"

"github.com/lyft/flytestdlib/promutils"
Expand Down Expand Up @@ -44,6 +46,16 @@ type LaunchPlanManager struct {
metrics launchPlanMetrics
}

func getLaunchPlanContext(ctx context.Context, identifier *core.Identifier) context.Context {
ctx = contextutils.WithProjectDomain(ctx, identifier.Project, identifier.Domain)
return contextutils.WithLaunchPlanID(ctx, identifier.Name)
}

func (m *LaunchPlanManager) getNamedEntityContext(ctx context.Context, identifier *admin.NamedEntityIdentifier) context.Context {
ctx = contextutils.WithProjectDomain(ctx, identifier.Project, identifier.Domain)
return contextutils.WithLaunchPlanID(ctx, identifier.Name)
}

func (m *LaunchPlanManager) CreateLaunchPlan(
ctx context.Context,
request admin.LaunchPlanCreateRequest) (*admin.LaunchPlanCreateResponse, error) {
Expand Down Expand Up @@ -71,6 +83,7 @@ func (m *LaunchPlanManager) CreateLaunchPlan(
logger.Debugf(ctx, "could not create launch plan: %+v, request failed validation with err: %v", request.Id, err)
return nil, err
}
ctx = getLaunchPlanContext(ctx, request.Id)
launchPlan := transformers.CreateLaunchPlan(request, workflowInterface.Outputs)
launchPlanDigest, err := util.GetLaunchPlanDigest(ctx, &launchPlan)
if err != nil {
Expand Down Expand Up @@ -328,6 +341,7 @@ func (m *LaunchPlanManager) UpdateLaunchPlan(ctx context.Context, request admin.
if err := validation.ValidateIdentifier(request.Id, common.LaunchPlan); err != nil {
logger.Debugf(ctx, "can't update launch plan [%+v] state, invalid identifier: %v", request.Id, err)
}
ctx = getLaunchPlanContext(ctx, request.Id)
switch request.State {
case admin.LaunchPlanState_INACTIVE:
return m.disableLaunchPlan(ctx, request)
Expand All @@ -346,6 +360,7 @@ func (m *LaunchPlanManager) GetLaunchPlan(ctx context.Context, request admin.Obj
logger.Debugf(ctx, "can't get launch plan [%+v] with invalid identifier: %v", request.Id, err)
return nil, err
}
ctx = getLaunchPlanContext(ctx, request.Id)
return util.GetLaunchPlan(ctx, m.db, *request.Id)
}

Expand All @@ -355,6 +370,7 @@ func (m *LaunchPlanManager) GetActiveLaunchPlan(ctx context.Context, request adm
logger.Debugf(ctx, "can't get active launch plan [%+v] with invalid request: %v", request.Id, err)
return nil, err
}
ctx = m.getNamedEntityContext(ctx, request.Id)

filters, err := util.GetActiveLaunchPlanVersionFilters(request.Id.Project, request.Id.Domain, request.Id.Name)
if err != nil {
Expand Down Expand Up @@ -388,6 +404,7 @@ func (m *LaunchPlanManager) ListLaunchPlans(ctx context.Context, request admin.R
logger.Debugf(ctx, "")
return nil, err
}
ctx = m.getNamedEntityContext(ctx, request.Id)

filters, err := util.GetDbFilters(util.FilterSpec{
Project: request.Id.Project,
Expand Down Expand Up @@ -447,6 +464,7 @@ func (m *LaunchPlanManager) ListActiveLaunchPlans(ctx context.Context, request a
logger.Debugf(ctx, "")
return nil, err
}
ctx = contextutils.WithProjectDomain(ctx, request.Project, request.Domain)

filters, err := util.ListActiveLaunchPlanVersionsFilters(request.Project, request.Domain)
if err != nil {
Expand Down Expand Up @@ -496,7 +514,7 @@ func (m *LaunchPlanManager) ListActiveLaunchPlans(ctx context.Context, request a
// At least project name and domain must be specified along with limit.
func (m *LaunchPlanManager) ListLaunchPlanIds(ctx context.Context, request admin.NamedEntityIdentifierListRequest) (
*admin.NamedEntityIdentifierList, error) {

ctx = contextutils.WithProjectDomain(ctx, request.Project, request.Domain)
filters, err := util.GetDbFilters(util.FilterSpec{
Project: request.Project,
Domain: request.Domain,
Expand Down
5 changes: 5 additions & 0 deletions pkg/manager/impl/named_entity_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"strconv"

"github.com/lyft/flytestdlib/contextutils"

"github.com/lyft/flyteadmin/pkg/common"
"github.com/lyft/flyteadmin/pkg/errors"
"google.golang.org/grpc/codes"
Expand Down Expand Up @@ -36,6 +38,7 @@ func (m *NamedEntityManager) UpdateNamedEntity(ctx context.Context, request admi
logger.Debugf(ctx, "invalid request [%+v]: %v", request, err)
return nil, err
}
ctx = contextutils.WithProjectDomain(ctx, request.Id.Project, request.Id.Domain)

// Ensure entity exists before trying to update it
_, err := util.GetNamedEntity(ctx, m.db, request.ResourceType, *request.Id)
Expand All @@ -58,6 +61,7 @@ func (m *NamedEntityManager) GetNamedEntity(ctx context.Context, request admin.N
logger.Debugf(ctx, "invalid request [%+v]: %v", request, err)
return nil, err
}
ctx = contextutils.WithProjectDomain(ctx, request.Id.Project, request.Id.Domain)
return util.GetNamedEntity(ctx, m.db, request.ResourceType, *request.Id)
}

Expand All @@ -67,6 +71,7 @@ func (m *NamedEntityManager) ListNamedEntities(ctx context.Context, request admi
logger.Debugf(ctx, "invalid request [%+v]: %v", request, err)
return nil, err
}
ctx = contextutils.WithProjectDomain(ctx, request.Project, request.Domain)

filters, err := util.GetDbFilters(util.FilterSpec{
Project: request.Project,
Expand Down
16 changes: 16 additions & 0 deletions pkg/manager/impl/node_execution_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"strconv"

"github.com/lyft/flytestdlib/contextutils"

"github.com/lyft/flyteadmin/pkg/manager/impl/shared"
"github.com/lyft/flytestdlib/promutils"
"github.com/prometheus/client_golang/prometheus"
Expand Down Expand Up @@ -58,6 +60,12 @@ var isParent = common.NewMapFilter(map[string]interface{}{
shared.ParentTaskExecutionID: nil,
})

func getNodeExecutionContext(ctx context.Context, identifier *core.NodeExecutionIdentifier) context.Context {
ctx = contextutils.WithProjectDomain(ctx, identifier.ExecutionId.Project, identifier.ExecutionId.Domain)
ctx = contextutils.WithExecutionID(ctx, identifier.ExecutionId.Name)
return contextutils.WithNodeID(ctx, identifier.NodeId)
}

func (m *NodeExecutionManager) createNodeExecutionWithEvent(
ctx context.Context, request *admin.NodeExecutionEventRequest) error {

Expand Down Expand Up @@ -149,6 +157,10 @@ func (m *NodeExecutionManager) updateNodeExecutionWithEvent(

func (m *NodeExecutionManager) CreateNodeEvent(ctx context.Context, request admin.NodeExecutionEventRequest) (
*admin.NodeExecutionEventResponse, error) {
if err := validation.ValidateNodeExecutionIdentifier(request.Event.Id); err != nil {
logger.Debugf(ctx, "CreateNodeEvent called with invalid identifier [%+v]: %v", request.Event.Id, err)
}
ctx = getNodeExecutionContext(ctx, request.Event.Id)
executionID := request.Event.Id.ExecutionId
logger.Debugf(ctx, "Received node execution event for Node Exec Id [%+v] transitioning to phase [%v], w/ Metadata [%v]",
request.Event.Id, request.Event.Phase, request.Event.ParentTaskMetadata)
Expand Down Expand Up @@ -208,6 +220,7 @@ func (m *NodeExecutionManager) GetNodeExecution(
if err := validation.ValidateNodeExecutionIdentifier(request.Id); err != nil {
logger.Debugf(ctx, "get node execution called with invalid identifier [%+v]: %v", request.Id, err)
}
ctx = getNodeExecutionContext(ctx, request.Id)
nodeExecutionModel, err := util.GetNodeExecutionModel(ctx, m.db, request.Id)
if err != nil {
logger.Debugf(ctx, "Failed to get node execution with id [%+v] with err %v",
Expand Down Expand Up @@ -282,6 +295,7 @@ func (m *NodeExecutionManager) ListNodeExecutions(
if err := validation.ValidateNodeExecutionListRequest(request); err != nil {
return nil, err
}
ctx = getExecutionContext(ctx, request.WorkflowExecutionId)

identifierFilters, err := util.GetWorkflowExecutionIdentifierFilters(ctx, *request.WorkflowExecutionId)
if err != nil {
Expand All @@ -299,6 +313,7 @@ func (m *NodeExecutionManager) ListNodeExecutionsForTask(
if err := validation.ValidateNodeExecutionForTaskListRequest(request); err != nil {
return nil, err
}
ctx = getTaskExecutionContext(ctx, request.TaskExecutionId)
identifierFilters, err := util.GetWorkflowExecutionIdentifierFilters(
ctx, *request.TaskExecutionId.NodeExecutionId.ExecutionId)
if err != nil {
Expand All @@ -323,6 +338,7 @@ func (m *NodeExecutionManager) GetNodeExecutionData(
if err := validation.ValidateNodeExecutionIdentifier(request.Id); err != nil {
logger.Debugf(ctx, "can't get node execution data with invalid identifier [%+v]: %v", request.Id, err)
}
ctx = getNodeExecutionContext(ctx, request.Id)
nodeExecutionModel, err := util.GetNodeExecutionModel(ctx, m.db, request.Id)
if err != nil {
logger.Debugf(ctx, "Failed to get node execution with id [%+v] with err %v",
Expand Down
3 changes: 3 additions & 0 deletions pkg/manager/impl/project_domain_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package impl
import (
"context"

"github.com/lyft/flytestdlib/contextutils"

"github.com/lyft/flyteadmin/pkg/manager/impl/validation"
"github.com/lyft/flyteadmin/pkg/repositories/transformers"

Expand All @@ -23,6 +25,7 @@ func (m *ProjectDomainManager) UpdateProjectDomain(
if err := validation.ValidateProjectDomainAttributesUpdateRequest(request); err != nil {
return nil, err
}
ctx = contextutils.WithProjectDomain(ctx, request.Attributes.Project, request.Attributes.Domain)

model, err := transformers.ToProjectDomainModel(*request.Attributes)
if err != nil {
Expand Down
Loading

0 comments on commit 84f181a

Please sign in to comment.