diff --git a/flyteidl/clients/go/assets/admin.swagger.json b/flyteidl/clients/go/assets/admin.swagger.json index 30e653daa3..1adf1fa70c 100644 --- a/flyteidl/clients/go/assets/admin.swagger.json +++ b/flyteidl/clients/go/assets/admin.swagger.json @@ -7221,6 +7221,10 @@ "interface": { "$ref": "#/definitions/coreTypedInterface", "title": "The input and output interface for the launch plan" + }, + "fixed_inputs": { + "$ref": "#/definitions/coreLiteralMap", + "title": "A collection of input literals that are fixed for the launch plan" } }, "description": "A structure that uniquely identifies a launch plan in the system." @@ -8076,6 +8080,10 @@ "extended_resources": { "$ref": "#/definitions/coreExtendedResources", "description": "Overrides for all non-standard resources, not captured by\nv1.ResourceRequirements, to allocate to a task." + }, + "container_image": { + "type": "string", + "description": "Override for the image used by task pods." } }, "description": "Optional task node overrides that will be applied at task execution time." diff --git a/flyteidl/gen/pb-es/flyteidl/core/workflow_pb.ts b/flyteidl/gen/pb-es/flyteidl/core/workflow_pb.ts index 644f792e63..414e6eb319 100644 --- a/flyteidl/gen/pb-es/flyteidl/core/workflow_pb.ts +++ b/flyteidl/gen/pb-es/flyteidl/core/workflow_pb.ts @@ -1121,6 +1121,13 @@ export class TaskNodeOverrides extends Message { */ extendedResources?: ExtendedResources; + /** + * Override for the image used by task pods. + * + * @generated from field: string container_image = 3; + */ + containerImage = ""; + constructor(data?: PartialMessage) { super(); proto3.util.initPartial(data, this); @@ -1131,6 +1138,7 @@ export class TaskNodeOverrides extends Message { static readonly fields: FieldList = proto3.util.newFieldList(() => [ { no: 1, name: "resources", kind: "message", T: Resources }, { no: 2, name: "extended_resources", kind: "message", T: ExtendedResources }, + { no: 3, name: "container_image", kind: "scalar", T: 9 /* ScalarType.STRING */ }, ]); static fromBinary(bytes: Uint8Array, options?: Partial): TaskNodeOverrides { diff --git a/flyteidl/gen/pb-go/flyteidl/core/workflow.pb.go b/flyteidl/gen/pb-go/flyteidl/core/workflow.pb.go index f9318f539f..8b6cc6ab1d 100644 --- a/flyteidl/gen/pb-go/flyteidl/core/workflow.pb.go +++ b/flyteidl/gen/pb-go/flyteidl/core/workflow.pb.go @@ -1498,6 +1498,8 @@ type TaskNodeOverrides struct { // Overrides for all non-standard resources, not captured by // v1.ResourceRequirements, to allocate to a task. ExtendedResources *ExtendedResources `protobuf:"bytes,2,opt,name=extended_resources,json=extendedResources,proto3" json:"extended_resources,omitempty"` + // Override for the image used by task pods. + ContainerImage string `protobuf:"bytes,3,opt,name=container_image,json=containerImage,proto3" json:"container_image,omitempty"` } func (x *TaskNodeOverrides) Reset() { @@ -1546,6 +1548,13 @@ func (x *TaskNodeOverrides) GetExtendedResources() *ExtendedResources { return nil } +func (x *TaskNodeOverrides) GetContainerImage() string { + if x != nil { + return x.ContainerImage + } + return "" +} + // A structure that uniquely identifies a launch plan in the system. type LaunchPlanTemplate struct { state protoimpl.MessageState @@ -1839,7 +1848,7 @@ var file_flyteidl_core_workflow_proto_rawDesc = []byte{ 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x57, 0x6f, 0x72, 0x6b, 0x66, 0x6c, 0x6f, 0x77, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x44, 0x65, 0x66, 0x61, 0x75, 0x6c, 0x74, 0x73, 0x52, 0x10, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x44, 0x65, 0x66, - 0x61, 0x75, 0x6c, 0x74, 0x73, 0x22, 0x9c, 0x01, 0x0a, 0x11, 0x54, 0x61, 0x73, 0x6b, 0x4e, 0x6f, + 0x61, 0x75, 0x6c, 0x74, 0x73, 0x22, 0xc5, 0x01, 0x0a, 0x11, 0x54, 0x61, 0x73, 0x6b, 0x4e, 0x6f, 0x64, 0x65, 0x4f, 0x76, 0x65, 0x72, 0x72, 0x69, 0x64, 0x65, 0x73, 0x12, 0x36, 0x0a, 0x09, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x52, @@ -1849,30 +1858,33 @@ var file_flyteidl_core_workflow_proto_rawDesc = []byte{ 0x20, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x45, 0x78, 0x74, 0x65, 0x6e, 0x64, 0x65, 0x64, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x73, 0x52, 0x11, 0x65, 0x78, 0x74, 0x65, 0x6e, 0x64, 0x65, 0x64, 0x52, 0x65, 0x73, 0x6f, 0x75, - 0x72, 0x63, 0x65, 0x73, 0x22, 0xba, 0x01, 0x0a, 0x12, 0x4c, 0x61, 0x75, 0x6e, 0x63, 0x68, 0x50, - 0x6c, 0x61, 0x6e, 0x54, 0x65, 0x6d, 0x70, 0x6c, 0x61, 0x74, 0x65, 0x12, 0x29, 0x0a, 0x02, 0x69, - 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, - 0x64, 0x6c, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x66, 0x69, - 0x65, 0x72, 0x52, 0x02, 0x69, 0x64, 0x12, 0x3b, 0x0a, 0x09, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x66, - 0x61, 0x63, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x66, 0x6c, 0x79, 0x74, - 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x64, 0x49, - 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x52, 0x09, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x66, - 0x61, 0x63, 0x65, 0x12, 0x3c, 0x0a, 0x0c, 0x66, 0x69, 0x78, 0x65, 0x64, 0x5f, 0x69, 0x6e, 0x70, - 0x75, 0x74, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x66, 0x6c, 0x79, 0x74, - 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x4c, 0x69, 0x74, 0x65, 0x72, 0x61, - 0x6c, 0x4d, 0x61, 0x70, 0x52, 0x0b, 0x66, 0x69, 0x78, 0x65, 0x64, 0x49, 0x6e, 0x70, 0x75, 0x74, - 0x73, 0x42, 0xb3, 0x01, 0x0a, 0x11, 0x63, 0x6f, 0x6d, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, - 0x64, 0x6c, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x42, 0x0d, 0x57, 0x6f, 0x72, 0x6b, 0x66, 0x6c, 0x6f, - 0x77, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x3a, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, - 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x6f, 0x72, 0x67, 0x2f, 0x66, 0x6c, - 0x79, 0x74, 0x65, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2f, 0x67, 0x65, 0x6e, - 0x2f, 0x70, 0x62, 0x2d, 0x67, 0x6f, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2f, - 0x63, 0x6f, 0x72, 0x65, 0xa2, 0x02, 0x03, 0x46, 0x43, 0x58, 0xaa, 0x02, 0x0d, 0x46, 0x6c, 0x79, - 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x43, 0x6f, 0x72, 0x65, 0xca, 0x02, 0x0d, 0x46, 0x6c, 0x79, - 0x74, 0x65, 0x69, 0x64, 0x6c, 0x5c, 0x43, 0x6f, 0x72, 0x65, 0xe2, 0x02, 0x19, 0x46, 0x6c, 0x79, - 0x74, 0x65, 0x69, 0x64, 0x6c, 0x5c, 0x43, 0x6f, 0x72, 0x65, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, - 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0xea, 0x02, 0x0e, 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, - 0x6c, 0x3a, 0x3a, 0x43, 0x6f, 0x72, 0x65, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x72, 0x63, 0x65, 0x73, 0x12, 0x27, 0x0a, 0x0f, 0x63, 0x6f, 0x6e, 0x74, 0x61, 0x69, 0x6e, 0x65, + 0x72, 0x5f, 0x69, 0x6d, 0x61, 0x67, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x63, + 0x6f, 0x6e, 0x74, 0x61, 0x69, 0x6e, 0x65, 0x72, 0x49, 0x6d, 0x61, 0x67, 0x65, 0x22, 0xba, 0x01, + 0x0a, 0x12, 0x4c, 0x61, 0x75, 0x6e, 0x63, 0x68, 0x50, 0x6c, 0x61, 0x6e, 0x54, 0x65, 0x6d, 0x70, + 0x6c, 0x61, 0x74, 0x65, 0x12, 0x29, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x19, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x63, 0x6f, 0x72, 0x65, + 0x2e, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x66, 0x69, 0x65, 0x72, 0x52, 0x02, 0x69, 0x64, 0x12, + 0x3b, 0x0a, 0x09, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x63, 0x6f, + 0x72, 0x65, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x64, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, + 0x65, 0x52, 0x09, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x12, 0x3c, 0x0a, 0x0c, + 0x66, 0x69, 0x78, 0x65, 0x64, 0x5f, 0x69, 0x6e, 0x70, 0x75, 0x74, 0x73, 0x18, 0x03, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x63, 0x6f, + 0x72, 0x65, 0x2e, 0x4c, 0x69, 0x74, 0x65, 0x72, 0x61, 0x6c, 0x4d, 0x61, 0x70, 0x52, 0x0b, 0x66, + 0x69, 0x78, 0x65, 0x64, 0x49, 0x6e, 0x70, 0x75, 0x74, 0x73, 0x42, 0xb3, 0x01, 0x0a, 0x11, 0x63, + 0x6f, 0x6d, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x63, 0x6f, 0x72, 0x65, + 0x42, 0x0d, 0x57, 0x6f, 0x72, 0x6b, 0x66, 0x6c, 0x6f, 0x77, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, + 0x01, 0x5a, 0x3a, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x66, 0x6c, + 0x79, 0x74, 0x65, 0x6f, 0x72, 0x67, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x2f, 0x66, 0x6c, 0x79, + 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2f, 0x67, 0x65, 0x6e, 0x2f, 0x70, 0x62, 0x2d, 0x67, 0x6f, 0x2f, + 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2f, 0x63, 0x6f, 0x72, 0x65, 0xa2, 0x02, 0x03, + 0x46, 0x43, 0x58, 0xaa, 0x02, 0x0d, 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x43, + 0x6f, 0x72, 0x65, 0xca, 0x02, 0x0d, 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x5c, 0x43, + 0x6f, 0x72, 0x65, 0xe2, 0x02, 0x19, 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x5c, 0x43, + 0x6f, 0x72, 0x65, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0xea, + 0x02, 0x0e, 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x3a, 0x3a, 0x43, 0x6f, 0x72, 0x65, + 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/flyteidl/gen/pb-go/gateway/flyteidl/service/admin.swagger.json b/flyteidl/gen/pb-go/gateway/flyteidl/service/admin.swagger.json index 73dbd3b8dc..1adf1fa70c 100644 --- a/flyteidl/gen/pb-go/gateway/flyteidl/service/admin.swagger.json +++ b/flyteidl/gen/pb-go/gateway/flyteidl/service/admin.swagger.json @@ -8080,6 +8080,10 @@ "extended_resources": { "$ref": "#/definitions/coreExtendedResources", "description": "Overrides for all non-standard resources, not captured by\nv1.ResourceRequirements, to allocate to a task." + }, + "container_image": { + "type": "string", + "description": "Override for the image used by task pods." } }, "description": "Optional task node overrides that will be applied at task execution time." diff --git a/flyteidl/gen/pb-js/flyteidl.d.ts b/flyteidl/gen/pb-js/flyteidl.d.ts index 87de89064b..d8a2711961 100644 --- a/flyteidl/gen/pb-js/flyteidl.d.ts +++ b/flyteidl/gen/pb-js/flyteidl.d.ts @@ -4935,6 +4935,9 @@ export namespace flyteidl { /** TaskNodeOverrides extendedResources */ extendedResources?: (flyteidl.core.IExtendedResources|null); + + /** TaskNodeOverrides containerImage */ + containerImage?: (string|null); } /** Represents a TaskNodeOverrides. */ @@ -4952,6 +4955,9 @@ export namespace flyteidl { /** TaskNodeOverrides extendedResources. */ public extendedResources?: (flyteidl.core.IExtendedResources|null); + /** TaskNodeOverrides containerImage. */ + public containerImage: string; + /** * Creates a new TaskNodeOverrides instance using the specified properties. * @param [properties] Properties to set diff --git a/flyteidl/gen/pb-js/flyteidl.js b/flyteidl/gen/pb-js/flyteidl.js index c7fac48ad5..791f0c9eb6 100644 --- a/flyteidl/gen/pb-js/flyteidl.js +++ b/flyteidl/gen/pb-js/flyteidl.js @@ -11969,6 +11969,7 @@ * @interface ITaskNodeOverrides * @property {flyteidl.core.IResources|null} [resources] TaskNodeOverrides resources * @property {flyteidl.core.IExtendedResources|null} [extendedResources] TaskNodeOverrides extendedResources + * @property {string|null} [containerImage] TaskNodeOverrides containerImage */ /** @@ -12002,6 +12003,14 @@ */ TaskNodeOverrides.prototype.extendedResources = null; + /** + * TaskNodeOverrides containerImage. + * @member {string} containerImage + * @memberof flyteidl.core.TaskNodeOverrides + * @instance + */ + TaskNodeOverrides.prototype.containerImage = ""; + /** * Creates a new TaskNodeOverrides instance using the specified properties. * @function create @@ -12030,6 +12039,8 @@ $root.flyteidl.core.Resources.encode(message.resources, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); if (message.extendedResources != null && message.hasOwnProperty("extendedResources")) $root.flyteidl.core.ExtendedResources.encode(message.extendedResources, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + if (message.containerImage != null && message.hasOwnProperty("containerImage")) + writer.uint32(/* id 3, wireType 2 =*/26).string(message.containerImage); return writer; }; @@ -12057,6 +12068,9 @@ case 2: message.extendedResources = $root.flyteidl.core.ExtendedResources.decode(reader, reader.uint32()); break; + case 3: + message.containerImage = reader.string(); + break; default: reader.skipType(tag & 7); break; @@ -12086,6 +12100,9 @@ if (error) return "extendedResources." + error; } + if (message.containerImage != null && message.hasOwnProperty("containerImage")) + if (!$util.isString(message.containerImage)) + return "containerImage: string expected"; return null; }; diff --git a/flyteidl/gen/pb_python/flyteidl/core/workflow_pb2.py b/flyteidl/gen/pb_python/flyteidl/core/workflow_pb2.py index ab629c2b9f..4c07024320 100644 --- a/flyteidl/gen/pb_python/flyteidl/core/workflow_pb2.py +++ b/flyteidl/gen/pb_python/flyteidl/core/workflow_pb2.py @@ -22,7 +22,7 @@ from google.protobuf import duration_pb2 as google_dot_protobuf_dot_duration__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lyteidl/core/workflow.proto\x12\rflyteidl.core\x1a\x1d\x66lyteidl/core/condition.proto\x1a\x1d\x66lyteidl/core/execution.proto\x1a\x1e\x66lyteidl/core/identifier.proto\x1a\x1d\x66lyteidl/core/interface.proto\x1a\x1c\x66lyteidl/core/literals.proto\x1a\x19\x66lyteidl/core/tasks.proto\x1a\x19\x66lyteidl/core/types.proto\x1a\x1c\x66lyteidl/core/security.proto\x1a\x1egoogle/protobuf/duration.proto\"{\n\x07IfBlock\x12>\n\tcondition\x18\x01 \x01(\x0b\x32 .flyteidl.core.BooleanExpressionR\tcondition\x12\x30\n\tthen_node\x18\x02 \x01(\x0b\x32\x13.flyteidl.core.NodeR\x08thenNode\"\xd4\x01\n\x0bIfElseBlock\x12*\n\x04\x63\x61se\x18\x01 \x01(\x0b\x32\x16.flyteidl.core.IfBlockR\x04\x63\x61se\x12,\n\x05other\x18\x02 \x03(\x0b\x32\x16.flyteidl.core.IfBlockR\x05other\x12\x32\n\telse_node\x18\x03 \x01(\x0b\x32\x13.flyteidl.core.NodeH\x00R\x08\x65lseNode\x12,\n\x05\x65rror\x18\x04 \x01(\x0b\x32\x14.flyteidl.core.ErrorH\x00R\x05\x65rrorB\t\n\x07\x64\x65\x66\x61ult\"A\n\nBranchNode\x12\x33\n\x07if_else\x18\x01 \x01(\x0b\x32\x1a.flyteidl.core.IfElseBlockR\x06ifElse\"\x97\x01\n\x08TaskNode\x12>\n\x0creference_id\x18\x01 \x01(\x0b\x32\x19.flyteidl.core.IdentifierH\x00R\x0breferenceId\x12>\n\toverrides\x18\x02 \x01(\x0b\x32 .flyteidl.core.TaskNodeOverridesR\toverridesB\x0b\n\treference\"\xa6\x01\n\x0cWorkflowNode\x12\x42\n\x0elaunchplan_ref\x18\x01 \x01(\x0b\x32\x19.flyteidl.core.IdentifierH\x00R\rlaunchplanRef\x12\x45\n\x10sub_workflow_ref\x18\x02 \x01(\x0b\x32\x19.flyteidl.core.IdentifierH\x00R\x0esubWorkflowRefB\x0b\n\treference\"/\n\x10\x41pproveCondition\x12\x1b\n\tsignal_id\x18\x01 \x01(\tR\x08signalId\"\x90\x01\n\x0fSignalCondition\x12\x1b\n\tsignal_id\x18\x01 \x01(\tR\x08signalId\x12.\n\x04type\x18\x02 \x01(\x0b\x32\x1a.flyteidl.core.LiteralTypeR\x04type\x12\x30\n\x14output_variable_name\x18\x03 \x01(\tR\x12outputVariableName\"G\n\x0eSleepCondition\x12\x35\n\x08\x64uration\x18\x01 \x01(\x0b\x32\x19.google.protobuf.DurationR\x08\x64uration\"\xc5\x01\n\x08GateNode\x12;\n\x07\x61pprove\x18\x01 \x01(\x0b\x32\x1f.flyteidl.core.ApproveConditionH\x00R\x07\x61pprove\x12\x38\n\x06signal\x18\x02 \x01(\x0b\x32\x1e.flyteidl.core.SignalConditionH\x00R\x06signal\x12\x35\n\x05sleep\x18\x03 \x01(\x0b\x32\x1d.flyteidl.core.SleepConditionH\x00R\x05sleepB\x0b\n\tcondition\"\xbf\x01\n\tArrayNode\x12\'\n\x04node\x18\x01 \x01(\x0b\x32\x13.flyteidl.core.NodeR\x04node\x12 \n\x0bparallelism\x18\x02 \x01(\rR\x0bparallelism\x12%\n\rmin_successes\x18\x03 \x01(\rH\x00R\x0cminSuccesses\x12,\n\x11min_success_ratio\x18\x04 \x01(\x02H\x00R\x0fminSuccessRatioB\x12\n\x10success_criteria\"\x8c\x03\n\x0cNodeMetadata\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x33\n\x07timeout\x18\x04 \x01(\x0b\x32\x19.google.protobuf.DurationR\x07timeout\x12\x36\n\x07retries\x18\x05 \x01(\x0b\x32\x1c.flyteidl.core.RetryStrategyR\x07retries\x12&\n\rinterruptible\x18\x06 \x01(\x08H\x00R\rinterruptible\x12\x1e\n\tcacheable\x18\x07 \x01(\x08H\x01R\tcacheable\x12%\n\rcache_version\x18\x08 \x01(\tH\x02R\x0c\x63\x61\x63heVersion\x12/\n\x12\x63\x61\x63he_serializable\x18\t \x01(\x08H\x03R\x11\x63\x61\x63heSerializableB\x15\n\x13interruptible_valueB\x11\n\x0f\x63\x61\x63heable_valueB\x15\n\x13\x63\x61\x63he_version_valueB\x1a\n\x18\x63\x61\x63he_serializable_value\"/\n\x05\x41lias\x12\x10\n\x03var\x18\x01 \x01(\tR\x03var\x12\x14\n\x05\x61lias\x18\x02 \x01(\tR\x05\x61lias\"\x9f\x04\n\x04Node\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x37\n\x08metadata\x18\x02 \x01(\x0b\x32\x1b.flyteidl.core.NodeMetadataR\x08metadata\x12.\n\x06inputs\x18\x03 \x03(\x0b\x32\x16.flyteidl.core.BindingR\x06inputs\x12*\n\x11upstream_node_ids\x18\x04 \x03(\tR\x0fupstreamNodeIds\x12;\n\x0eoutput_aliases\x18\x05 \x03(\x0b\x32\x14.flyteidl.core.AliasR\routputAliases\x12\x36\n\ttask_node\x18\x06 \x01(\x0b\x32\x17.flyteidl.core.TaskNodeH\x00R\x08taskNode\x12\x42\n\rworkflow_node\x18\x07 \x01(\x0b\x32\x1b.flyteidl.core.WorkflowNodeH\x00R\x0cworkflowNode\x12<\n\x0b\x62ranch_node\x18\x08 \x01(\x0b\x32\x19.flyteidl.core.BranchNodeH\x00R\nbranchNode\x12\x36\n\tgate_node\x18\t \x01(\x0b\x32\x17.flyteidl.core.GateNodeH\x00R\x08gateNode\x12\x39\n\narray_node\x18\n \x01(\x0b\x32\x18.flyteidl.core.ArrayNodeH\x00R\tarrayNodeB\x08\n\x06target\"\xfc\x02\n\x10WorkflowMetadata\x12M\n\x12quality_of_service\x18\x01 \x01(\x0b\x32\x1f.flyteidl.core.QualityOfServiceR\x10qualityOfService\x12N\n\non_failure\x18\x02 \x01(\x0e\x32/.flyteidl.core.WorkflowMetadata.OnFailurePolicyR\tonFailure\x12=\n\x04tags\x18\x03 \x03(\x0b\x32).flyteidl.core.WorkflowMetadata.TagsEntryR\x04tags\x1a\x37\n\tTagsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"Q\n\x0fOnFailurePolicy\x12\x14\n\x10\x46\x41IL_IMMEDIATELY\x10\x00\x12(\n$FAIL_AFTER_EXECUTABLE_NODES_COMPLETE\x10\x01\"@\n\x18WorkflowMetadataDefaults\x12$\n\rinterruptible\x18\x01 \x01(\x08R\rinterruptible\"\xa2\x03\n\x10WorkflowTemplate\x12)\n\x02id\x18\x01 \x01(\x0b\x32\x19.flyteidl.core.IdentifierR\x02id\x12;\n\x08metadata\x18\x02 \x01(\x0b\x32\x1f.flyteidl.core.WorkflowMetadataR\x08metadata\x12;\n\tinterface\x18\x03 \x01(\x0b\x32\x1d.flyteidl.core.TypedInterfaceR\tinterface\x12)\n\x05nodes\x18\x04 \x03(\x0b\x32\x13.flyteidl.core.NodeR\x05nodes\x12\x30\n\x07outputs\x18\x05 \x03(\x0b\x32\x16.flyteidl.core.BindingR\x07outputs\x12\x36\n\x0c\x66\x61ilure_node\x18\x06 \x01(\x0b\x32\x13.flyteidl.core.NodeR\x0b\x66\x61ilureNode\x12T\n\x11metadata_defaults\x18\x07 \x01(\x0b\x32\'.flyteidl.core.WorkflowMetadataDefaultsR\x10metadataDefaults\"\x9c\x01\n\x11TaskNodeOverrides\x12\x36\n\tresources\x18\x01 \x01(\x0b\x32\x18.flyteidl.core.ResourcesR\tresources\x12O\n\x12\x65xtended_resources\x18\x02 \x01(\x0b\x32 .flyteidl.core.ExtendedResourcesR\x11\x65xtendedResources\"\xba\x01\n\x12LaunchPlanTemplate\x12)\n\x02id\x18\x01 \x01(\x0b\x32\x19.flyteidl.core.IdentifierR\x02id\x12;\n\tinterface\x18\x02 \x01(\x0b\x32\x1d.flyteidl.core.TypedInterfaceR\tinterface\x12<\n\x0c\x66ixed_inputs\x18\x03 \x01(\x0b\x32\x19.flyteidl.core.LiteralMapR\x0b\x66ixedInputsB\xb3\x01\n\x11\x63om.flyteidl.coreB\rWorkflowProtoP\x01Z:github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core\xa2\x02\x03\x46\x43X\xaa\x02\rFlyteidl.Core\xca\x02\rFlyteidl\\Core\xe2\x02\x19\x46lyteidl\\Core\\GPBMetadata\xea\x02\x0e\x46lyteidl::Coreb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lyteidl/core/workflow.proto\x12\rflyteidl.core\x1a\x1d\x66lyteidl/core/condition.proto\x1a\x1d\x66lyteidl/core/execution.proto\x1a\x1e\x66lyteidl/core/identifier.proto\x1a\x1d\x66lyteidl/core/interface.proto\x1a\x1c\x66lyteidl/core/literals.proto\x1a\x19\x66lyteidl/core/tasks.proto\x1a\x19\x66lyteidl/core/types.proto\x1a\x1c\x66lyteidl/core/security.proto\x1a\x1egoogle/protobuf/duration.proto\"{\n\x07IfBlock\x12>\n\tcondition\x18\x01 \x01(\x0b\x32 .flyteidl.core.BooleanExpressionR\tcondition\x12\x30\n\tthen_node\x18\x02 \x01(\x0b\x32\x13.flyteidl.core.NodeR\x08thenNode\"\xd4\x01\n\x0bIfElseBlock\x12*\n\x04\x63\x61se\x18\x01 \x01(\x0b\x32\x16.flyteidl.core.IfBlockR\x04\x63\x61se\x12,\n\x05other\x18\x02 \x03(\x0b\x32\x16.flyteidl.core.IfBlockR\x05other\x12\x32\n\telse_node\x18\x03 \x01(\x0b\x32\x13.flyteidl.core.NodeH\x00R\x08\x65lseNode\x12,\n\x05\x65rror\x18\x04 \x01(\x0b\x32\x14.flyteidl.core.ErrorH\x00R\x05\x65rrorB\t\n\x07\x64\x65\x66\x61ult\"A\n\nBranchNode\x12\x33\n\x07if_else\x18\x01 \x01(\x0b\x32\x1a.flyteidl.core.IfElseBlockR\x06ifElse\"\x97\x01\n\x08TaskNode\x12>\n\x0creference_id\x18\x01 \x01(\x0b\x32\x19.flyteidl.core.IdentifierH\x00R\x0breferenceId\x12>\n\toverrides\x18\x02 \x01(\x0b\x32 .flyteidl.core.TaskNodeOverridesR\toverridesB\x0b\n\treference\"\xa6\x01\n\x0cWorkflowNode\x12\x42\n\x0elaunchplan_ref\x18\x01 \x01(\x0b\x32\x19.flyteidl.core.IdentifierH\x00R\rlaunchplanRef\x12\x45\n\x10sub_workflow_ref\x18\x02 \x01(\x0b\x32\x19.flyteidl.core.IdentifierH\x00R\x0esubWorkflowRefB\x0b\n\treference\"/\n\x10\x41pproveCondition\x12\x1b\n\tsignal_id\x18\x01 \x01(\tR\x08signalId\"\x90\x01\n\x0fSignalCondition\x12\x1b\n\tsignal_id\x18\x01 \x01(\tR\x08signalId\x12.\n\x04type\x18\x02 \x01(\x0b\x32\x1a.flyteidl.core.LiteralTypeR\x04type\x12\x30\n\x14output_variable_name\x18\x03 \x01(\tR\x12outputVariableName\"G\n\x0eSleepCondition\x12\x35\n\x08\x64uration\x18\x01 \x01(\x0b\x32\x19.google.protobuf.DurationR\x08\x64uration\"\xc5\x01\n\x08GateNode\x12;\n\x07\x61pprove\x18\x01 \x01(\x0b\x32\x1f.flyteidl.core.ApproveConditionH\x00R\x07\x61pprove\x12\x38\n\x06signal\x18\x02 \x01(\x0b\x32\x1e.flyteidl.core.SignalConditionH\x00R\x06signal\x12\x35\n\x05sleep\x18\x03 \x01(\x0b\x32\x1d.flyteidl.core.SleepConditionH\x00R\x05sleepB\x0b\n\tcondition\"\xbf\x01\n\tArrayNode\x12\'\n\x04node\x18\x01 \x01(\x0b\x32\x13.flyteidl.core.NodeR\x04node\x12 \n\x0bparallelism\x18\x02 \x01(\rR\x0bparallelism\x12%\n\rmin_successes\x18\x03 \x01(\rH\x00R\x0cminSuccesses\x12,\n\x11min_success_ratio\x18\x04 \x01(\x02H\x00R\x0fminSuccessRatioB\x12\n\x10success_criteria\"\x8c\x03\n\x0cNodeMetadata\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x33\n\x07timeout\x18\x04 \x01(\x0b\x32\x19.google.protobuf.DurationR\x07timeout\x12\x36\n\x07retries\x18\x05 \x01(\x0b\x32\x1c.flyteidl.core.RetryStrategyR\x07retries\x12&\n\rinterruptible\x18\x06 \x01(\x08H\x00R\rinterruptible\x12\x1e\n\tcacheable\x18\x07 \x01(\x08H\x01R\tcacheable\x12%\n\rcache_version\x18\x08 \x01(\tH\x02R\x0c\x63\x61\x63heVersion\x12/\n\x12\x63\x61\x63he_serializable\x18\t \x01(\x08H\x03R\x11\x63\x61\x63heSerializableB\x15\n\x13interruptible_valueB\x11\n\x0f\x63\x61\x63heable_valueB\x15\n\x13\x63\x61\x63he_version_valueB\x1a\n\x18\x63\x61\x63he_serializable_value\"/\n\x05\x41lias\x12\x10\n\x03var\x18\x01 \x01(\tR\x03var\x12\x14\n\x05\x61lias\x18\x02 \x01(\tR\x05\x61lias\"\x9f\x04\n\x04Node\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x37\n\x08metadata\x18\x02 \x01(\x0b\x32\x1b.flyteidl.core.NodeMetadataR\x08metadata\x12.\n\x06inputs\x18\x03 \x03(\x0b\x32\x16.flyteidl.core.BindingR\x06inputs\x12*\n\x11upstream_node_ids\x18\x04 \x03(\tR\x0fupstreamNodeIds\x12;\n\x0eoutput_aliases\x18\x05 \x03(\x0b\x32\x14.flyteidl.core.AliasR\routputAliases\x12\x36\n\ttask_node\x18\x06 \x01(\x0b\x32\x17.flyteidl.core.TaskNodeH\x00R\x08taskNode\x12\x42\n\rworkflow_node\x18\x07 \x01(\x0b\x32\x1b.flyteidl.core.WorkflowNodeH\x00R\x0cworkflowNode\x12<\n\x0b\x62ranch_node\x18\x08 \x01(\x0b\x32\x19.flyteidl.core.BranchNodeH\x00R\nbranchNode\x12\x36\n\tgate_node\x18\t \x01(\x0b\x32\x17.flyteidl.core.GateNodeH\x00R\x08gateNode\x12\x39\n\narray_node\x18\n \x01(\x0b\x32\x18.flyteidl.core.ArrayNodeH\x00R\tarrayNodeB\x08\n\x06target\"\xfc\x02\n\x10WorkflowMetadata\x12M\n\x12quality_of_service\x18\x01 \x01(\x0b\x32\x1f.flyteidl.core.QualityOfServiceR\x10qualityOfService\x12N\n\non_failure\x18\x02 \x01(\x0e\x32/.flyteidl.core.WorkflowMetadata.OnFailurePolicyR\tonFailure\x12=\n\x04tags\x18\x03 \x03(\x0b\x32).flyteidl.core.WorkflowMetadata.TagsEntryR\x04tags\x1a\x37\n\tTagsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"Q\n\x0fOnFailurePolicy\x12\x14\n\x10\x46\x41IL_IMMEDIATELY\x10\x00\x12(\n$FAIL_AFTER_EXECUTABLE_NODES_COMPLETE\x10\x01\"@\n\x18WorkflowMetadataDefaults\x12$\n\rinterruptible\x18\x01 \x01(\x08R\rinterruptible\"\xa2\x03\n\x10WorkflowTemplate\x12)\n\x02id\x18\x01 \x01(\x0b\x32\x19.flyteidl.core.IdentifierR\x02id\x12;\n\x08metadata\x18\x02 \x01(\x0b\x32\x1f.flyteidl.core.WorkflowMetadataR\x08metadata\x12;\n\tinterface\x18\x03 \x01(\x0b\x32\x1d.flyteidl.core.TypedInterfaceR\tinterface\x12)\n\x05nodes\x18\x04 \x03(\x0b\x32\x13.flyteidl.core.NodeR\x05nodes\x12\x30\n\x07outputs\x18\x05 \x03(\x0b\x32\x16.flyteidl.core.BindingR\x07outputs\x12\x36\n\x0c\x66\x61ilure_node\x18\x06 \x01(\x0b\x32\x13.flyteidl.core.NodeR\x0b\x66\x61ilureNode\x12T\n\x11metadata_defaults\x18\x07 \x01(\x0b\x32\'.flyteidl.core.WorkflowMetadataDefaultsR\x10metadataDefaults\"\xc5\x01\n\x11TaskNodeOverrides\x12\x36\n\tresources\x18\x01 \x01(\x0b\x32\x18.flyteidl.core.ResourcesR\tresources\x12O\n\x12\x65xtended_resources\x18\x02 \x01(\x0b\x32 .flyteidl.core.ExtendedResourcesR\x11\x65xtendedResources\x12\'\n\x0f\x63ontainer_image\x18\x03 \x01(\tR\x0e\x63ontainerImage\"\xba\x01\n\x12LaunchPlanTemplate\x12)\n\x02id\x18\x01 \x01(\x0b\x32\x19.flyteidl.core.IdentifierR\x02id\x12;\n\tinterface\x18\x02 \x01(\x0b\x32\x1d.flyteidl.core.TypedInterfaceR\tinterface\x12<\n\x0c\x66ixed_inputs\x18\x03 \x01(\x0b\x32\x19.flyteidl.core.LiteralMapR\x0b\x66ixedInputsB\xb3\x01\n\x11\x63om.flyteidl.coreB\rWorkflowProtoP\x01Z:github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core\xa2\x02\x03\x46\x43X\xaa\x02\rFlyteidl.Core\xca\x02\rFlyteidl\\Core\xe2\x02\x19\x46lyteidl\\Core\\GPBMetadata\xea\x02\x0e\x46lyteidl::Coreb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -70,7 +70,7 @@ _globals['_WORKFLOWTEMPLATE']._serialized_start=3155 _globals['_WORKFLOWTEMPLATE']._serialized_end=3573 _globals['_TASKNODEOVERRIDES']._serialized_start=3576 - _globals['_TASKNODEOVERRIDES']._serialized_end=3732 - _globals['_LAUNCHPLANTEMPLATE']._serialized_start=3735 - _globals['_LAUNCHPLANTEMPLATE']._serialized_end=3921 + _globals['_TASKNODEOVERRIDES']._serialized_end=3773 + _globals['_LAUNCHPLANTEMPLATE']._serialized_start=3776 + _globals['_LAUNCHPLANTEMPLATE']._serialized_end=3962 # @@protoc_insertion_point(module_scope) diff --git a/flyteidl/gen/pb_python/flyteidl/core/workflow_pb2.pyi b/flyteidl/gen/pb_python/flyteidl/core/workflow_pb2.pyi index efee1e9ec2..11efca2392 100644 --- a/flyteidl/gen/pb_python/flyteidl/core/workflow_pb2.pyi +++ b/flyteidl/gen/pb_python/flyteidl/core/workflow_pb2.pyi @@ -199,12 +199,14 @@ class WorkflowTemplate(_message.Message): def __init__(self, id: _Optional[_Union[_identifier_pb2.Identifier, _Mapping]] = ..., metadata: _Optional[_Union[WorkflowMetadata, _Mapping]] = ..., interface: _Optional[_Union[_interface_pb2.TypedInterface, _Mapping]] = ..., nodes: _Optional[_Iterable[_Union[Node, _Mapping]]] = ..., outputs: _Optional[_Iterable[_Union[_literals_pb2.Binding, _Mapping]]] = ..., failure_node: _Optional[_Union[Node, _Mapping]] = ..., metadata_defaults: _Optional[_Union[WorkflowMetadataDefaults, _Mapping]] = ...) -> None: ... class TaskNodeOverrides(_message.Message): - __slots__ = ["resources", "extended_resources"] + __slots__ = ["resources", "extended_resources", "container_image"] RESOURCES_FIELD_NUMBER: _ClassVar[int] EXTENDED_RESOURCES_FIELD_NUMBER: _ClassVar[int] + CONTAINER_IMAGE_FIELD_NUMBER: _ClassVar[int] resources: _tasks_pb2.Resources extended_resources: _tasks_pb2.ExtendedResources - def __init__(self, resources: _Optional[_Union[_tasks_pb2.Resources, _Mapping]] = ..., extended_resources: _Optional[_Union[_tasks_pb2.ExtendedResources, _Mapping]] = ...) -> None: ... + container_image: str + def __init__(self, resources: _Optional[_Union[_tasks_pb2.Resources, _Mapping]] = ..., extended_resources: _Optional[_Union[_tasks_pb2.ExtendedResources, _Mapping]] = ..., container_image: _Optional[str] = ...) -> None: ... class LaunchPlanTemplate(_message.Message): __slots__ = ["id", "interface", "fixed_inputs"] diff --git a/flyteidl/gen/pb_rust/flyteidl.core.rs b/flyteidl/gen/pb_rust/flyteidl.core.rs index 683581190e..6ef4dcd499 100644 --- a/flyteidl/gen/pb_rust/flyteidl.core.rs +++ b/flyteidl/gen/pb_rust/flyteidl.core.rs @@ -2634,6 +2634,9 @@ pub struct TaskNodeOverrides { /// v1.ResourceRequirements, to allocate to a task. #[prost(message, optional, tag="2")] pub extended_resources: ::core::option::Option, + /// Override for the image used by task pods. + #[prost(string, tag="3")] + pub container_image: ::prost::alloc::string::String, } /// A structure that uniquely identifies a launch plan in the system. #[allow(clippy::derive_partial_eq_without_eq)] diff --git a/flyteidl/protos/flyteidl/core/workflow.proto b/flyteidl/protos/flyteidl/core/workflow.proto index 8433fc9e31..dcbe9367f9 100644 --- a/flyteidl/protos/flyteidl/core/workflow.proto +++ b/flyteidl/protos/flyteidl/core/workflow.proto @@ -297,6 +297,9 @@ message TaskNodeOverrides { // Overrides for all non-standard resources, not captured by // v1.ResourceRequirements, to allocate to a task. ExtendedResources extended_resources = 2; + + // Override for the image used by task pods. + string container_image = 3; } // A structure that uniquely identifies a launch plan in the system. diff --git a/flyteplugins/go/tasks/pluginmachinery/core/exec_metadata.go b/flyteplugins/go/tasks/pluginmachinery/core/exec_metadata.go index 9ac650baaa..83d7dbcf12 100644 --- a/flyteplugins/go/tasks/pluginmachinery/core/exec_metadata.go +++ b/flyteplugins/go/tasks/pluginmachinery/core/exec_metadata.go @@ -12,6 +12,7 @@ import ( type TaskOverrides interface { GetResources() *v1.ResourceRequirements GetExtendedResources() *core.ExtendedResources + GetContainerImage() string GetConfig() *v1.ConfigMap } diff --git a/flyteplugins/go/tasks/pluginmachinery/core/mocks/task_overrides.go b/flyteplugins/go/tasks/pluginmachinery/core/mocks/task_overrides.go index d00973c114..cc3c528007 100644 --- a/flyteplugins/go/tasks/pluginmachinery/core/mocks/task_overrides.go +++ b/flyteplugins/go/tasks/pluginmachinery/core/mocks/task_overrides.go @@ -82,6 +82,41 @@ func (_m *TaskOverrides) GetExtendedResources() *flyteidlcore.ExtendedResources return r0 } +type TaskOverrides_GetContainerImage struct { + *mock.Call +} + +func (_m TaskOverrides_GetContainerImage) Return(_a0 string) *TaskOverrides_GetContainerImage { + return &TaskOverrides_GetContainerImage{Call: _m.Call.Return(_a0)} +} + +func (_m *TaskOverrides) OnGetContainerImage() *TaskOverrides_GetContainerImage { + c_call := _m.On("GetContainerImage") + return &TaskOverrides_GetContainerImage{Call: c_call} +} + +func (_m *TaskOverrides) OnGetContainerImageMatch(matchers ...interface{}) *TaskOverrides_GetContainerImage { + c_call := _m.On("GetContainerImage", matchers...) + return &TaskOverrides_GetContainerImage{Call: c_call} +} + +// GetContainerImage provides a mock function with given fields: +func (_m *TaskOverrides) GetContainerImage() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(string) + } + } + + return r0 +} + + type TaskOverrides_GetResources struct { *mock.Call } diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/plugin_exec_context.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/plugin_exec_context.go index d09bcdb363..ff9ae2ba9f 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/plugin_exec_context.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/plugin_exec_context.go @@ -27,6 +27,10 @@ func (to *pluginTaskOverrides) GetExtendedResources() *core.ExtendedResources { return to.TaskOverrides.GetExtendedResources() } +func (to *pluginTaskOverrides) GetContainerImage() string { + return to.TaskOverrides.GetContainerImage() +} + type pluginTaskExecutionMetadata struct { pluginsCore.TaskExecutionMetadata interruptible *bool diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go index 69111dc030..7f010f6b7c 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go @@ -317,7 +317,7 @@ func BuildRawPod(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (*v } // ApplyFlytePodConfiguration updates the PodSpec and ObjectMeta with various Flyte configuration. This includes -// applying default k8s configuration, resource requests, injecting copilot containers, and merging with the +// applying default k8s configuration, applying overrides (resources etc.), injecting copilot containers, and merging with the // configuration PodTemplate (if exists). func ApplyFlytePodConfiguration(ctx context.Context, tCtx pluginsCore.TaskExecutionContext, podSpec *v1.PodSpec, objectMeta *metav1.ObjectMeta, primaryContainerName string) (*v1.PodSpec, *metav1.ObjectMeta, error) { taskTemplate, err := tCtx.TaskReader().Read(ctx) @@ -400,9 +400,23 @@ func ApplyFlytePodConfiguration(ctx context.Context, tCtx pluginsCore.TaskExecut ApplyGPUNodeSelectors(podSpec, extendedResources.GetGpuAccelerator()) } + // Override container image if necessary + if len(tCtx.TaskExecutionMetadata().GetOverrides().GetContainerImage()) > 0 { + ApplyContainerImageOverride(podSpec, tCtx.TaskExecutionMetadata().GetOverrides().GetContainerImage(), primaryContainerName) + } + return podSpec, objectMeta, nil } +func ApplyContainerImageOverride(podSpec *v1.PodSpec, containerImage string, primaryContainerName string) { + for i, c := range podSpec.Containers { + if c.Name == primaryContainerName { + podSpec.Containers[i].Image = containerImage + return + } + } +} + // ToK8sPodSpec builds a PodSpec and ObjectMeta based on the definition passed by the TaskExecutionContext. This // involves parsing the raw PodSpec definition and applying all Flyte configuration options. func ToK8sPodSpec(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (*v1.PodSpec, *metav1.ObjectMeta, string, error) { diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go index b12813b387..08b6c5b9d0 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go @@ -26,7 +26,7 @@ import ( "github.com/flyteorg/flyte/flytestdlib/storage" ) -func dummyTaskExecutionMetadata(resources *v1.ResourceRequirements, extendedResources *core.ExtendedResources) pluginsCore.TaskExecutionMetadata { +func dummyTaskExecutionMetadata(resources *v1.ResourceRequirements, extendedResources *core.ExtendedResources, containerImage string) pluginsCore.TaskExecutionMetadata { taskExecutionMetadata := &pluginsCoreMock.TaskExecutionMetadata{} taskExecutionMetadata.On("GetNamespace").Return("test-namespace") taskExecutionMetadata.On("GetAnnotations").Return(map[string]string{"annotation-1": "val1"}) @@ -52,6 +52,7 @@ func dummyTaskExecutionMetadata(resources *v1.ResourceRequirements, extendedReso to := &pluginsCoreMock.TaskOverrides{} to.On("GetResources").Return(resources) to.On("GetExtendedResources").Return(extendedResources) + to.On("GetContainerImage").Return(containerImage) taskExecutionMetadata.On("GetOverrides").Return(to) taskExecutionMetadata.On("IsInterruptible").Return(true) taskExecutionMetadata.OnGetPlatformResources().Return(&v1.ResourceRequirements{}) @@ -79,7 +80,7 @@ func dummyInputReader() io.InputReader { return inputReader } -func dummyExecContext(taskTemplate *core.TaskTemplate, r *v1.ResourceRequirements, rm *core.ExtendedResources) pluginsCore.TaskExecutionContext { +func dummyExecContext(taskTemplate *core.TaskTemplate, r *v1.ResourceRequirements, rm *core.ExtendedResources, containerImage string) pluginsCore.TaskExecutionContext { ow := &pluginsIOMock.OutputWriter{} ow.OnGetOutputPrefixPath().Return("") ow.OnGetRawOutputPrefix().Return("") @@ -87,7 +88,7 @@ func dummyExecContext(taskTemplate *core.TaskTemplate, r *v1.ResourceRequirement ow.OnGetPreviousCheckpointsPrefix().Return("/prev") tCtx := &pluginsCoreMock.TaskExecutionContext{} - tCtx.OnTaskExecutionMetadata().Return(dummyTaskExecutionMetadata(r, rm)) + tCtx.OnTaskExecutionMetadata().Return(dummyTaskExecutionMetadata(r, rm, containerImage)) tCtx.OnInputReader().Return(dummyInputReader()) tCtx.OnOutputWriter().Return(ow) @@ -700,7 +701,7 @@ func updatePod(t *testing.T) { v1.ResourceCPU: resource.MustParse("1024m"), v1.ResourceEphemeralStorage: resource.MustParse("100M"), }, - }, nil) + }, nil, "") pod := &v1.Pod{ Spec: v1.PodSpec{ @@ -753,7 +754,7 @@ func updatePod(t *testing.T) { } func TestUpdatePodWithDefaultAffinityAndInterruptibleNodeSelectorRequirement(t *testing.T) { - taskExecutionMetadata := dummyTaskExecutionMetadata(&v1.ResourceRequirements{}, nil) + taskExecutionMetadata := dummyTaskExecutionMetadata(&v1.ResourceRequirements{}, nil, "") assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{ DefaultAffinity: &v1.Affinity{ NodeAffinity: &v1.NodeAffinity{ @@ -817,7 +818,7 @@ func toK8sPodInterruptible(t *testing.T) { v1.ResourceCPU: resource.MustParse("1024m"), v1.ResourceEphemeralStorage: resource.MustParse("100M"), }, - }, nil) + }, nil, "") p, _, _, err := ToK8sPodSpec(ctx, x) assert.NoError(t, err) @@ -884,7 +885,7 @@ func TestToK8sPod(t *testing.T) { v1.ResourceCPU: resource.MustParse("1024m"), v1.ResourceEphemeralStorage: resource.MustParse("100M"), }, - }, nil) + }, nil, "") p, _, _, err := ToK8sPodSpec(ctx, x) assert.NoError(t, err) @@ -901,7 +902,7 @@ func TestToK8sPod(t *testing.T) { v1.ResourceCPU: resource.MustParse("1024m"), v1.ResourceEphemeralStorage: resource.MustParse("100M"), }, - }, nil) + }, nil, "") p, _, _, err := ToK8sPodSpec(ctx, x) assert.NoError(t, err) @@ -919,7 +920,7 @@ func TestToK8sPod(t *testing.T) { v1.ResourceCPU: resource.MustParse("1024m"), v1.ResourceEphemeralStorage: resource.MustParse("100M"), }, - }, nil) + }, nil, "") assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{ DefaultNodeSelector: map[string]string{ @@ -946,7 +947,7 @@ func TestToK8sPod(t *testing.T) { }, })) - x := dummyExecContext(dummyTaskTemplate(), &v1.ResourceRequirements{}, nil) + x := dummyExecContext(dummyTaskTemplate(), &v1.ResourceRequirements{}, nil, "") p, _, _, err := ToK8sPodSpec(ctx, x) assert.NoError(t, err) assert.NotNil(t, p.SecurityContext) @@ -958,7 +959,7 @@ func TestToK8sPod(t *testing.T) { assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{ EnableHostNetworkingPod: &enabled, })) - x := dummyExecContext(dummyTaskTemplate(), &v1.ResourceRequirements{}, nil) + x := dummyExecContext(dummyTaskTemplate(), &v1.ResourceRequirements{}, nil, "") p, _, _, err := ToK8sPodSpec(ctx, x) assert.NoError(t, err) assert.True(t, p.HostNetwork) @@ -969,7 +970,7 @@ func TestToK8sPod(t *testing.T) { assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{ EnableHostNetworkingPod: &enabled, })) - x := dummyExecContext(dummyTaskTemplate(), &v1.ResourceRequirements{}, nil) + x := dummyExecContext(dummyTaskTemplate(), &v1.ResourceRequirements{}, nil, "") p, _, _, err := ToK8sPodSpec(ctx, x) assert.NoError(t, err) assert.False(t, p.HostNetwork) @@ -977,7 +978,7 @@ func TestToK8sPod(t *testing.T) { t.Run("skipSettingHostNetwork", func(t *testing.T) { assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{})) - x := dummyExecContext(dummyTaskTemplate(), &v1.ResourceRequirements{}, nil) + x := dummyExecContext(dummyTaskTemplate(), &v1.ResourceRequirements{}, nil, "") p, _, _, err := ToK8sPodSpec(ctx, x) assert.NoError(t, err) assert.False(t, p.HostNetwork) @@ -1011,7 +1012,7 @@ func TestToK8sPod(t *testing.T) { }, })) - x := dummyExecContext(dummyTaskTemplate(), &v1.ResourceRequirements{}, nil) + x := dummyExecContext(dummyTaskTemplate(), &v1.ResourceRequirements{}, nil, "") p, _, _, err := ToK8sPodSpec(ctx, x) assert.NoError(t, err) assert.NotNil(t, p.DNSConfig) @@ -1032,7 +1033,7 @@ func TestToK8sPod(t *testing.T) { "foo": "bar", }, })) - x := dummyExecContext(dummyTaskTemplate(), &v1.ResourceRequirements{}, nil) + x := dummyExecContext(dummyTaskTemplate(), &v1.ResourceRequirements{}, nil, "") p, _, _, err := ToK8sPodSpec(ctx, x) assert.NoError(t, err) for _, c := range p.Containers { @@ -1047,6 +1048,18 @@ func TestToK8sPod(t *testing.T) { }) } +func TestToK8sPodContainerImage(t *testing.T) { + t.Run("Override container image", func(t *testing.T) { + taskContext := dummyExecContext(dummyTaskTemplate(), &v1.ResourceRequirements{ + Requests: v1.ResourceList{ + v1.ResourceCPU: resource.MustParse("1024m"), + }}, nil, "foo:latest") + p, _, _, err := ToK8sPodSpec(context.TODO(), taskContext) + assert.NoError(t, err) + assert.Equal(t, "foo:latest", p.Containers[0].Image) + }) +} + func TestToK8sPodExtendedResources(t *testing.T) { assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{ GpuDeviceNodeLabel: "gpu-node-label", @@ -1152,7 +1165,7 @@ func TestToK8sPodExtendedResources(t *testing.T) { t.Run(f.name, func(t *testing.T) { taskTemplate := dummyTaskTemplate() taskTemplate.ExtendedResources = f.extendedResourcesBase - taskContext := dummyExecContext(taskTemplate, f.resources, f.extendedResourcesOverride) + taskContext := dummyExecContext(taskTemplate, f.resources, f.extendedResourcesOverride, "") p, _, _, err := ToK8sPodSpec(context.TODO(), taskContext) assert.NoError(t, err) @@ -1790,7 +1803,7 @@ func TestGetPodTemplate(t *testing.T) { taskReader.On("Read", mock.Anything).Return(task, nil) tCtx := &pluginsCoreMock.TaskExecutionContext{} - tCtx.OnTaskExecutionMetadata().Return(dummyTaskExecutionMetadata(&v1.ResourceRequirements{}, nil)) + tCtx.OnTaskExecutionMetadata().Return(dummyTaskExecutionMetadata(&v1.ResourceRequirements{}, nil, "")) tCtx.OnTaskReader().Return(taskReader) // initialize PodTemplateStore @@ -1816,7 +1829,7 @@ func TestGetPodTemplate(t *testing.T) { taskReader.On("Read", mock.Anything).Return(task, nil) tCtx := &pluginsCoreMock.TaskExecutionContext{} - tCtx.OnTaskExecutionMetadata().Return(dummyTaskExecutionMetadata(&v1.ResourceRequirements{}, nil)) + tCtx.OnTaskExecutionMetadata().Return(dummyTaskExecutionMetadata(&v1.ResourceRequirements{}, nil, "")) tCtx.OnTaskReader().Return(taskReader) // initialize PodTemplateStore @@ -1843,7 +1856,7 @@ func TestGetPodTemplate(t *testing.T) { taskReader.On("Read", mock.Anything).Return(task, nil) tCtx := &pluginsCoreMock.TaskExecutionContext{} - tCtx.OnTaskExecutionMetadata().Return(dummyTaskExecutionMetadata(&v1.ResourceRequirements{}, nil)) + tCtx.OnTaskExecutionMetadata().Return(dummyTaskExecutionMetadata(&v1.ResourceRequirements{}, nil, "")) tCtx.OnTaskReader().Return(taskReader) // initialize PodTemplateStore @@ -1871,7 +1884,7 @@ func TestGetPodTemplate(t *testing.T) { taskReader.On("Read", mock.Anything).Return(task, nil) tCtx := &pluginsCoreMock.TaskExecutionContext{} - tCtx.OnTaskExecutionMetadata().Return(dummyTaskExecutionMetadata(&v1.ResourceRequirements{}, nil)) + tCtx.OnTaskExecutionMetadata().Return(dummyTaskExecutionMetadata(&v1.ResourceRequirements{}, nil, "")) tCtx.OnTaskReader().Return(taskReader) // initialize PodTemplateStore @@ -1913,7 +1926,7 @@ func TestMergeWithBasePodTemplate(t *testing.T) { taskReader.On("Read", mock.Anything).Return(task, nil) tCtx := &pluginsCoreMock.TaskExecutionContext{} - tCtx.OnTaskExecutionMetadata().Return(dummyTaskExecutionMetadata(&v1.ResourceRequirements{}, nil)) + tCtx.OnTaskExecutionMetadata().Return(dummyTaskExecutionMetadata(&v1.ResourceRequirements{}, nil, "")) tCtx.OnTaskReader().Return(taskReader) resultPodSpec, resultObjectMeta, err := MergeWithBasePodTemplate(context.TODO(), tCtx, &podSpec, &objectMeta, "foo") @@ -1967,7 +1980,7 @@ func TestMergeWithBasePodTemplate(t *testing.T) { taskReader.On("Read", mock.Anything).Return(task, nil) tCtx := &pluginsCoreMock.TaskExecutionContext{} - tCtx.OnTaskExecutionMetadata().Return(dummyTaskExecutionMetadata(&v1.ResourceRequirements{}, nil)) + tCtx.OnTaskExecutionMetadata().Return(dummyTaskExecutionMetadata(&v1.ResourceRequirements{}, nil, "")) tCtx.OnTaskReader().Return(taskReader) resultPodSpec, resultObjectMeta, err := MergeWithBasePodTemplate(context.TODO(), tCtx, &podSpec, &objectMeta, "foo") diff --git a/flyteplugins/go/tasks/plugins/array/k8s/management_test.go b/flyteplugins/go/tasks/plugins/array/k8s/management_test.go index 2bd1d5eefe..0eaed65467 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/management_test.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/management_test.go @@ -94,6 +94,7 @@ func getMockTaskExecutionContext(ctx context.Context, parallelism int) *mocks.Ta }, }) overrides.OnGetExtendedResources().Return(nil) + overrides.OnGetContainerImage().Return("") tMeta := &mocks.TaskExecutionMetadata{} tMeta.OnGetTaskExecutionID().Return(tID) diff --git a/flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go b/flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go index e7d43a2256..86ca034e7b 100644 --- a/flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go @@ -194,6 +194,7 @@ func dummyDaskTaskContext(taskTemplate *core.TaskTemplate, resources *v1.Resourc overrides := &mocks.TaskOverrides{} overrides.OnGetResources().Return(resources) overrides.OnGetExtendedResources().Return(extendedResources) + overrides.OnGetContainerImage().Return("") taskExecutionMetadata.OnGetOverrides().Return(overrides) taskCtx.On("TaskExecutionMetadata").Return(taskExecutionMetadata) return taskCtx diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go index 29fe47a446..3c823510ac 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go @@ -153,6 +153,7 @@ func dummyMPITaskContext(taskTemplate *core.TaskTemplate, resources *corev1.Reso overrides := &mocks.TaskOverrides{} overrides.OnGetResources().Return(resources) overrides.OnGetExtendedResources().Return(extendedResources) + overrides.OnGetContainerImage().Return("") taskExecutionMetadata := &mocks.TaskExecutionMetadata{} taskExecutionMetadata.OnGetTaskExecutionID().Return(tID) diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go index 0700644578..aea096e4ba 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go @@ -123,7 +123,7 @@ func dummyPytorchTaskTemplate(id string, args ...interface{}) *core.TaskTemplate } } -func dummyPytorchTaskContext(taskTemplate *core.TaskTemplate, resources *corev1.ResourceRequirements, extendedResources *core.ExtendedResources) pluginsCore.TaskExecutionContext { +func dummyPytorchTaskContext(taskTemplate *core.TaskTemplate, resources *corev1.ResourceRequirements, extendedResources *core.ExtendedResources, containerImage string) pluginsCore.TaskExecutionContext { taskCtx := &mocks.TaskExecutionContext{} inputReader := &pluginIOMocks.InputReader{} inputReader.OnGetInputPrefixPath().Return("/input/prefix") @@ -159,6 +159,7 @@ func dummyPytorchTaskContext(taskTemplate *core.TaskTemplate, resources *corev1. overrides := &mocks.TaskOverrides{} overrides.OnGetResources().Return(resources) overrides.OnGetExtendedResources().Return(extendedResources) + overrides.OnGetContainerImage().Return(containerImage) taskExecutionMetadata := &mocks.TaskExecutionMetadata{} taskExecutionMetadata.OnGetTaskExecutionID().Return(tID) @@ -279,7 +280,7 @@ func dummyPytorchJobResource(pytorchResourceHandler pytorchOperatorResourceHandl ptObj := dummyPytorchCustomObj(workers) taskTemplate := dummyPytorchTaskTemplate("job1", ptObj) - resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) if err != nil { panic(err) } @@ -307,7 +308,7 @@ func TestBuildResourcePytorchElastic(t *testing.T) { ptObj := dummyElasticPytorchCustomObj(2, plugins.ElasticConfig{MinReplicas: 1, MaxReplicas: 2, NprocPerNode: 4, RdzvBackend: "c10d"}) taskTemplate := dummyPytorchTaskTemplate("job2", ptObj) - resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) assert.NoError(t, err) assert.NotNil(t, resource) @@ -350,7 +351,7 @@ func TestBuildResourcePytorch(t *testing.T) { ptObj := dummyPytorchCustomObj(100) taskTemplate := dummyPytorchTaskTemplate("job3", ptObj) - res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil)) + res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) assert.NoError(t, err) assert.NotNil(t, res) @@ -386,6 +387,74 @@ func TestBuildResourcePytorch(t *testing.T) { } } +func TestBuildResourcePytorchContainerImage(t *testing.T) { + assert.NoError(t, flytek8sConfig.SetK8sPluginConfig(&flytek8sConfig.K8sPluginConfig{})) + + fixtures := []struct { + name string + resources *corev1.ResourceRequirements + containerImageOverride string + }{ + { + "without overrides", + &corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + }, + "", + }, + { + "with overrides", + &corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + }, + "container-image-override", + }, + } + + testConfigs := []struct { + name string + plugin *plugins.DistributedPyTorchTrainingTask + }{ + { + "pytorch", + dummyPytorchCustomObj(100), + }, + { + "elastic pytorch", + dummyElasticPytorchCustomObj(2, plugins.ElasticConfig{MinReplicas: 1, MaxReplicas: 2, NprocPerNode: 4, RdzvBackend: "c10d"}), + }, + } + + for _, tCfg := range testConfigs { + for _, f := range fixtures { + t.Run(tCfg.name+" "+f.name, func(t *testing.T) { + taskTemplate := dummyPytorchTaskTemplate("job", tCfg.plugin) + taskContext := dummyPytorchTaskContext(taskTemplate, f.resources, nil, f.containerImageOverride) + pytorchResourceHandler := pytorchOperatorResourceHandler{} + r, err := pytorchResourceHandler.BuildResource(context.TODO(), taskContext) + assert.NoError(t, err) + assert.NotNil(t, r) + pytorchJob, ok := r.(*kubeflowv1.PyTorchJob) + assert.True(t, ok) + + for _, replicaSpec := range pytorchJob.Spec.PyTorchReplicaSpecs { + var expectedContainerImage string + if len(f.containerImageOverride) > 0 { + expectedContainerImage = f.containerImageOverride + } else { + expectedContainerImage = testImage + } + assert.Equal(t, expectedContainerImage, replicaSpec.Template.Spec.Containers[0].Image) + } + }) + } + } +} + func TestBuildResourcePytorchExtendedResources(t *testing.T) { assert.NoError(t, flytek8sConfig.SetK8sPluginConfig(&flytek8sConfig.K8sPluginConfig{ GpuDeviceNodeLabel: "gpu-node-label", @@ -506,7 +575,7 @@ func TestBuildResourcePytorchExtendedResources(t *testing.T) { t.Run(tCfg.name+" "+f.name, func(t *testing.T) { taskTemplate := dummyPytorchTaskTemplate("job", tCfg.plugin) taskTemplate.ExtendedResources = f.extendedResourcesBase - taskContext := dummyPytorchTaskContext(taskTemplate, f.resources, f.extendedResourcesOverride) + taskContext := dummyPytorchTaskContext(taskTemplate, f.resources, f.extendedResourcesOverride, "") pytorchResourceHandler := pytorchOperatorResourceHandler{} r, err := pytorchResourceHandler.BuildResource(context.TODO(), taskContext) assert.NoError(t, err) @@ -539,7 +608,7 @@ func TestGetTaskPhase(t *testing.T) { return dummyPytorchJobResource(pytorchResourceHandler, 2, conditionType) } - taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(2)), resourceRequirements, nil) + taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(2)), resourceRequirements, nil, "") taskPhase, err := pytorchResourceHandler.GetTaskPhase(ctx, taskCtx, dummyPytorchJobResourceCreator(commonOp.JobCreated)) assert.NoError(t, err) assert.Equal(t, pluginsCore.PhaseQueued, taskPhase.Phase()) @@ -582,7 +651,7 @@ func TestGetLogs(t *testing.T) { pytorchResourceHandler := pytorchOperatorResourceHandler{} pytorchJob := dummyPytorchJobResource(pytorchResourceHandler, workers, commonOp.JobRunning) - taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(workers)), resourceRequirements, nil) + taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(workers)), resourceRequirements, nil, "") jobLogs, err := common.GetLogs(taskCtx, common.PytorchTaskType, pytorchJob.ObjectMeta, hasMaster, workers, 0, 0, 0) assert.NoError(t, err) assert.Equal(t, 3, len(jobLogs)) @@ -602,7 +671,7 @@ func TestGetLogsElastic(t *testing.T) { pytorchResourceHandler := pytorchOperatorResourceHandler{} pytorchJob := dummyPytorchJobResource(pytorchResourceHandler, workers, commonOp.JobRunning) - taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(workers)), resourceRequirements, nil) + taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(workers)), resourceRequirements, nil, "") jobLogs, err := common.GetLogs(taskCtx, common.PytorchTaskType, pytorchJob.ObjectMeta, hasMaster, workers, 0, 0, 0) assert.NoError(t, err) assert.Equal(t, 2, len(jobLogs)) @@ -633,7 +702,7 @@ func TestReplicaCounts(t *testing.T) { ptObj := dummyPytorchCustomObj(test.workerReplicaCount) taskTemplate := dummyPytorchTaskTemplate("the job", ptObj) - res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil)) + res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) if test.expectError { assert.Error(t, err) assert.Nil(t, res) @@ -715,7 +784,7 @@ func TestBuildResourcePytorchV1(t *testing.T) { taskTemplate := dummyPytorchTaskTemplate("job4", taskConfig) taskTemplate.TaskTypeVersion = 1 - res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil)) + res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) assert.NoError(t, err) assert.NotNil(t, res) @@ -759,7 +828,7 @@ func TestBuildResourcePytorchV1WithRunPolicy(t *testing.T) { taskTemplate := dummyPytorchTaskTemplate("job5", taskConfig) taskTemplate.TaskTypeVersion = 1 - res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil)) + res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) assert.NoError(t, err) assert.NotNil(t, res) @@ -819,7 +888,7 @@ func TestBuildResourcePytorchV1WithOnlyWorkerSpec(t *testing.T) { taskTemplate := dummyPytorchTaskTemplate("job5", taskConfig) taskTemplate.TaskTypeVersion = 1 - res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil)) + res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) assert.NoError(t, err) assert.NotNil(t, res) @@ -890,7 +959,7 @@ func TestBuildResourcePytorchV1ResourceTolerations(t *testing.T) { taskTemplate := dummyPytorchTaskTemplate("job4", taskConfig) taskTemplate.TaskTypeVersion = 1 - res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil)) + res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) assert.NoError(t, err) assert.NotNil(t, res) @@ -912,7 +981,7 @@ func TestBuildResourcePytorchV1WithElastic(t *testing.T) { taskTemplate.TaskTypeVersion = 1 pytorchResourceHandler := pytorchOperatorResourceHandler{} - resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) assert.NoError(t, err) assert.NotNil(t, resource) @@ -948,7 +1017,7 @@ func TestBuildResourcePytorchV1WithZeroWorker(t *testing.T) { pytorchResourceHandler := pytorchOperatorResourceHandler{} taskTemplate := dummyPytorchTaskTemplate("job5", taskConfig) taskTemplate.TaskTypeVersion = 1 - _, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil)) + _, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) assert.Error(t, err) } @@ -965,7 +1034,7 @@ func TestGetReplicaCount(t *testing.T) { pytorchResourceHandler := pytorchOperatorResourceHandler{} tfObj := dummyPytorchCustomObj(1) taskTemplate := dummyPytorchTaskTemplate("the job", tfObj) - resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) assert.NoError(t, err) assert.NotNil(t, resource) PytorchJob, ok := resource.(*kubeflowv1.PyTorchJob) diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go index 2402bd4c5a..80e95871d1 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go @@ -154,6 +154,7 @@ func dummyTensorFlowTaskContext(taskTemplate *core.TaskTemplate, resources *core overrides := &mocks.TaskOverrides{} overrides.OnGetResources().Return(resources) overrides.OnGetExtendedResources().Return(extendedResources) + overrides.OnGetContainerImage().Return("") taskExecutionMetadata := &mocks.TaskExecutionMetadata{} taskExecutionMetadata.OnGetTaskExecutionID().Return(tID) diff --git a/flyteplugins/go/tasks/plugins/k8s/pod/container_test.go b/flyteplugins/go/tasks/plugins/k8s/pod/container_test.go index e69f764549..9c098bd708 100644 --- a/flyteplugins/go/tasks/plugins/k8s/pod/container_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/pod/container_test.go @@ -79,7 +79,7 @@ func dummyContainerTaskTemplateWithPodSpec(command []string, args []string) *cor return taskTemplate } -func dummyContainerTaskMetadata(resources *v1.ResourceRequirements, extendedResources *core.ExtendedResources, returnsServiceAccount bool) pluginsCore.TaskExecutionMetadata { +func dummyContainerTaskMetadata(resources *v1.ResourceRequirements, extendedResources *core.ExtendedResources, returnsServiceAccount bool, containerImage string) pluginsCore.TaskExecutionMetadata { taskMetadata := &pluginsCoreMock.TaskExecutionMetadata{} taskMetadata.On("GetNamespace").Return("test-namespace") taskMetadata.On("GetAnnotations").Return(map[string]string{"annotation-1": "val1"}) @@ -118,6 +118,7 @@ func dummyContainerTaskMetadata(resources *v1.ResourceRequirements, extendedReso to := &pluginsCoreMock.TaskOverrides{} to.On("GetResources").Return(resources) to.On("GetExtendedResources").Return(extendedResources) + to.OnGetContainerImage().Return(containerImage) taskMetadata.On("GetOverrides").Return(to) taskMetadata.On("IsInterruptible").Return(true) taskMetadata.On("GetEnvironmentVariables").Return(nil) @@ -176,19 +177,19 @@ func TestContainerTaskExecutor_BuildResource(t *testing.T) { { name: "BuildResource", taskTemplate: dummyContainerTaskTemplate(command, args), - taskMetadata: dummyContainerTaskMetadata(containerResourceRequirements, nil, true), + taskMetadata: dummyContainerTaskMetadata(containerResourceRequirements, nil, true, ""), expectServiceAccount: serviceAccount, }, { name: "BuildResource_PodTemplate", taskTemplate: dummyContainerTaskTemplateWithPodSpec(command, args), - taskMetadata: dummyContainerTaskMetadata(containerResourceRequirements, nil, true), + taskMetadata: dummyContainerTaskMetadata(containerResourceRequirements, nil, true, ""), expectServiceAccount: podTemplateServiceAccount, }, { name: "BuildResource_SecurityContext", taskTemplate: dummyContainerTaskTemplate(command, args), - taskMetadata: dummyContainerTaskMetadata(containerResourceRequirements, nil, false), + taskMetadata: dummyContainerTaskMetadata(containerResourceRequirements, nil, false, ""), expectServiceAccount: securityContextServiceAccount, }, } @@ -321,7 +322,7 @@ func TestContainerTaskExecutor_BuildResource_ExtendedResources(t *testing.T) { t.Run(f.name, func(t *testing.T) { taskTemplate := dummyContainerTaskTemplate([]string{"command"}, []string{"{{.Input}}"}) taskTemplate.ExtendedResources = f.extendedResourcesBase - taskMetadata := dummyContainerTaskMetadata(f.resources, f.extendedResourcesOverride, true) + taskMetadata := dummyContainerTaskMetadata(f.resources, f.extendedResourcesOverride, true, "") taskContext := dummyContainerTaskContext(taskTemplate, taskMetadata) r, err := DefaultPodPlugin.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) @@ -343,11 +344,54 @@ func TestContainerTaskExecutor_BuildResource_ExtendedResources(t *testing.T) { } } +func TestContainerTaskExecutor_BuildResource_ContainerImage(t *testing.T) { + assert.NoError(t, flytek8sConfig.SetK8sPluginConfig(&flytek8sConfig.K8sPluginConfig{})) + + fixtures := []struct { + name string + resources *v1.ResourceRequirements + containerImageOverride string + }{ + { + "without overrides", + &v1.ResourceRequirements{ + Limits: v1.ResourceList{ + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + }, + "", + }, + { + "with overrides", + &v1.ResourceRequirements{ + Limits: v1.ResourceList{ + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + }, + "test-image", + }, + } + + for _, f := range fixtures { + t.Run(f.name, func(t *testing.T) { + taskTemplate := dummyContainerTaskTemplate([]string{"command"}, []string{"{{.Input}}"}) + taskMetadata := dummyContainerTaskMetadata(f.resources, nil, true, f.containerImageOverride) + taskContext := dummyContainerTaskContext(taskTemplate, taskMetadata) + r, err := DefaultPodPlugin.BuildResource(context.TODO(), taskContext) + assert.Nil(t, err) + assert.NotNil(t, r) + _, ok := r.(*v1.Pod) + assert.True(t, ok) + assert.Equal(t, f.containerImageOverride, r.(*v1.Pod).Spec.Containers[0].Image) + }) + } +} + func TestContainerTaskExecutor_GetTaskStatus(t *testing.T) { command := []string{"command"} args := []string{"{{.Input}}"} taskTemplate := dummyContainerTaskTemplate(command, args) - taskMetadata := dummyContainerTaskMetadata(containerResourceRequirements, nil, true) + taskMetadata := dummyContainerTaskMetadata(containerResourceRequirements, nil, true, "") taskCtx := dummyContainerTaskContext(taskTemplate, taskMetadata) j := &v1.Pod{ @@ -437,7 +481,7 @@ func TestContainerTaskExecutor_GetTaskStatus_InvalidImageName(t *testing.T) { command := []string{"command"} args := []string{"{{.Input}}"} taskTemplate := dummyContainerTaskTemplate(command, args) - taskMetadata := dummyContainerTaskMetadata(containerResourceRequirements, nil, true) + taskMetadata := dummyContainerTaskMetadata(containerResourceRequirements, nil, true, "") taskCtx := dummyContainerTaskContext(taskTemplate, taskMetadata) ctx := context.TODO() diff --git a/flyteplugins/go/tasks/plugins/k8s/pod/sidecar_test.go b/flyteplugins/go/tasks/plugins/k8s/pod/sidecar_test.go index a9848f3276..06166d15fb 100644 --- a/flyteplugins/go/tasks/plugins/k8s/pod/sidecar_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/pod/sidecar_test.go @@ -90,6 +90,7 @@ func dummySidecarTaskMetadata(resources *v1.ResourceRequirements, extendedResour to := &pluginsCoreMock.TaskOverrides{} to.On("GetResources").Return(resources) to.On("GetExtendedResources").Return(extendedResources) + to.On("GetContainerImage").Return("") taskMetadata.On("GetOverrides").Return(to) taskMetadata.On("GetEnvironmentVariables").Return(nil) diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go index 02ed83db14..69f65c19ef 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go @@ -106,7 +106,7 @@ func dummyRayTaskTemplate(id string, rayJob *plugins.RayJob) *core.TaskTemplate } } -func dummyRayTaskContext(taskTemplate *core.TaskTemplate, resources *corev1.ResourceRequirements, extendedResources *core.ExtendedResources) pluginsCore.TaskExecutionContext { +func dummyRayTaskContext(taskTemplate *core.TaskTemplate, resources *corev1.ResourceRequirements, extendedResources *core.ExtendedResources, containerImage string) pluginsCore.TaskExecutionContext { taskCtx := &mocks.TaskExecutionContext{} inputReader := &pluginIOMocks.InputReader{} inputReader.OnGetInputPrefixPath().Return("/input/prefix") @@ -141,6 +141,7 @@ func dummyRayTaskContext(taskTemplate *core.TaskTemplate, resources *corev1.Reso overrides := &mocks.TaskOverrides{} overrides.OnGetResources().Return(resources) overrides.OnGetExtendedResources().Return(extendedResources) + overrides.OnGetContainerImage().Return(containerImage) taskExecutionMetadata := &mocks.TaskExecutionMetadata{} taskExecutionMetadata.OnGetTaskExecutionID().Return(tID) @@ -175,7 +176,7 @@ func TestBuildResourceRay(t *testing.T) { err := config.SetK8sPluginConfig(&config.K8sPluginConfig{DefaultTolerations: toleration}) assert.Nil(t, err) - RayResource, err := rayJobResourceHandler.BuildResource(context.TODO(), dummyRayTaskContext(taskTemplate, resourceRequirements, nil)) + RayResource, err := rayJobResourceHandler.BuildResource(context.TODO(), dummyRayTaskContext(taskTemplate, resourceRequirements, nil, "")) assert.Nil(t, err) assert.NotNil(t, RayResource) @@ -210,6 +211,63 @@ func TestBuildResourceRay(t *testing.T) { assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Tolerations, toleration) } +func TestBuildResourceRayContainerImage(t *testing.T) { + assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{})) + + fixtures := []struct { + name string + resources *corev1.ResourceRequirements + containerImageOverride string + }{ + { + "without overrides", + &corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + }, + "", + }, + { + "with overrides", + &corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + }, + "container-image-override", + }, + } + + for _, f := range fixtures { + t.Run(f.name, func(t *testing.T) { + taskTemplate := dummyRayTaskTemplate("id", dummyRayCustomObj()) + taskContext := dummyRayTaskContext(taskTemplate, f.resources, nil, f.containerImageOverride) + rayJobResourceHandler := rayJobResourceHandler{} + r, err := rayJobResourceHandler.BuildResource(context.TODO(), taskContext) + assert.Nil(t, err) + assert.NotNil(t, r) + rayJob, ok := r.(*rayv1alpha1.RayJob) + assert.True(t, ok) + + var expectedContainerImage string + if len(f.containerImageOverride) > 0 { + expectedContainerImage = f.containerImageOverride + } else { + expectedContainerImage = testImage + } + + // Head node + headNodeSpec := rayJob.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec + assert.Equal(t, expectedContainerImage, headNodeSpec.Containers[0].Image) + + // Worker node + workerNodeSpec := rayJob.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec + assert.Equal(t, expectedContainerImage, workerNodeSpec.Containers[0].Image) + }) + } +} + func TestBuildResourceRayExtendedResources(t *testing.T) { assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{ GpuDeviceNodeLabel: "gpu-node-label", @@ -315,7 +373,7 @@ func TestBuildResourceRayExtendedResources(t *testing.T) { t.Run(p.name, func(t *testing.T) { taskTemplate := dummyRayTaskTemplate("ray-id", dummyRayCustomObj()) taskTemplate.ExtendedResources = p.extendedResourcesBase - taskContext := dummyRayTaskContext(taskTemplate, p.resources, p.extendedResourcesOverride) + taskContext := dummyRayTaskContext(taskTemplate, p.resources, p.extendedResourcesOverride, "") rayJobResourceHandler := rayJobResourceHandler{} r, err := rayJobResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) @@ -374,7 +432,7 @@ func TestDefaultStartParameters(t *testing.T) { err := config.SetK8sPluginConfig(&config.K8sPluginConfig{DefaultTolerations: toleration}) assert.Nil(t, err) - RayResource, err := rayJobResourceHandler.BuildResource(context.TODO(), dummyRayTaskContext(taskTemplate, resourceRequirements, nil)) + RayResource, err := rayJobResourceHandler.BuildResource(context.TODO(), dummyRayTaskContext(taskTemplate, resourceRequirements, nil, "")) assert.Nil(t, err) assert.NotNil(t, RayResource) @@ -582,7 +640,7 @@ func TestInjectLogsSidecar(t *testing.T) { assert.NoError(t, SetConfig(&Config{ LogsSidecar: p.logsSidecarCfg, })) - taskContext := dummyRayTaskContext(&p.taskTemplate, resourceRequirements, nil) + taskContext := dummyRayTaskContext(&p.taskTemplate, resourceRequirements, nil, "") rayJobResourceHandler := rayJobResourceHandler{} r, err := rayJobResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go index 264f9514e9..561901226a 100644 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go @@ -385,6 +385,7 @@ func dummySparkTaskContext(taskTemplate *core.TaskTemplate, interruptible bool) overrides.On("GetResources").Return(&corev1.ResourceRequirements{}) // No support for GPUs, and consequently, ExtendedResources on Spark plugin. overrides.On("GetExtendedResources").Return(nil) + overrides.OnGetContainerImage().Return("") taskExecutionMetadata := &mocks.TaskExecutionMetadata{} taskExecutionMetadata.On("GetTaskExecutionID").Return(tID) diff --git a/flyteplugins/tests/end_to_end.go b/flyteplugins/tests/end_to_end.go index 037ae877d9..1a55b52c0c 100644 --- a/flyteplugins/tests/end_to_end.go +++ b/flyteplugins/tests/end_to_end.go @@ -147,6 +147,7 @@ func RunPluginEndToEndTest(t *testing.T, executor pluginCore.Plugin, template *i Limits: map[v1.ResourceName]resource.Quantity{}, }) overrides.OnGetExtendedResources().Return(&idlCore.ExtendedResources{}) + overrides.OnGetContainerImage().Return("") tMeta := &coreMocks.TaskExecutionMetadata{} tMeta.OnGetTaskExecutionID().Return(tID) diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go index 6590aaa04a..664ebb6767 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go @@ -442,6 +442,7 @@ type ExecutableNode interface { GetActiveDeadline() *time.Duration IsInterruptible() *bool GetName() string + GetContainerImage() string } // ExecutableWorkflowStatus is an interface for the Workflow p. This is the mutable portion for a Workflow diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNode.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNode.go index ae1663bba1..e6a03b7deb 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNode.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNode.go @@ -154,6 +154,38 @@ func (_m *ExecutableNode) GetConfig() *v1.ConfigMap { return r0 } +type ExecutableNode_GetContainerImage struct { + *mock.Call +} + +func (_m ExecutableNode_GetContainerImage) Return(_a0 string) *ExecutableNode_GetContainerImage { + return &ExecutableNode_GetContainerImage{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutableNode) OnGetContainerImage() *ExecutableNode_GetContainerImage { + c_call := _m.On("GetContainerImage") + return &ExecutableNode_GetContainerImage{Call: c_call} +} + +func (_m *ExecutableNode) OnGetContainerImageMatch(matchers ...interface{}) *ExecutableNode_GetContainerImage { + c_call := _m.On("GetContainerImage", matchers...) + return &ExecutableNode_GetContainerImage{Call: c_call} +} + +// GetContainerImage provides a mock function with given fields: +func (_m *ExecutableNode) GetContainerImage() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + type ExecutableNode_GetExecutionDeadline struct { *mock.Call } diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/nodes.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/nodes.go index 8a22cae974..6554357031 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/nodes.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/nodes.go @@ -175,6 +175,8 @@ type NodeSpec struct { // The value set to True means task is OK with getting interrupted // +optional Interruptible *bool `json:"interruptible,omitempty"` + + ContainerImage string `json:"containerImage,omitempty"` } func (in *NodeSpec) GetName() string { @@ -273,3 +275,7 @@ func (in *NodeSpec) IsEndNode() bool { func (in *NodeSpec) GetInputBindings() []*Binding { return in.InputBindings } + +func (in *NodeSpec) GetContainerImage() string { + return in.ContainerImage +} diff --git a/flytepropeller/pkg/compiler/transformers/k8s/node.go b/flytepropeller/pkg/compiler/transformers/k8s/node.go index 7120f3ad90..7b5df9b3b5 100644 --- a/flytepropeller/pkg/compiler/transformers/k8s/node.go +++ b/flytepropeller/pkg/compiler/transformers/k8s/node.go @@ -30,6 +30,7 @@ func buildNodeSpec(n *core.Node, tasks []*core.CompiledTask, errs errors.Compile var task *core.TaskTemplate var resources *core.Resources var extendedResources *v1alpha1.ExtendedResources + var containerImage string if n.GetTaskNode() != nil { taskID := n.GetTaskNode().GetReferenceId().String() // TODO: Use task index for quick lookup @@ -55,6 +56,10 @@ func buildNodeSpec(n *core.Node, tasks []*core.CompiledTask, errs errors.Compile ExtendedResources: overrides.GetExtendedResources(), } } + + if len(overrides.GetContainerImage()) > 0 { + containerImage = overrides.GetContainerImage() + } } } @@ -96,6 +101,7 @@ func buildNodeSpec(n *core.Node, tasks []*core.CompiledTask, errs errors.Compile InputBindings: toBindingValueArray(n.GetInputs()), ActiveDeadline: activeDeadline, Interruptible: interruptible, + ContainerImage: containerImage, } switch v := n.GetTarget().(type) { diff --git a/flytepropeller/pkg/compiler/transformers/k8s/node_test.go b/flytepropeller/pkg/compiler/transformers/k8s/node_test.go index f2d6c79494..ff1f263b6e 100644 --- a/flytepropeller/pkg/compiler/transformers/k8s/node_test.go +++ b/flytepropeller/pkg/compiler/transformers/k8s/node_test.go @@ -140,6 +140,24 @@ func TestBuildNodeSpec(t *testing.T) { assert.Equal(t, expectedGpuDevice, spec.GetExtendedResources().GetGpuAccelerator().GetDevice()) }) + t.Run("node with container image override", func(t *testing.T) { + expectedContainerImage := "test-image:latest" + n.Node.Target = &core.Node_TaskNode{ + TaskNode: &core.TaskNode{ + Reference: &core.TaskNode_ReferenceId{ + ReferenceId: &core.Identifier{Name: "ref_2"}, + }, + Overrides: &core.TaskNodeOverrides{ + ContainerImage: expectedContainerImage, + }, + }, + } + + spec := mustBuild(t, n, 1, errs.NewScope()) + assert.NotNil(t, spec.GetContainerImage()) + assert.Equal(t, expectedContainerImage, spec.GetContainerImage()) + }) + t.Run("LaunchPlanRef", func(t *testing.T) { n.Node.Target = &core.Node_WorkflowNode{ WorkflowNode: &core.WorkflowNode{