diff --git a/flytepropeller/go.mod b/flytepropeller/go.mod index 16a8781579..2f9703db3f 100644 --- a/flytepropeller/go.mod +++ b/flytepropeller/go.mod @@ -6,7 +6,7 @@ require ( github.com/DiSiqueira/GoTree v1.0.1-0.20180907134536-53a8e837f295 github.com/benlaurie/objecthash v0.0.0-20180202135721-d1e3d6079fc1 github.com/fatih/color v1.10.0 - github.com/flyteorg/flyteidl v0.19.2 + github.com/flyteorg/flyteidl v0.19.5 github.com/flyteorg/flyteplugins v0.5.56 github.com/flyteorg/flytestdlib v0.3.17 github.com/ghodss/yaml v1.0.0 diff --git a/flytepropeller/go.sum b/flytepropeller/go.sum index 6fe0e0f199..4111c869d8 100644 --- a/flytepropeller/go.sum +++ b/flytepropeller/go.sum @@ -229,8 +229,9 @@ github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5Kwzbycv github.com/fatih/color v1.10.0 h1:s36xzo75JdqLaaWoiEHk767eHiwo0598uUxyfiPkDsg= github.com/fatih/color v1.10.0/go.mod h1:ELkj/draVOlAH/xkhN6mQ50Qd0MPOk5AAr3maGEBuJM= github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94= -github.com/flyteorg/flyteidl v0.19.2 h1:jXuRrLJEzSo33N9pw7bMEd6mRYSL7LCz/vnazz5XcOg= github.com/flyteorg/flyteidl v0.19.2/go.mod h1:576W2ViEyjTpT+kEVHAGbrTP3HARNUZ/eCwrNPmdx9U= +github.com/flyteorg/flyteidl v0.19.5 h1:qNhNK6mhCTuOms7zJmBtog6bLQJhBj+iScf1IlHdqeg= +github.com/flyteorg/flyteidl v0.19.5/go.mod h1:576W2ViEyjTpT+kEVHAGbrTP3HARNUZ/eCwrNPmdx9U= github.com/flyteorg/flyteplugins v0.5.56 h1:LF/dwMFJDSMEmOp8hd9rU4Et4oyn0K+LgMzcHOu/xrw= github.com/flyteorg/flyteplugins v0.5.56/go.mod h1:Jp5WheQMI08luZmgcmcgyjtzakKH0tPws/t35DzpKUA= github.com/flyteorg/flytestdlib v0.3.13/go.mod h1:Tz8JCECAbX6VWGwFT6cmEQ+RJpZ/6L9pswu3fzWs220= diff --git a/flytepropeller/pkg/compiler/transformers/k8s/node.go b/flytepropeller/pkg/compiler/transformers/k8s/node.go index 6449be5528..1409545394 100644 --- a/flytepropeller/pkg/compiler/transformers/k8s/node.go +++ b/flytepropeller/pkg/compiler/transformers/k8s/node.go @@ -27,6 +27,7 @@ func buildNodeSpec(n *core.Node, tasks []*core.CompiledTask, errs errors.Compile } var task *core.TaskTemplate + var resources *core.Resources if n.GetTaskNode() != nil { taskID := n.GetTaskNode().GetReferenceId().String() // TODO: Use task index for quick lookup @@ -41,9 +42,15 @@ func buildNodeSpec(n *core.Node, tasks []*core.CompiledTask, errs errors.Compile errs.Collect(errors.NewTaskReferenceNotFoundErr(n.GetId(), taskID)) return nil, !errs.HasErrors() } + + if n.GetTaskNode().Overrides != nil && n.GetTaskNode().Overrides.Resources != nil { + resources = n.GetTaskNode().Overrides.Resources + } else { + resources = getResources(task) + } } - res, err := utils.ToK8sResourceRequirements(getResources(task)) + res, err := utils.ToK8sResourceRequirements(resources) if err != nil { errs.Collect(errors.NewWorkflowBuildError(err)) return nil, false diff --git a/flytepropeller/pkg/compiler/transformers/k8s/node_test.go b/flytepropeller/pkg/compiler/transformers/k8s/node_test.go index 0816d7a020..480948726d 100644 --- a/flytepropeller/pkg/compiler/transformers/k8s/node_test.go +++ b/flytepropeller/pkg/compiler/transformers/k8s/node_test.go @@ -104,6 +104,32 @@ func TestBuildNodeSpec(t *testing.T) { assert.Equal(t, expectedCPU.Value(), spec.Resources.Requests.Cpu().Value()) }) + t.Run("node with resource overrides", func(t *testing.T) { + expectedCPU := resource.MustParse("20Mi") + n.Node.Target = &core.Node_TaskNode{ + TaskNode: &core.TaskNode{ + Reference: &core.TaskNode_ReferenceId{ + ReferenceId: &core.Identifier{Name: "ref_2"}, + }, + Overrides: &core.TaskNodeOverrides{ + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + { + Name: core.Resources_CPU, + Value: "20Mi", + }, + }, + }, + }, + }, + } + + spec := mustBuild(t, n, 1, errs.NewScope()) + assert.NotNil(t, spec.Resources) + assert.NotNil(t, spec.Resources.Requests.Cpu()) + assert.Equal(t, expectedCPU.Value(), spec.Resources.Requests.Cpu().Value()) + }) + t.Run("LaunchPlanRef", func(t *testing.T) { n.Node.Target = &core.Node_WorkflowNode{ WorkflowNode: &core.WorkflowNode{