From ae8aed53bbe584cc3fd5314f34844420b1bf8eec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabio=20Gr=C3=A4tz?= Date: Mon, 10 Apr 2023 10:39:14 +0200 Subject: [PATCH] Add elastic config message type for torchrun training Signed-off-by: Fabio Graetz --- gen/pb-cpp/flyteidl/plugins/pytorch.pb.cc | 741 ++++++--- gen/pb-cpp/flyteidl/plugins/pytorch.pb.h | 422 ++++-- gen/pb-go/flyteidl/plugins/pytorch.pb.go | 158 +- .../flyteidl/plugins/pytorch.pb.validate.go | 91 +- gen/pb-java/flyteidl/plugins/Pytorch.java | 1323 ++++++++++++----- gen/pb_python/flyteidl/plugins/pytorch_pb2.py | 8 +- .../flyteidl/plugins/pytorch_pb2.pyi | 32 +- gen/pb_rust/flyteidl.plugins.rs | 30 +- protos/flyteidl/plugins/pytorch.proto | 18 +- 9 files changed, 2063 insertions(+), 760 deletions(-) diff --git a/gen/pb-cpp/flyteidl/plugins/pytorch.pb.cc b/gen/pb-cpp/flyteidl/plugins/pytorch.pb.cc index 4f4096218..20bc82696 100644 --- a/gen/pb-cpp/flyteidl/plugins/pytorch.pb.cc +++ b/gen/pb-cpp/flyteidl/plugins/pytorch.pb.cc @@ -16,14 +16,33 @@ // @@protoc_insertion_point(includes) #include +extern PROTOBUF_INTERNAL_EXPORT_flyteidl_2fplugins_2fpytorch_2eproto ::google::protobuf::internal::SCCInfo<0> scc_info_ElasticConfig_flyteidl_2fplugins_2fpytorch_2eproto; namespace flyteidl { namespace plugins { +class ElasticConfigDefaultTypeInternal { + public: + ::google::protobuf::internal::ExplicitlyConstructed _instance; +} _ElasticConfig_default_instance_; class DistributedPyTorchTrainingTaskDefaultTypeInternal { public: ::google::protobuf::internal::ExplicitlyConstructed _instance; } _DistributedPyTorchTrainingTask_default_instance_; } // namespace plugins } // namespace flyteidl +static void InitDefaultsElasticConfig_flyteidl_2fplugins_2fpytorch_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::flyteidl::plugins::_ElasticConfig_default_instance_; + new (ptr) ::flyteidl::plugins::ElasticConfig(); + ::google::protobuf::internal::OnShutdownDestroyMessage(ptr); + } + ::flyteidl::plugins::ElasticConfig::InitAsDefaultInstance(); +} + +::google::protobuf::internal::SCCInfo<0> scc_info_ElasticConfig_flyteidl_2fplugins_2fpytorch_2eproto = + {{ATOMIC_VAR_INIT(::google::protobuf::internal::SCCInfoBase::kUninitialized), 0, InitDefaultsElasticConfig_flyteidl_2fplugins_2fpytorch_2eproto}, {}}; + static void InitDefaultsDistributedPyTorchTrainingTask_flyteidl_2fplugins_2fpytorch_2eproto() { GOOGLE_PROTOBUF_VERIFY_VERSION; @@ -35,57 +54,69 @@ static void InitDefaultsDistributedPyTorchTrainingTask_flyteidl_2fplugins_2fpyto ::flyteidl::plugins::DistributedPyTorchTrainingTask::InitAsDefaultInstance(); } -::google::protobuf::internal::SCCInfo<0> scc_info_DistributedPyTorchTrainingTask_flyteidl_2fplugins_2fpytorch_2eproto = - {{ATOMIC_VAR_INIT(::google::protobuf::internal::SCCInfoBase::kUninitialized), 0, InitDefaultsDistributedPyTorchTrainingTask_flyteidl_2fplugins_2fpytorch_2eproto}, {}}; +::google::protobuf::internal::SCCInfo<1> scc_info_DistributedPyTorchTrainingTask_flyteidl_2fplugins_2fpytorch_2eproto = + {{ATOMIC_VAR_INIT(::google::protobuf::internal::SCCInfoBase::kUninitialized), 1, InitDefaultsDistributedPyTorchTrainingTask_flyteidl_2fplugins_2fpytorch_2eproto}, { + &scc_info_ElasticConfig_flyteidl_2fplugins_2fpytorch_2eproto.base,}}; void InitDefaults_flyteidl_2fplugins_2fpytorch_2eproto() { + ::google::protobuf::internal::InitSCC(&scc_info_ElasticConfig_flyteidl_2fplugins_2fpytorch_2eproto.base); ::google::protobuf::internal::InitSCC(&scc_info_DistributedPyTorchTrainingTask_flyteidl_2fplugins_2fpytorch_2eproto.base); } -::google::protobuf::Metadata file_level_metadata_flyteidl_2fplugins_2fpytorch_2eproto[1]; +::google::protobuf::Metadata file_level_metadata_flyteidl_2fplugins_2fpytorch_2eproto[2]; constexpr ::google::protobuf::EnumDescriptor const** file_level_enum_descriptors_flyteidl_2fplugins_2fpytorch_2eproto = nullptr; constexpr ::google::protobuf::ServiceDescriptor const** file_level_service_descriptors_flyteidl_2fplugins_2fpytorch_2eproto = nullptr; const ::google::protobuf::uint32 TableStruct_flyteidl_2fplugins_2fpytorch_2eproto::offsets[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = { + ~0u, // no _has_bits_ + PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::ElasticConfig, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::ElasticConfig, rdzv_backend_), + PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::ElasticConfig, min_replicas_), + PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::ElasticConfig, max_replicas_), + PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::ElasticConfig, nproc_per_node_), + PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::ElasticConfig, max_restarts_), ~0u, // no _has_bits_ PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::DistributedPyTorchTrainingTask, _internal_metadata_), ~0u, // no _extensions_ ~0u, // no _oneof_case_ ~0u, // no _weak_field_map_ PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::DistributedPyTorchTrainingTask, workers_), - PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::DistributedPyTorchTrainingTask, rdzvbackend_), - PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::DistributedPyTorchTrainingTask, minreplicas_), - PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::DistributedPyTorchTrainingTask, maxreplicas_), - PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::DistributedPyTorchTrainingTask, nprocpernode_), - PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::DistributedPyTorchTrainingTask, maxrestarts_), + PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::DistributedPyTorchTrainingTask, elastic_config_), }; static const ::google::protobuf::internal::MigrationSchema schemas[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = { - { 0, -1, sizeof(::flyteidl::plugins::DistributedPyTorchTrainingTask)}, + { 0, -1, sizeof(::flyteidl::plugins::ElasticConfig)}, + { 10, -1, sizeof(::flyteidl::plugins::DistributedPyTorchTrainingTask)}, }; static ::google::protobuf::Message const * const file_default_instances[] = { + reinterpret_cast(&::flyteidl::plugins::_ElasticConfig_default_instance_), reinterpret_cast(&::flyteidl::plugins::_DistributedPyTorchTrainingTask_default_instance_), }; ::google::protobuf::internal::AssignDescriptorsTable assign_descriptors_table_flyteidl_2fplugins_2fpytorch_2eproto = { {}, AddDescriptors_flyteidl_2fplugins_2fpytorch_2eproto, "flyteidl/plugins/pytorch.proto", schemas, file_default_instances, TableStruct_flyteidl_2fplugins_2fpytorch_2eproto::offsets, - file_level_metadata_flyteidl_2fplugins_2fpytorch_2eproto, 1, file_level_enum_descriptors_flyteidl_2fplugins_2fpytorch_2eproto, file_level_service_descriptors_flyteidl_2fplugins_2fpytorch_2eproto, + file_level_metadata_flyteidl_2fplugins_2fpytorch_2eproto, 2, file_level_enum_descriptors_flyteidl_2fplugins_2fpytorch_2eproto, file_level_service_descriptors_flyteidl_2fplugins_2fpytorch_2eproto, }; const char descriptor_table_protodef_flyteidl_2fplugins_2fpytorch_2eproto[] = "\n\036flyteidl/plugins/pytorch.proto\022\020flytei" - "dl.plugins\"\233\001\n\036DistributedPyTorchTrainin" - "gTask\022\017\n\007workers\030\001 \001(\005\022\023\n\013RDZVBackend\030\002 " - "\001(\t\022\023\n\013minReplicas\030\003 \001(\005\022\023\n\013maxReplicas\030" - "\004 \001(\005\022\024\n\014nProcPerNode\030\005 \001(\005\022\023\n\013maxRestar" - "ts\030\006 \001(\005B9Z7github.com/flyteorg/flyteidl" - "/gen/pb-go/flyteidl/pluginsb\006proto3" + "dl.plugins\"\177\n\rElasticConfig\022\024\n\014rdzv_back" + "end\030\001 \001(\t\022\024\n\014min_replicas\030\002 \001(\005\022\024\n\014max_r" + "eplicas\030\003 \001(\005\022\026\n\016nproc_per_node\030\004 \001(\005\022\024\n" + "\014max_restarts\030\005 \001(\005\"j\n\036DistributedPyTorc" + "hTrainingTask\022\017\n\007workers\030\001 \001(\005\0227\n\016elasti" + "c_config\030\002 \001(\0132\037.flyteidl.plugins.Elasti" + "cConfigB9Z7github.com/flyteorg/flyteidl/" + "gen/pb-go/flyteidl/pluginsb\006proto3" ; ::google::protobuf::internal::DescriptorTable descriptor_table_flyteidl_2fplugins_2fpytorch_2eproto = { false, InitDefaults_flyteidl_2fplugins_2fpytorch_2eproto, descriptor_table_protodef_flyteidl_2fplugins_2fpytorch_2eproto, - "flyteidl/plugins/pytorch.proto", &assign_descriptors_table_flyteidl_2fplugins_2fpytorch_2eproto, 275, + "flyteidl/plugins/pytorch.proto", &assign_descriptors_table_flyteidl_2fplugins_2fpytorch_2eproto, 354, }; void AddDescriptors_flyteidl_2fplugins_2fpytorch_2eproto() { @@ -102,84 +133,83 @@ namespace plugins { // =================================================================== -void DistributedPyTorchTrainingTask::InitAsDefaultInstance() { +void ElasticConfig::InitAsDefaultInstance() { } -class DistributedPyTorchTrainingTask::HasBitSetters { +class ElasticConfig::HasBitSetters { public: }; #if !defined(_MSC_VER) || _MSC_VER >= 1900 -const int DistributedPyTorchTrainingTask::kWorkersFieldNumber; -const int DistributedPyTorchTrainingTask::kRDZVBackendFieldNumber; -const int DistributedPyTorchTrainingTask::kMinReplicasFieldNumber; -const int DistributedPyTorchTrainingTask::kMaxReplicasFieldNumber; -const int DistributedPyTorchTrainingTask::kNProcPerNodeFieldNumber; -const int DistributedPyTorchTrainingTask::kMaxRestartsFieldNumber; +const int ElasticConfig::kRdzvBackendFieldNumber; +const int ElasticConfig::kMinReplicasFieldNumber; +const int ElasticConfig::kMaxReplicasFieldNumber; +const int ElasticConfig::kNprocPerNodeFieldNumber; +const int ElasticConfig::kMaxRestartsFieldNumber; #endif // !defined(_MSC_VER) || _MSC_VER >= 1900 -DistributedPyTorchTrainingTask::DistributedPyTorchTrainingTask() +ElasticConfig::ElasticConfig() : ::google::protobuf::Message(), _internal_metadata_(nullptr) { SharedCtor(); - // @@protoc_insertion_point(constructor:flyteidl.plugins.DistributedPyTorchTrainingTask) + // @@protoc_insertion_point(constructor:flyteidl.plugins.ElasticConfig) } -DistributedPyTorchTrainingTask::DistributedPyTorchTrainingTask(const DistributedPyTorchTrainingTask& from) +ElasticConfig::ElasticConfig(const ElasticConfig& from) : ::google::protobuf::Message(), _internal_metadata_(nullptr) { _internal_metadata_.MergeFrom(from._internal_metadata_); - rdzvbackend_.UnsafeSetDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); - if (from.rdzvbackend().size() > 0) { - rdzvbackend_.AssignWithDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), from.rdzvbackend_); + rdzv_backend_.UnsafeSetDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); + if (from.rdzv_backend().size() > 0) { + rdzv_backend_.AssignWithDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), from.rdzv_backend_); } - ::memcpy(&workers_, &from.workers_, - static_cast(reinterpret_cast(&maxrestarts_) - - reinterpret_cast(&workers_)) + sizeof(maxrestarts_)); - // @@protoc_insertion_point(copy_constructor:flyteidl.plugins.DistributedPyTorchTrainingTask) + ::memcpy(&min_replicas_, &from.min_replicas_, + static_cast(reinterpret_cast(&max_restarts_) - + reinterpret_cast(&min_replicas_)) + sizeof(max_restarts_)); + // @@protoc_insertion_point(copy_constructor:flyteidl.plugins.ElasticConfig) } -void DistributedPyTorchTrainingTask::SharedCtor() { +void ElasticConfig::SharedCtor() { ::google::protobuf::internal::InitSCC( - &scc_info_DistributedPyTorchTrainingTask_flyteidl_2fplugins_2fpytorch_2eproto.base); - rdzvbackend_.UnsafeSetDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); - ::memset(&workers_, 0, static_cast( - reinterpret_cast(&maxrestarts_) - - reinterpret_cast(&workers_)) + sizeof(maxrestarts_)); + &scc_info_ElasticConfig_flyteidl_2fplugins_2fpytorch_2eproto.base); + rdzv_backend_.UnsafeSetDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); + ::memset(&min_replicas_, 0, static_cast( + reinterpret_cast(&max_restarts_) - + reinterpret_cast(&min_replicas_)) + sizeof(max_restarts_)); } -DistributedPyTorchTrainingTask::~DistributedPyTorchTrainingTask() { - // @@protoc_insertion_point(destructor:flyteidl.plugins.DistributedPyTorchTrainingTask) +ElasticConfig::~ElasticConfig() { + // @@protoc_insertion_point(destructor:flyteidl.plugins.ElasticConfig) SharedDtor(); } -void DistributedPyTorchTrainingTask::SharedDtor() { - rdzvbackend_.DestroyNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); +void ElasticConfig::SharedDtor() { + rdzv_backend_.DestroyNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } -void DistributedPyTorchTrainingTask::SetCachedSize(int size) const { +void ElasticConfig::SetCachedSize(int size) const { _cached_size_.Set(size); } -const DistributedPyTorchTrainingTask& DistributedPyTorchTrainingTask::default_instance() { - ::google::protobuf::internal::InitSCC(&::scc_info_DistributedPyTorchTrainingTask_flyteidl_2fplugins_2fpytorch_2eproto.base); +const ElasticConfig& ElasticConfig::default_instance() { + ::google::protobuf::internal::InitSCC(&::scc_info_ElasticConfig_flyteidl_2fplugins_2fpytorch_2eproto.base); return *internal_default_instance(); } -void DistributedPyTorchTrainingTask::Clear() { -// @@protoc_insertion_point(message_clear_start:flyteidl.plugins.DistributedPyTorchTrainingTask) +void ElasticConfig::Clear() { +// @@protoc_insertion_point(message_clear_start:flyteidl.plugins.ElasticConfig) ::google::protobuf::uint32 cached_has_bits = 0; // Prevent compiler warnings about cached_has_bits being unused (void) cached_has_bits; - rdzvbackend_.ClearToEmptyNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); - ::memset(&workers_, 0, static_cast( - reinterpret_cast(&maxrestarts_) - - reinterpret_cast(&workers_)) + sizeof(maxrestarts_)); + rdzv_backend_.ClearToEmptyNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); + ::memset(&min_replicas_, 0, static_cast( + reinterpret_cast(&max_restarts_) - + reinterpret_cast(&min_replicas_)) + sizeof(max_restarts_)); _internal_metadata_.Clear(); } #if GOOGLE_PROTOBUF_ENABLE_EXPERIMENTAL_PARSER -const char* DistributedPyTorchTrainingTask::_InternalParse(const char* begin, const char* end, void* object, +const char* ElasticConfig::_InternalParse(const char* begin, const char* end, void* object, ::google::protobuf::internal::ParseContext* ctx) { - auto msg = static_cast(object); + auto msg = static_cast(object); ::google::protobuf::int32 size; (void)size; int depth; (void)depth; ::google::protobuf::uint32 tag; @@ -189,20 +219,13 @@ const char* DistributedPyTorchTrainingTask::_InternalParse(const char* begin, co ptr = ::google::protobuf::io::Parse32(ptr, &tag); GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); switch (tag >> 3) { - // int32 workers = 1; + // string rdzv_backend = 1; case 1: { - if (static_cast<::google::protobuf::uint8>(tag) != 8) goto handle_unusual; - msg->set_workers(::google::protobuf::internal::ReadVarint(&ptr)); - GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); - break; - } - // string RDZVBackend = 2; - case 2: { - if (static_cast<::google::protobuf::uint8>(tag) != 18) goto handle_unusual; + if (static_cast<::google::protobuf::uint8>(tag) != 10) goto handle_unusual; ptr = ::google::protobuf::io::ReadSize(ptr, &size); GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); - ctx->extra_parse_data().SetFieldName("flyteidl.plugins.DistributedPyTorchTrainingTask.RDZVBackend"); - object = msg->mutable_rdzvbackend(); + ctx->extra_parse_data().SetFieldName("flyteidl.plugins.ElasticConfig.rdzv_backend"); + object = msg->mutable_rdzv_backend(); if (size > end - ptr + ::google::protobuf::internal::ParseContext::kSlopBytes) { parser_till_end = ::google::protobuf::internal::GreedyStringParserUTF8; goto string_till_end; @@ -212,31 +235,31 @@ const char* DistributedPyTorchTrainingTask::_InternalParse(const char* begin, co ptr += size; break; } - // int32 minReplicas = 3; + // int32 min_replicas = 2; + case 2: { + if (static_cast<::google::protobuf::uint8>(tag) != 16) goto handle_unusual; + msg->set_min_replicas(::google::protobuf::internal::ReadVarint(&ptr)); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + break; + } + // int32 max_replicas = 3; case 3: { if (static_cast<::google::protobuf::uint8>(tag) != 24) goto handle_unusual; - msg->set_minreplicas(::google::protobuf::internal::ReadVarint(&ptr)); + msg->set_max_replicas(::google::protobuf::internal::ReadVarint(&ptr)); GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); break; } - // int32 maxReplicas = 4; + // int32 nproc_per_node = 4; case 4: { if (static_cast<::google::protobuf::uint8>(tag) != 32) goto handle_unusual; - msg->set_maxreplicas(::google::protobuf::internal::ReadVarint(&ptr)); + msg->set_nproc_per_node(::google::protobuf::internal::ReadVarint(&ptr)); GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); break; } - // int32 nProcPerNode = 5; + // int32 max_restarts = 5; case 5: { if (static_cast<::google::protobuf::uint8>(tag) != 40) goto handle_unusual; - msg->set_nprocpernode(::google::protobuf::internal::ReadVarint(&ptr)); - GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); - break; - } - // int32 maxRestarts = 6; - case 6: { - if (static_cast<::google::protobuf::uint8>(tag) != 48) goto handle_unusual; - msg->set_maxrestarts(::google::protobuf::internal::ReadVarint(&ptr)); + msg->set_max_restarts(::google::protobuf::internal::ReadVarint(&ptr)); GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); break; } @@ -264,90 +287,77 @@ const char* DistributedPyTorchTrainingTask::_InternalParse(const char* begin, co {parser_till_end, object}, size); } #else // GOOGLE_PROTOBUF_ENABLE_EXPERIMENTAL_PARSER -bool DistributedPyTorchTrainingTask::MergePartialFromCodedStream( +bool ElasticConfig::MergePartialFromCodedStream( ::google::protobuf::io::CodedInputStream* input) { #define DO_(EXPRESSION) if (!PROTOBUF_PREDICT_TRUE(EXPRESSION)) goto failure ::google::protobuf::uint32 tag; - // @@protoc_insertion_point(parse_start:flyteidl.plugins.DistributedPyTorchTrainingTask) + // @@protoc_insertion_point(parse_start:flyteidl.plugins.ElasticConfig) for (;;) { ::std::pair<::google::protobuf::uint32, bool> p = input->ReadTagWithCutoffNoLastTag(127u); tag = p.first; if (!p.second) goto handle_unusual; switch (::google::protobuf::internal::WireFormatLite::GetTagFieldNumber(tag)) { - // int32 workers = 1; + // string rdzv_backend = 1; case 1: { - if (static_cast< ::google::protobuf::uint8>(tag) == (8 & 0xFF)) { - - DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< - ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( - input, &workers_))); + if (static_cast< ::google::protobuf::uint8>(tag) == (10 & 0xFF)) { + DO_(::google::protobuf::internal::WireFormatLite::ReadString( + input, this->mutable_rdzv_backend())); + DO_(::google::protobuf::internal::WireFormatLite::VerifyUtf8String( + this->rdzv_backend().data(), static_cast(this->rdzv_backend().length()), + ::google::protobuf::internal::WireFormatLite::PARSE, + "flyteidl.plugins.ElasticConfig.rdzv_backend")); } else { goto handle_unusual; } break; } - // string RDZVBackend = 2; + // int32 min_replicas = 2; case 2: { - if (static_cast< ::google::protobuf::uint8>(tag) == (18 & 0xFF)) { - DO_(::google::protobuf::internal::WireFormatLite::ReadString( - input, this->mutable_rdzvbackend())); - DO_(::google::protobuf::internal::WireFormatLite::VerifyUtf8String( - this->rdzvbackend().data(), static_cast(this->rdzvbackend().length()), - ::google::protobuf::internal::WireFormatLite::PARSE, - "flyteidl.plugins.DistributedPyTorchTrainingTask.RDZVBackend")); + if (static_cast< ::google::protobuf::uint8>(tag) == (16 & 0xFF)) { + + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( + input, &min_replicas_))); } else { goto handle_unusual; } break; } - // int32 minReplicas = 3; + // int32 max_replicas = 3; case 3: { if (static_cast< ::google::protobuf::uint8>(tag) == (24 & 0xFF)) { DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( - input, &minreplicas_))); + input, &max_replicas_))); } else { goto handle_unusual; } break; } - // int32 maxReplicas = 4; + // int32 nproc_per_node = 4; case 4: { if (static_cast< ::google::protobuf::uint8>(tag) == (32 & 0xFF)) { DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( - input, &maxreplicas_))); + input, &nproc_per_node_))); } else { goto handle_unusual; } break; } - // int32 nProcPerNode = 5; + // int32 max_restarts = 5; case 5: { if (static_cast< ::google::protobuf::uint8>(tag) == (40 & 0xFF)) { DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( - input, &nprocpernode_))); - } else { - goto handle_unusual; - } - break; - } - - // int32 maxRestarts = 6; - case 6: { - if (static_cast< ::google::protobuf::uint8>(tag) == (48 & 0xFF)) { - - DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< - ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( - input, &maxrestarts_))); + input, &max_restarts_))); } else { goto handle_unusual; } @@ -366,115 +376,105 @@ bool DistributedPyTorchTrainingTask::MergePartialFromCodedStream( } } success: - // @@protoc_insertion_point(parse_success:flyteidl.plugins.DistributedPyTorchTrainingTask) + // @@protoc_insertion_point(parse_success:flyteidl.plugins.ElasticConfig) return true; failure: - // @@protoc_insertion_point(parse_failure:flyteidl.plugins.DistributedPyTorchTrainingTask) + // @@protoc_insertion_point(parse_failure:flyteidl.plugins.ElasticConfig) return false; #undef DO_ } #endif // GOOGLE_PROTOBUF_ENABLE_EXPERIMENTAL_PARSER -void DistributedPyTorchTrainingTask::SerializeWithCachedSizes( +void ElasticConfig::SerializeWithCachedSizes( ::google::protobuf::io::CodedOutputStream* output) const { - // @@protoc_insertion_point(serialize_start:flyteidl.plugins.DistributedPyTorchTrainingTask) + // @@protoc_insertion_point(serialize_start:flyteidl.plugins.ElasticConfig) ::google::protobuf::uint32 cached_has_bits = 0; (void) cached_has_bits; - // int32 workers = 1; - if (this->workers() != 0) { - ::google::protobuf::internal::WireFormatLite::WriteInt32(1, this->workers(), output); - } - - // string RDZVBackend = 2; - if (this->rdzvbackend().size() > 0) { + // string rdzv_backend = 1; + if (this->rdzv_backend().size() > 0) { ::google::protobuf::internal::WireFormatLite::VerifyUtf8String( - this->rdzvbackend().data(), static_cast(this->rdzvbackend().length()), + this->rdzv_backend().data(), static_cast(this->rdzv_backend().length()), ::google::protobuf::internal::WireFormatLite::SERIALIZE, - "flyteidl.plugins.DistributedPyTorchTrainingTask.RDZVBackend"); + "flyteidl.plugins.ElasticConfig.rdzv_backend"); ::google::protobuf::internal::WireFormatLite::WriteStringMaybeAliased( - 2, this->rdzvbackend(), output); + 1, this->rdzv_backend(), output); } - // int32 minReplicas = 3; - if (this->minreplicas() != 0) { - ::google::protobuf::internal::WireFormatLite::WriteInt32(3, this->minreplicas(), output); + // int32 min_replicas = 2; + if (this->min_replicas() != 0) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(2, this->min_replicas(), output); } - // int32 maxReplicas = 4; - if (this->maxreplicas() != 0) { - ::google::protobuf::internal::WireFormatLite::WriteInt32(4, this->maxreplicas(), output); + // int32 max_replicas = 3; + if (this->max_replicas() != 0) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(3, this->max_replicas(), output); } - // int32 nProcPerNode = 5; - if (this->nprocpernode() != 0) { - ::google::protobuf::internal::WireFormatLite::WriteInt32(5, this->nprocpernode(), output); + // int32 nproc_per_node = 4; + if (this->nproc_per_node() != 0) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(4, this->nproc_per_node(), output); } - // int32 maxRestarts = 6; - if (this->maxrestarts() != 0) { - ::google::protobuf::internal::WireFormatLite::WriteInt32(6, this->maxrestarts(), output); + // int32 max_restarts = 5; + if (this->max_restarts() != 0) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(5, this->max_restarts(), output); } if (_internal_metadata_.have_unknown_fields()) { ::google::protobuf::internal::WireFormat::SerializeUnknownFields( _internal_metadata_.unknown_fields(), output); } - // @@protoc_insertion_point(serialize_end:flyteidl.plugins.DistributedPyTorchTrainingTask) + // @@protoc_insertion_point(serialize_end:flyteidl.plugins.ElasticConfig) } -::google::protobuf::uint8* DistributedPyTorchTrainingTask::InternalSerializeWithCachedSizesToArray( +::google::protobuf::uint8* ElasticConfig::InternalSerializeWithCachedSizesToArray( ::google::protobuf::uint8* target) const { - // @@protoc_insertion_point(serialize_to_array_start:flyteidl.plugins.DistributedPyTorchTrainingTask) + // @@protoc_insertion_point(serialize_to_array_start:flyteidl.plugins.ElasticConfig) ::google::protobuf::uint32 cached_has_bits = 0; (void) cached_has_bits; - // int32 workers = 1; - if (this->workers() != 0) { - target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(1, this->workers(), target); - } - - // string RDZVBackend = 2; - if (this->rdzvbackend().size() > 0) { + // string rdzv_backend = 1; + if (this->rdzv_backend().size() > 0) { ::google::protobuf::internal::WireFormatLite::VerifyUtf8String( - this->rdzvbackend().data(), static_cast(this->rdzvbackend().length()), + this->rdzv_backend().data(), static_cast(this->rdzv_backend().length()), ::google::protobuf::internal::WireFormatLite::SERIALIZE, - "flyteidl.plugins.DistributedPyTorchTrainingTask.RDZVBackend"); + "flyteidl.plugins.ElasticConfig.rdzv_backend"); target = ::google::protobuf::internal::WireFormatLite::WriteStringToArray( - 2, this->rdzvbackend(), target); + 1, this->rdzv_backend(), target); } - // int32 minReplicas = 3; - if (this->minreplicas() != 0) { - target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(3, this->minreplicas(), target); + // int32 min_replicas = 2; + if (this->min_replicas() != 0) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(2, this->min_replicas(), target); } - // int32 maxReplicas = 4; - if (this->maxreplicas() != 0) { - target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(4, this->maxreplicas(), target); + // int32 max_replicas = 3; + if (this->max_replicas() != 0) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(3, this->max_replicas(), target); } - // int32 nProcPerNode = 5; - if (this->nprocpernode() != 0) { - target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(5, this->nprocpernode(), target); + // int32 nproc_per_node = 4; + if (this->nproc_per_node() != 0) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(4, this->nproc_per_node(), target); } - // int32 maxRestarts = 6; - if (this->maxrestarts() != 0) { - target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(6, this->maxrestarts(), target); + // int32 max_restarts = 5; + if (this->max_restarts() != 0) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(5, this->max_restarts(), target); } if (_internal_metadata_.have_unknown_fields()) { target = ::google::protobuf::internal::WireFormat::SerializeUnknownFieldsToArray( _internal_metadata_.unknown_fields(), target); } - // @@protoc_insertion_point(serialize_to_array_end:flyteidl.plugins.DistributedPyTorchTrainingTask) + // @@protoc_insertion_point(serialize_to_array_end:flyteidl.plugins.ElasticConfig) return target; } -size_t DistributedPyTorchTrainingTask::ByteSizeLong() const { -// @@protoc_insertion_point(message_byte_size_start:flyteidl.plugins.DistributedPyTorchTrainingTask) +size_t ElasticConfig::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:flyteidl.plugins.ElasticConfig) size_t total_size = 0; if (_internal_metadata_.have_unknown_fields()) { @@ -486,46 +486,386 @@ size_t DistributedPyTorchTrainingTask::ByteSizeLong() const { // Prevent compiler warnings about cached_has_bits being unused (void) cached_has_bits; - // string RDZVBackend = 2; - if (this->rdzvbackend().size() > 0) { + // string rdzv_backend = 1; + if (this->rdzv_backend().size() > 0) { total_size += 1 + ::google::protobuf::internal::WireFormatLite::StringSize( - this->rdzvbackend()); + this->rdzv_backend()); } - // int32 workers = 1; - if (this->workers() != 0) { + // int32 min_replicas = 2; + if (this->min_replicas() != 0) { total_size += 1 + ::google::protobuf::internal::WireFormatLite::Int32Size( - this->workers()); + this->min_replicas()); } - // int32 minReplicas = 3; - if (this->minreplicas() != 0) { + // int32 max_replicas = 3; + if (this->max_replicas() != 0) { total_size += 1 + ::google::protobuf::internal::WireFormatLite::Int32Size( - this->minreplicas()); + this->max_replicas()); } - // int32 maxReplicas = 4; - if (this->maxreplicas() != 0) { + // int32 nproc_per_node = 4; + if (this->nproc_per_node() != 0) { total_size += 1 + ::google::protobuf::internal::WireFormatLite::Int32Size( - this->maxreplicas()); + this->nproc_per_node()); } - // int32 nProcPerNode = 5; - if (this->nprocpernode() != 0) { + // int32 max_restarts = 5; + if (this->max_restarts() != 0) { total_size += 1 + ::google::protobuf::internal::WireFormatLite::Int32Size( - this->nprocpernode()); + this->max_restarts()); + } + + int cached_size = ::google::protobuf::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void ElasticConfig::MergeFrom(const ::google::protobuf::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:flyteidl.plugins.ElasticConfig) + GOOGLE_DCHECK_NE(&from, this); + const ElasticConfig* source = + ::google::protobuf::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:flyteidl.plugins.ElasticConfig) + ::google::protobuf::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:flyteidl.plugins.ElasticConfig) + MergeFrom(*source); + } +} + +void ElasticConfig::MergeFrom(const ElasticConfig& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:flyteidl.plugins.ElasticConfig) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom(from._internal_metadata_); + ::google::protobuf::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + if (from.rdzv_backend().size() > 0) { + + rdzv_backend_.AssignWithDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), from.rdzv_backend_); + } + if (from.min_replicas() != 0) { + set_min_replicas(from.min_replicas()); + } + if (from.max_replicas() != 0) { + set_max_replicas(from.max_replicas()); + } + if (from.nproc_per_node() != 0) { + set_nproc_per_node(from.nproc_per_node()); + } + if (from.max_restarts() != 0) { + set_max_restarts(from.max_restarts()); + } +} + +void ElasticConfig::CopyFrom(const ::google::protobuf::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:flyteidl.plugins.ElasticConfig) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void ElasticConfig::CopyFrom(const ElasticConfig& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:flyteidl.plugins.ElasticConfig) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool ElasticConfig::IsInitialized() const { + return true; +} + +void ElasticConfig::Swap(ElasticConfig* other) { + if (other == this) return; + InternalSwap(other); +} +void ElasticConfig::InternalSwap(ElasticConfig* other) { + using std::swap; + _internal_metadata_.Swap(&other->_internal_metadata_); + rdzv_backend_.Swap(&other->rdzv_backend_, &::google::protobuf::internal::GetEmptyStringAlreadyInited(), + GetArenaNoVirtual()); + swap(min_replicas_, other->min_replicas_); + swap(max_replicas_, other->max_replicas_); + swap(nproc_per_node_, other->nproc_per_node_); + swap(max_restarts_, other->max_restarts_); +} + +::google::protobuf::Metadata ElasticConfig::GetMetadata() const { + ::google::protobuf::internal::AssignDescriptors(&::assign_descriptors_table_flyteidl_2fplugins_2fpytorch_2eproto); + return ::file_level_metadata_flyteidl_2fplugins_2fpytorch_2eproto[kIndexInFileMessages]; +} + + +// =================================================================== + +void DistributedPyTorchTrainingTask::InitAsDefaultInstance() { + ::flyteidl::plugins::_DistributedPyTorchTrainingTask_default_instance_._instance.get_mutable()->elastic_config_ = const_cast< ::flyteidl::plugins::ElasticConfig*>( + ::flyteidl::plugins::ElasticConfig::internal_default_instance()); +} +class DistributedPyTorchTrainingTask::HasBitSetters { + public: + static const ::flyteidl::plugins::ElasticConfig& elastic_config(const DistributedPyTorchTrainingTask* msg); +}; + +const ::flyteidl::plugins::ElasticConfig& +DistributedPyTorchTrainingTask::HasBitSetters::elastic_config(const DistributedPyTorchTrainingTask* msg) { + return *msg->elastic_config_; +} +#if !defined(_MSC_VER) || _MSC_VER >= 1900 +const int DistributedPyTorchTrainingTask::kWorkersFieldNumber; +const int DistributedPyTorchTrainingTask::kElasticConfigFieldNumber; +#endif // !defined(_MSC_VER) || _MSC_VER >= 1900 + +DistributedPyTorchTrainingTask::DistributedPyTorchTrainingTask() + : ::google::protobuf::Message(), _internal_metadata_(nullptr) { + SharedCtor(); + // @@protoc_insertion_point(constructor:flyteidl.plugins.DistributedPyTorchTrainingTask) +} +DistributedPyTorchTrainingTask::DistributedPyTorchTrainingTask(const DistributedPyTorchTrainingTask& from) + : ::google::protobuf::Message(), + _internal_metadata_(nullptr) { + _internal_metadata_.MergeFrom(from._internal_metadata_); + if (from.has_elastic_config()) { + elastic_config_ = new ::flyteidl::plugins::ElasticConfig(*from.elastic_config_); + } else { + elastic_config_ = nullptr; } + workers_ = from.workers_; + // @@protoc_insertion_point(copy_constructor:flyteidl.plugins.DistributedPyTorchTrainingTask) +} - // int32 maxRestarts = 6; - if (this->maxrestarts() != 0) { +void DistributedPyTorchTrainingTask::SharedCtor() { + ::google::protobuf::internal::InitSCC( + &scc_info_DistributedPyTorchTrainingTask_flyteidl_2fplugins_2fpytorch_2eproto.base); + ::memset(&elastic_config_, 0, static_cast( + reinterpret_cast(&workers_) - + reinterpret_cast(&elastic_config_)) + sizeof(workers_)); +} + +DistributedPyTorchTrainingTask::~DistributedPyTorchTrainingTask() { + // @@protoc_insertion_point(destructor:flyteidl.plugins.DistributedPyTorchTrainingTask) + SharedDtor(); +} + +void DistributedPyTorchTrainingTask::SharedDtor() { + if (this != internal_default_instance()) delete elastic_config_; +} + +void DistributedPyTorchTrainingTask::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const DistributedPyTorchTrainingTask& DistributedPyTorchTrainingTask::default_instance() { + ::google::protobuf::internal::InitSCC(&::scc_info_DistributedPyTorchTrainingTask_flyteidl_2fplugins_2fpytorch_2eproto.base); + return *internal_default_instance(); +} + + +void DistributedPyTorchTrainingTask::Clear() { +// @@protoc_insertion_point(message_clear_start:flyteidl.plugins.DistributedPyTorchTrainingTask) + ::google::protobuf::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + if (GetArenaNoVirtual() == nullptr && elastic_config_ != nullptr) { + delete elastic_config_; + } + elastic_config_ = nullptr; + workers_ = 0; + _internal_metadata_.Clear(); +} + +#if GOOGLE_PROTOBUF_ENABLE_EXPERIMENTAL_PARSER +const char* DistributedPyTorchTrainingTask::_InternalParse(const char* begin, const char* end, void* object, + ::google::protobuf::internal::ParseContext* ctx) { + auto msg = static_cast(object); + ::google::protobuf::int32 size; (void)size; + int depth; (void)depth; + ::google::protobuf::uint32 tag; + ::google::protobuf::internal::ParseFunc parser_till_end; (void)parser_till_end; + auto ptr = begin; + while (ptr < end) { + ptr = ::google::protobuf::io::Parse32(ptr, &tag); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + switch (tag >> 3) { + // int32 workers = 1; + case 1: { + if (static_cast<::google::protobuf::uint8>(tag) != 8) goto handle_unusual; + msg->set_workers(::google::protobuf::internal::ReadVarint(&ptr)); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + break; + } + // .flyteidl.plugins.ElasticConfig elastic_config = 2; + case 2: { + if (static_cast<::google::protobuf::uint8>(tag) != 18) goto handle_unusual; + ptr = ::google::protobuf::io::ReadSize(ptr, &size); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + parser_till_end = ::flyteidl::plugins::ElasticConfig::_InternalParse; + object = msg->mutable_elastic_config(); + if (size > end - ptr) goto len_delim_till_end; + ptr += size; + GOOGLE_PROTOBUF_PARSER_ASSERT(ctx->ParseExactRange( + {parser_till_end, object}, ptr - size, ptr)); + break; + } + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->EndGroup(tag); + return ptr; + } + auto res = UnknownFieldParse(tag, {_InternalParse, msg}, + ptr, end, msg->_internal_metadata_.mutable_unknown_fields(), ctx); + ptr = res.first; + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr != nullptr); + if (res.second) return ptr; + } + } // switch + } // while + return ptr; +len_delim_till_end: + return ctx->StoreAndTailCall(ptr, end, {_InternalParse, msg}, + {parser_till_end, object}, size); +} +#else // GOOGLE_PROTOBUF_ENABLE_EXPERIMENTAL_PARSER +bool DistributedPyTorchTrainingTask::MergePartialFromCodedStream( + ::google::protobuf::io::CodedInputStream* input) { +#define DO_(EXPRESSION) if (!PROTOBUF_PREDICT_TRUE(EXPRESSION)) goto failure + ::google::protobuf::uint32 tag; + // @@protoc_insertion_point(parse_start:flyteidl.plugins.DistributedPyTorchTrainingTask) + for (;;) { + ::std::pair<::google::protobuf::uint32, bool> p = input->ReadTagWithCutoffNoLastTag(127u); + tag = p.first; + if (!p.second) goto handle_unusual; + switch (::google::protobuf::internal::WireFormatLite::GetTagFieldNumber(tag)) { + // int32 workers = 1; + case 1: { + if (static_cast< ::google::protobuf::uint8>(tag) == (8 & 0xFF)) { + + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( + input, &workers_))); + } else { + goto handle_unusual; + } + break; + } + + // .flyteidl.plugins.ElasticConfig elastic_config = 2; + case 2: { + if (static_cast< ::google::protobuf::uint8>(tag) == (18 & 0xFF)) { + DO_(::google::protobuf::internal::WireFormatLite::ReadMessage( + input, mutable_elastic_config())); + } else { + goto handle_unusual; + } + break; + } + + default: { + handle_unusual: + if (tag == 0) { + goto success; + } + DO_(::google::protobuf::internal::WireFormat::SkipField( + input, tag, _internal_metadata_.mutable_unknown_fields())); + break; + } + } + } +success: + // @@protoc_insertion_point(parse_success:flyteidl.plugins.DistributedPyTorchTrainingTask) + return true; +failure: + // @@protoc_insertion_point(parse_failure:flyteidl.plugins.DistributedPyTorchTrainingTask) + return false; +#undef DO_ +} +#endif // GOOGLE_PROTOBUF_ENABLE_EXPERIMENTAL_PARSER + +void DistributedPyTorchTrainingTask::SerializeWithCachedSizes( + ::google::protobuf::io::CodedOutputStream* output) const { + // @@protoc_insertion_point(serialize_start:flyteidl.plugins.DistributedPyTorchTrainingTask) + ::google::protobuf::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // int32 workers = 1; + if (this->workers() != 0) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(1, this->workers(), output); + } + + // .flyteidl.plugins.ElasticConfig elastic_config = 2; + if (this->has_elastic_config()) { + ::google::protobuf::internal::WireFormatLite::WriteMessageMaybeToArray( + 2, HasBitSetters::elastic_config(this), output); + } + + if (_internal_metadata_.have_unknown_fields()) { + ::google::protobuf::internal::WireFormat::SerializeUnknownFields( + _internal_metadata_.unknown_fields(), output); + } + // @@protoc_insertion_point(serialize_end:flyteidl.plugins.DistributedPyTorchTrainingTask) +} + +::google::protobuf::uint8* DistributedPyTorchTrainingTask::InternalSerializeWithCachedSizesToArray( + ::google::protobuf::uint8* target) const { + // @@protoc_insertion_point(serialize_to_array_start:flyteidl.plugins.DistributedPyTorchTrainingTask) + ::google::protobuf::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // int32 workers = 1; + if (this->workers() != 0) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(1, this->workers(), target); + } + + // .flyteidl.plugins.ElasticConfig elastic_config = 2; + if (this->has_elastic_config()) { + target = ::google::protobuf::internal::WireFormatLite:: + InternalWriteMessageToArray( + 2, HasBitSetters::elastic_config(this), target); + } + + if (_internal_metadata_.have_unknown_fields()) { + target = ::google::protobuf::internal::WireFormat::SerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields(), target); + } + // @@protoc_insertion_point(serialize_to_array_end:flyteidl.plugins.DistributedPyTorchTrainingTask) + return target; +} + +size_t DistributedPyTorchTrainingTask::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:flyteidl.plugins.DistributedPyTorchTrainingTask) + size_t total_size = 0; + + if (_internal_metadata_.have_unknown_fields()) { + total_size += + ::google::protobuf::internal::WireFormat::ComputeUnknownFieldsSize( + _internal_metadata_.unknown_fields()); + } + ::google::protobuf::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // .flyteidl.plugins.ElasticConfig elastic_config = 2; + if (this->has_elastic_config()) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::MessageSize( + *elastic_config_); + } + + // int32 workers = 1; + if (this->workers() != 0) { total_size += 1 + ::google::protobuf::internal::WireFormatLite::Int32Size( - this->maxrestarts()); + this->workers()); } int cached_size = ::google::protobuf::internal::ToCachedSize(total_size); @@ -555,25 +895,12 @@ void DistributedPyTorchTrainingTask::MergeFrom(const DistributedPyTorchTrainingT ::google::protobuf::uint32 cached_has_bits = 0; (void) cached_has_bits; - if (from.rdzvbackend().size() > 0) { - - rdzvbackend_.AssignWithDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), from.rdzvbackend_); + if (from.has_elastic_config()) { + mutable_elastic_config()->::flyteidl::plugins::ElasticConfig::MergeFrom(from.elastic_config()); } if (from.workers() != 0) { set_workers(from.workers()); } - if (from.minreplicas() != 0) { - set_minreplicas(from.minreplicas()); - } - if (from.maxreplicas() != 0) { - set_maxreplicas(from.maxreplicas()); - } - if (from.nprocpernode() != 0) { - set_nprocpernode(from.nprocpernode()); - } - if (from.maxrestarts() != 0) { - set_maxrestarts(from.maxrestarts()); - } } void DistributedPyTorchTrainingTask::CopyFrom(const ::google::protobuf::Message& from) { @@ -601,13 +928,8 @@ void DistributedPyTorchTrainingTask::Swap(DistributedPyTorchTrainingTask* other) void DistributedPyTorchTrainingTask::InternalSwap(DistributedPyTorchTrainingTask* other) { using std::swap; _internal_metadata_.Swap(&other->_internal_metadata_); - rdzvbackend_.Swap(&other->rdzvbackend_, &::google::protobuf::internal::GetEmptyStringAlreadyInited(), - GetArenaNoVirtual()); + swap(elastic_config_, other->elastic_config_); swap(workers_, other->workers_); - swap(minreplicas_, other->minreplicas_); - swap(maxreplicas_, other->maxreplicas_); - swap(nprocpernode_, other->nprocpernode_); - swap(maxrestarts_, other->maxrestarts_); } ::google::protobuf::Metadata DistributedPyTorchTrainingTask::GetMetadata() const { @@ -621,6 +943,9 @@ ::google::protobuf::Metadata DistributedPyTorchTrainingTask::GetMetadata() const } // namespace flyteidl namespace google { namespace protobuf { +template<> PROTOBUF_NOINLINE ::flyteidl::plugins::ElasticConfig* Arena::CreateMaybeMessage< ::flyteidl::plugins::ElasticConfig >(Arena* arena) { + return Arena::CreateInternal< ::flyteidl::plugins::ElasticConfig >(arena); +} template<> PROTOBUF_NOINLINE ::flyteidl::plugins::DistributedPyTorchTrainingTask* Arena::CreateMaybeMessage< ::flyteidl::plugins::DistributedPyTorchTrainingTask >(Arena* arena) { return Arena::CreateInternal< ::flyteidl::plugins::DistributedPyTorchTrainingTask >(arena); } diff --git a/gen/pb-cpp/flyteidl/plugins/pytorch.pb.h b/gen/pb-cpp/flyteidl/plugins/pytorch.pb.h index 7a0d15a76..b0d9b4a27 100644 --- a/gen/pb-cpp/flyteidl/plugins/pytorch.pb.h +++ b/gen/pb-cpp/flyteidl/plugins/pytorch.pb.h @@ -41,7 +41,7 @@ struct TableStruct_flyteidl_2fplugins_2fpytorch_2eproto { PROTOBUF_SECTION_VARIABLE(protodesc_cold); static const ::google::protobuf::internal::AuxillaryParseTableField aux[] PROTOBUF_SECTION_VARIABLE(protodesc_cold); - static const ::google::protobuf::internal::ParseTable schema[1] + static const ::google::protobuf::internal::ParseTable schema[2] PROTOBUF_SECTION_VARIABLE(protodesc_cold); static const ::google::protobuf::internal::FieldMetadata field_metadata[]; static const ::google::protobuf::internal::SerializationTable serialization_table[]; @@ -53,11 +53,15 @@ namespace plugins { class DistributedPyTorchTrainingTask; class DistributedPyTorchTrainingTaskDefaultTypeInternal; extern DistributedPyTorchTrainingTaskDefaultTypeInternal _DistributedPyTorchTrainingTask_default_instance_; +class ElasticConfig; +class ElasticConfigDefaultTypeInternal; +extern ElasticConfigDefaultTypeInternal _ElasticConfig_default_instance_; } // namespace plugins } // namespace flyteidl namespace google { namespace protobuf { template<> ::flyteidl::plugins::DistributedPyTorchTrainingTask* Arena::CreateMaybeMessage<::flyteidl::plugins::DistributedPyTorchTrainingTask>(Arena*); +template<> ::flyteidl::plugins::ElasticConfig* Arena::CreateMaybeMessage<::flyteidl::plugins::ElasticConfig>(Arena*); } // namespace protobuf } // namespace google namespace flyteidl { @@ -65,6 +69,154 @@ namespace plugins { // =================================================================== +class ElasticConfig final : + public ::google::protobuf::Message /* @@protoc_insertion_point(class_definition:flyteidl.plugins.ElasticConfig) */ { + public: + ElasticConfig(); + virtual ~ElasticConfig(); + + ElasticConfig(const ElasticConfig& from); + + inline ElasticConfig& operator=(const ElasticConfig& from) { + CopyFrom(from); + return *this; + } + #if LANG_CXX11 + ElasticConfig(ElasticConfig&& from) noexcept + : ElasticConfig() { + *this = ::std::move(from); + } + + inline ElasticConfig& operator=(ElasticConfig&& from) noexcept { + if (GetArenaNoVirtual() == from.GetArenaNoVirtual()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + #endif + static const ::google::protobuf::Descriptor* descriptor() { + return default_instance().GetDescriptor(); + } + static const ElasticConfig& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const ElasticConfig* internal_default_instance() { + return reinterpret_cast( + &_ElasticConfig_default_instance_); + } + static constexpr int kIndexInFileMessages = + 0; + + void Swap(ElasticConfig* other); + friend void swap(ElasticConfig& a, ElasticConfig& b) { + a.Swap(&b); + } + + // implements Message ---------------------------------------------- + + inline ElasticConfig* New() const final { + return CreateMaybeMessage(nullptr); + } + + ElasticConfig* New(::google::protobuf::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::google::protobuf::Message& from) final; + void MergeFrom(const ::google::protobuf::Message& from) final; + void CopyFrom(const ElasticConfig& from); + void MergeFrom(const ElasticConfig& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + #if GOOGLE_PROTOBUF_ENABLE_EXPERIMENTAL_PARSER + static const char* _InternalParse(const char* begin, const char* end, void* object, ::google::protobuf::internal::ParseContext* ctx); + ::google::protobuf::internal::ParseFunc _ParseFunc() const final { return _InternalParse; } + #else + bool MergePartialFromCodedStream( + ::google::protobuf::io::CodedInputStream* input) final; + #endif // GOOGLE_PROTOBUF_ENABLE_EXPERIMENTAL_PARSER + void SerializeWithCachedSizes( + ::google::protobuf::io::CodedOutputStream* output) const final; + ::google::protobuf::uint8* InternalSerializeWithCachedSizesToArray( + ::google::protobuf::uint8* target) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + void SharedCtor(); + void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(ElasticConfig* other); + private: + inline ::google::protobuf::Arena* GetArenaNoVirtual() const { + return nullptr; + } + inline void* MaybeArenaPtr() const { + return nullptr; + } + public: + + ::google::protobuf::Metadata GetMetadata() const final; + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + // string rdzv_backend = 1; + void clear_rdzv_backend(); + static const int kRdzvBackendFieldNumber = 1; + const ::std::string& rdzv_backend() const; + void set_rdzv_backend(const ::std::string& value); + #if LANG_CXX11 + void set_rdzv_backend(::std::string&& value); + #endif + void set_rdzv_backend(const char* value); + void set_rdzv_backend(const char* value, size_t size); + ::std::string* mutable_rdzv_backend(); + ::std::string* release_rdzv_backend(); + void set_allocated_rdzv_backend(::std::string* rdzv_backend); + + // int32 min_replicas = 2; + void clear_min_replicas(); + static const int kMinReplicasFieldNumber = 2; + ::google::protobuf::int32 min_replicas() const; + void set_min_replicas(::google::protobuf::int32 value); + + // int32 max_replicas = 3; + void clear_max_replicas(); + static const int kMaxReplicasFieldNumber = 3; + ::google::protobuf::int32 max_replicas() const; + void set_max_replicas(::google::protobuf::int32 value); + + // int32 nproc_per_node = 4; + void clear_nproc_per_node(); + static const int kNprocPerNodeFieldNumber = 4; + ::google::protobuf::int32 nproc_per_node() const; + void set_nproc_per_node(::google::protobuf::int32 value); + + // int32 max_restarts = 5; + void clear_max_restarts(); + static const int kMaxRestartsFieldNumber = 5; + ::google::protobuf::int32 max_restarts() const; + void set_max_restarts(::google::protobuf::int32 value); + + // @@protoc_insertion_point(class_scope:flyteidl.plugins.ElasticConfig) + private: + class HasBitSetters; + + ::google::protobuf::internal::InternalMetadataWithArena _internal_metadata_; + ::google::protobuf::internal::ArenaStringPtr rdzv_backend_; + ::google::protobuf::int32 min_replicas_; + ::google::protobuf::int32 max_replicas_; + ::google::protobuf::int32 nproc_per_node_; + ::google::protobuf::int32 max_restarts_; + mutable ::google::protobuf::internal::CachedSize _cached_size_; + friend struct ::TableStruct_flyteidl_2fplugins_2fpytorch_2eproto; +}; +// ------------------------------------------------------------------- + class DistributedPyTorchTrainingTask final : public ::google::protobuf::Message /* @@protoc_insertion_point(class_definition:flyteidl.plugins.DistributedPyTorchTrainingTask) */ { public: @@ -103,7 +255,7 @@ class DistributedPyTorchTrainingTask final : &_DistributedPyTorchTrainingTask_default_instance_); } static constexpr int kIndexInFileMessages = - 0; + 1; void Swap(DistributedPyTorchTrainingTask* other); friend void swap(DistributedPyTorchTrainingTask& a, DistributedPyTorchTrainingTask& b) { @@ -160,19 +312,14 @@ class DistributedPyTorchTrainingTask final : // accessors ------------------------------------------------------- - // string RDZVBackend = 2; - void clear_rdzvbackend(); - static const int kRDZVBackendFieldNumber = 2; - const ::std::string& rdzvbackend() const; - void set_rdzvbackend(const ::std::string& value); - #if LANG_CXX11 - void set_rdzvbackend(::std::string&& value); - #endif - void set_rdzvbackend(const char* value); - void set_rdzvbackend(const char* value, size_t size); - ::std::string* mutable_rdzvbackend(); - ::std::string* release_rdzvbackend(); - void set_allocated_rdzvbackend(::std::string* rdzvbackend); + // .flyteidl.plugins.ElasticConfig elastic_config = 2; + bool has_elastic_config() const; + void clear_elastic_config(); + static const int kElasticConfigFieldNumber = 2; + const ::flyteidl::plugins::ElasticConfig& elastic_config() const; + ::flyteidl::plugins::ElasticConfig* release_elastic_config(); + ::flyteidl::plugins::ElasticConfig* mutable_elastic_config(); + void set_allocated_elastic_config(::flyteidl::plugins::ElasticConfig* elastic_config); // int32 workers = 1; void clear_workers(); @@ -180,41 +327,13 @@ class DistributedPyTorchTrainingTask final : ::google::protobuf::int32 workers() const; void set_workers(::google::protobuf::int32 value); - // int32 minReplicas = 3; - void clear_minreplicas(); - static const int kMinReplicasFieldNumber = 3; - ::google::protobuf::int32 minreplicas() const; - void set_minreplicas(::google::protobuf::int32 value); - - // int32 maxReplicas = 4; - void clear_maxreplicas(); - static const int kMaxReplicasFieldNumber = 4; - ::google::protobuf::int32 maxreplicas() const; - void set_maxreplicas(::google::protobuf::int32 value); - - // int32 nProcPerNode = 5; - void clear_nprocpernode(); - static const int kNProcPerNodeFieldNumber = 5; - ::google::protobuf::int32 nprocpernode() const; - void set_nprocpernode(::google::protobuf::int32 value); - - // int32 maxRestarts = 6; - void clear_maxrestarts(); - static const int kMaxRestartsFieldNumber = 6; - ::google::protobuf::int32 maxrestarts() const; - void set_maxrestarts(::google::protobuf::int32 value); - // @@protoc_insertion_point(class_scope:flyteidl.plugins.DistributedPyTorchTrainingTask) private: class HasBitSetters; ::google::protobuf::internal::InternalMetadataWithArena _internal_metadata_; - ::google::protobuf::internal::ArenaStringPtr rdzvbackend_; + ::flyteidl::plugins::ElasticConfig* elastic_config_; ::google::protobuf::int32 workers_; - ::google::protobuf::int32 minreplicas_; - ::google::protobuf::int32 maxreplicas_; - ::google::protobuf::int32 nprocpernode_; - ::google::protobuf::int32 maxrestarts_; mutable ::google::protobuf::internal::CachedSize _cached_size_; friend struct ::TableStruct_flyteidl_2fplugins_2fpytorch_2eproto; }; @@ -227,134 +346,191 @@ class DistributedPyTorchTrainingTask final : #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wstrict-aliasing" #endif // __GNUC__ -// DistributedPyTorchTrainingTask +// ElasticConfig -// int32 workers = 1; -inline void DistributedPyTorchTrainingTask::clear_workers() { - workers_ = 0; +// string rdzv_backend = 1; +inline void ElasticConfig::clear_rdzv_backend() { + rdzv_backend_.ClearToEmptyNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } -inline ::google::protobuf::int32 DistributedPyTorchTrainingTask::workers() const { - // @@protoc_insertion_point(field_get:flyteidl.plugins.DistributedPyTorchTrainingTask.workers) - return workers_; +inline const ::std::string& ElasticConfig::rdzv_backend() const { + // @@protoc_insertion_point(field_get:flyteidl.plugins.ElasticConfig.rdzv_backend) + return rdzv_backend_.GetNoArena(); } -inline void DistributedPyTorchTrainingTask::set_workers(::google::protobuf::int32 value) { +inline void ElasticConfig::set_rdzv_backend(const ::std::string& value) { - workers_ = value; - // @@protoc_insertion_point(field_set:flyteidl.plugins.DistributedPyTorchTrainingTask.workers) -} - -// string RDZVBackend = 2; -inline void DistributedPyTorchTrainingTask::clear_rdzvbackend() { - rdzvbackend_.ClearToEmptyNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); -} -inline const ::std::string& DistributedPyTorchTrainingTask::rdzvbackend() const { - // @@protoc_insertion_point(field_get:flyteidl.plugins.DistributedPyTorchTrainingTask.RDZVBackend) - return rdzvbackend_.GetNoArena(); -} -inline void DistributedPyTorchTrainingTask::set_rdzvbackend(const ::std::string& value) { - - rdzvbackend_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); - // @@protoc_insertion_point(field_set:flyteidl.plugins.DistributedPyTorchTrainingTask.RDZVBackend) + rdzv_backend_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); + // @@protoc_insertion_point(field_set:flyteidl.plugins.ElasticConfig.rdzv_backend) } #if LANG_CXX11 -inline void DistributedPyTorchTrainingTask::set_rdzvbackend(::std::string&& value) { +inline void ElasticConfig::set_rdzv_backend(::std::string&& value) { - rdzvbackend_.SetNoArena( + rdzv_backend_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); - // @@protoc_insertion_point(field_set_rvalue:flyteidl.plugins.DistributedPyTorchTrainingTask.RDZVBackend) + // @@protoc_insertion_point(field_set_rvalue:flyteidl.plugins.ElasticConfig.rdzv_backend) } #endif -inline void DistributedPyTorchTrainingTask::set_rdzvbackend(const char* value) { +inline void ElasticConfig::set_rdzv_backend(const char* value) { GOOGLE_DCHECK(value != nullptr); - rdzvbackend_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); - // @@protoc_insertion_point(field_set_char:flyteidl.plugins.DistributedPyTorchTrainingTask.RDZVBackend) + rdzv_backend_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); + // @@protoc_insertion_point(field_set_char:flyteidl.plugins.ElasticConfig.rdzv_backend) } -inline void DistributedPyTorchTrainingTask::set_rdzvbackend(const char* value, size_t size) { +inline void ElasticConfig::set_rdzv_backend(const char* value, size_t size) { - rdzvbackend_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), + rdzv_backend_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); - // @@protoc_insertion_point(field_set_pointer:flyteidl.plugins.DistributedPyTorchTrainingTask.RDZVBackend) + // @@protoc_insertion_point(field_set_pointer:flyteidl.plugins.ElasticConfig.rdzv_backend) } -inline ::std::string* DistributedPyTorchTrainingTask::mutable_rdzvbackend() { +inline ::std::string* ElasticConfig::mutable_rdzv_backend() { - // @@protoc_insertion_point(field_mutable:flyteidl.plugins.DistributedPyTorchTrainingTask.RDZVBackend) - return rdzvbackend_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); + // @@protoc_insertion_point(field_mutable:flyteidl.plugins.ElasticConfig.rdzv_backend) + return rdzv_backend_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } -inline ::std::string* DistributedPyTorchTrainingTask::release_rdzvbackend() { - // @@protoc_insertion_point(field_release:flyteidl.plugins.DistributedPyTorchTrainingTask.RDZVBackend) +inline ::std::string* ElasticConfig::release_rdzv_backend() { + // @@protoc_insertion_point(field_release:flyteidl.plugins.ElasticConfig.rdzv_backend) - return rdzvbackend_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); + return rdzv_backend_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } -inline void DistributedPyTorchTrainingTask::set_allocated_rdzvbackend(::std::string* rdzvbackend) { - if (rdzvbackend != nullptr) { +inline void ElasticConfig::set_allocated_rdzv_backend(::std::string* rdzv_backend) { + if (rdzv_backend != nullptr) { } else { } - rdzvbackend_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), rdzvbackend); - // @@protoc_insertion_point(field_set_allocated:flyteidl.plugins.DistributedPyTorchTrainingTask.RDZVBackend) + rdzv_backend_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), rdzv_backend); + // @@protoc_insertion_point(field_set_allocated:flyteidl.plugins.ElasticConfig.rdzv_backend) +} + +// int32 min_replicas = 2; +inline void ElasticConfig::clear_min_replicas() { + min_replicas_ = 0; +} +inline ::google::protobuf::int32 ElasticConfig::min_replicas() const { + // @@protoc_insertion_point(field_get:flyteidl.plugins.ElasticConfig.min_replicas) + return min_replicas_; +} +inline void ElasticConfig::set_min_replicas(::google::protobuf::int32 value) { + + min_replicas_ = value; + // @@protoc_insertion_point(field_set:flyteidl.plugins.ElasticConfig.min_replicas) } -// int32 minReplicas = 3; -inline void DistributedPyTorchTrainingTask::clear_minreplicas() { - minreplicas_ = 0; +// int32 max_replicas = 3; +inline void ElasticConfig::clear_max_replicas() { + max_replicas_ = 0; } -inline ::google::protobuf::int32 DistributedPyTorchTrainingTask::minreplicas() const { - // @@protoc_insertion_point(field_get:flyteidl.plugins.DistributedPyTorchTrainingTask.minReplicas) - return minreplicas_; +inline ::google::protobuf::int32 ElasticConfig::max_replicas() const { + // @@protoc_insertion_point(field_get:flyteidl.plugins.ElasticConfig.max_replicas) + return max_replicas_; } -inline void DistributedPyTorchTrainingTask::set_minreplicas(::google::protobuf::int32 value) { +inline void ElasticConfig::set_max_replicas(::google::protobuf::int32 value) { - minreplicas_ = value; - // @@protoc_insertion_point(field_set:flyteidl.plugins.DistributedPyTorchTrainingTask.minReplicas) + max_replicas_ = value; + // @@protoc_insertion_point(field_set:flyteidl.plugins.ElasticConfig.max_replicas) } -// int32 maxReplicas = 4; -inline void DistributedPyTorchTrainingTask::clear_maxreplicas() { - maxreplicas_ = 0; +// int32 nproc_per_node = 4; +inline void ElasticConfig::clear_nproc_per_node() { + nproc_per_node_ = 0; } -inline ::google::protobuf::int32 DistributedPyTorchTrainingTask::maxreplicas() const { - // @@protoc_insertion_point(field_get:flyteidl.plugins.DistributedPyTorchTrainingTask.maxReplicas) - return maxreplicas_; +inline ::google::protobuf::int32 ElasticConfig::nproc_per_node() const { + // @@protoc_insertion_point(field_get:flyteidl.plugins.ElasticConfig.nproc_per_node) + return nproc_per_node_; } -inline void DistributedPyTorchTrainingTask::set_maxreplicas(::google::protobuf::int32 value) { +inline void ElasticConfig::set_nproc_per_node(::google::protobuf::int32 value) { - maxreplicas_ = value; - // @@protoc_insertion_point(field_set:flyteidl.plugins.DistributedPyTorchTrainingTask.maxReplicas) + nproc_per_node_ = value; + // @@protoc_insertion_point(field_set:flyteidl.plugins.ElasticConfig.nproc_per_node) } -// int32 nProcPerNode = 5; -inline void DistributedPyTorchTrainingTask::clear_nprocpernode() { - nprocpernode_ = 0; +// int32 max_restarts = 5; +inline void ElasticConfig::clear_max_restarts() { + max_restarts_ = 0; +} +inline ::google::protobuf::int32 ElasticConfig::max_restarts() const { + // @@protoc_insertion_point(field_get:flyteidl.plugins.ElasticConfig.max_restarts) + return max_restarts_; } -inline ::google::protobuf::int32 DistributedPyTorchTrainingTask::nprocpernode() const { - // @@protoc_insertion_point(field_get:flyteidl.plugins.DistributedPyTorchTrainingTask.nProcPerNode) - return nprocpernode_; +inline void ElasticConfig::set_max_restarts(::google::protobuf::int32 value) { + + max_restarts_ = value; + // @@protoc_insertion_point(field_set:flyteidl.plugins.ElasticConfig.max_restarts) +} + +// ------------------------------------------------------------------- + +// DistributedPyTorchTrainingTask + +// int32 workers = 1; +inline void DistributedPyTorchTrainingTask::clear_workers() { + workers_ = 0; +} +inline ::google::protobuf::int32 DistributedPyTorchTrainingTask::workers() const { + // @@protoc_insertion_point(field_get:flyteidl.plugins.DistributedPyTorchTrainingTask.workers) + return workers_; } -inline void DistributedPyTorchTrainingTask::set_nprocpernode(::google::protobuf::int32 value) { +inline void DistributedPyTorchTrainingTask::set_workers(::google::protobuf::int32 value) { - nprocpernode_ = value; - // @@protoc_insertion_point(field_set:flyteidl.plugins.DistributedPyTorchTrainingTask.nProcPerNode) + workers_ = value; + // @@protoc_insertion_point(field_set:flyteidl.plugins.DistributedPyTorchTrainingTask.workers) } -// int32 maxRestarts = 6; -inline void DistributedPyTorchTrainingTask::clear_maxrestarts() { - maxrestarts_ = 0; +// .flyteidl.plugins.ElasticConfig elastic_config = 2; +inline bool DistributedPyTorchTrainingTask::has_elastic_config() const { + return this != internal_default_instance() && elastic_config_ != nullptr; } -inline ::google::protobuf::int32 DistributedPyTorchTrainingTask::maxrestarts() const { - // @@protoc_insertion_point(field_get:flyteidl.plugins.DistributedPyTorchTrainingTask.maxRestarts) - return maxrestarts_; +inline void DistributedPyTorchTrainingTask::clear_elastic_config() { + if (GetArenaNoVirtual() == nullptr && elastic_config_ != nullptr) { + delete elastic_config_; + } + elastic_config_ = nullptr; } -inline void DistributedPyTorchTrainingTask::set_maxrestarts(::google::protobuf::int32 value) { +inline const ::flyteidl::plugins::ElasticConfig& DistributedPyTorchTrainingTask::elastic_config() const { + const ::flyteidl::plugins::ElasticConfig* p = elastic_config_; + // @@protoc_insertion_point(field_get:flyteidl.plugins.DistributedPyTorchTrainingTask.elastic_config) + return p != nullptr ? *p : *reinterpret_cast( + &::flyteidl::plugins::_ElasticConfig_default_instance_); +} +inline ::flyteidl::plugins::ElasticConfig* DistributedPyTorchTrainingTask::release_elastic_config() { + // @@protoc_insertion_point(field_release:flyteidl.plugins.DistributedPyTorchTrainingTask.elastic_config) + + ::flyteidl::plugins::ElasticConfig* temp = elastic_config_; + elastic_config_ = nullptr; + return temp; +} +inline ::flyteidl::plugins::ElasticConfig* DistributedPyTorchTrainingTask::mutable_elastic_config() { - maxrestarts_ = value; - // @@protoc_insertion_point(field_set:flyteidl.plugins.DistributedPyTorchTrainingTask.maxRestarts) + if (elastic_config_ == nullptr) { + auto* p = CreateMaybeMessage<::flyteidl::plugins::ElasticConfig>(GetArenaNoVirtual()); + elastic_config_ = p; + } + // @@protoc_insertion_point(field_mutable:flyteidl.plugins.DistributedPyTorchTrainingTask.elastic_config) + return elastic_config_; +} +inline void DistributedPyTorchTrainingTask::set_allocated_elastic_config(::flyteidl::plugins::ElasticConfig* elastic_config) { + ::google::protobuf::Arena* message_arena = GetArenaNoVirtual(); + if (message_arena == nullptr) { + delete elastic_config_; + } + if (elastic_config) { + ::google::protobuf::Arena* submessage_arena = nullptr; + if (message_arena != submessage_arena) { + elastic_config = ::google::protobuf::internal::GetOwnedMessage( + message_arena, elastic_config, submessage_arena); + } + + } else { + + } + elastic_config_ = elastic_config; + // @@protoc_insertion_point(field_set_allocated:flyteidl.plugins.DistributedPyTorchTrainingTask.elastic_config) } #ifdef __GNUC__ #pragma GCC diagnostic pop #endif // __GNUC__ +// ------------------------------------------------------------------- + // @@protoc_insertion_point(namespace_scope) diff --git a/gen/pb-go/flyteidl/plugins/pytorch.pb.go b/gen/pb-go/flyteidl/plugins/pytorch.pb.go index c25b64024..f75649fa0 100644 --- a/gen/pb-go/flyteidl/plugins/pytorch.pb.go +++ b/gen/pb-go/flyteidl/plugins/pytorch.pb.go @@ -20,110 +20,156 @@ var _ = math.Inf // proto package needs to be updated. const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package -// Custom proto for plugin that enables distributed training using https://github.com/kubeflow/pytorch-operator -type DistributedPyTorchTrainingTask struct { - // number of worker replicas spawned in the cluster for this job - Workers int32 `protobuf:"varint,1,opt,name=workers,proto3" json:"workers,omitempty"` - // config for an elastic pytorch job - // https://github.com/kubeflow/training-operator/blob/master/pkg/apis/kubeflow.org/v1/pytorch_types.go - RDZVBackend string `protobuf:"bytes,2,opt,name=RDZVBackend,proto3" json:"RDZVBackend,omitempty"` - MinReplicas int32 `protobuf:"varint,3,opt,name=minReplicas,proto3" json:"minReplicas,omitempty"` - MaxReplicas int32 `protobuf:"varint,4,opt,name=maxReplicas,proto3" json:"maxReplicas,omitempty"` - NProcPerNode int32 `protobuf:"varint,5,opt,name=nProcPerNode,proto3" json:"nProcPerNode,omitempty"` - MaxRestarts int32 `protobuf:"varint,6,opt,name=maxRestarts,proto3" json:"maxRestarts,omitempty"` +// Custom proto for torch elastic config for distributed training using +// https://github.com/kubeflow/training-operator/blob/master/pkg/apis/kubeflow.org/v1/pytorch_types.go +type ElasticConfig struct { + RdzvBackend string `protobuf:"bytes,1,opt,name=rdzv_backend,json=rdzvBackend,proto3" json:"rdzv_backend,omitempty"` + MinReplicas int32 `protobuf:"varint,2,opt,name=min_replicas,json=minReplicas,proto3" json:"min_replicas,omitempty"` + MaxReplicas int32 `protobuf:"varint,3,opt,name=max_replicas,json=maxReplicas,proto3" json:"max_replicas,omitempty"` + NprocPerNode int32 `protobuf:"varint,4,opt,name=nproc_per_node,json=nprocPerNode,proto3" json:"nproc_per_node,omitempty"` + MaxRestarts int32 `protobuf:"varint,5,opt,name=max_restarts,json=maxRestarts,proto3" json:"max_restarts,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` } -func (m *DistributedPyTorchTrainingTask) Reset() { *m = DistributedPyTorchTrainingTask{} } -func (m *DistributedPyTorchTrainingTask) String() string { return proto.CompactTextString(m) } -func (*DistributedPyTorchTrainingTask) ProtoMessage() {} -func (*DistributedPyTorchTrainingTask) Descriptor() ([]byte, []int) { +func (m *ElasticConfig) Reset() { *m = ElasticConfig{} } +func (m *ElasticConfig) String() string { return proto.CompactTextString(m) } +func (*ElasticConfig) ProtoMessage() {} +func (*ElasticConfig) Descriptor() ([]byte, []int) { return fileDescriptor_4df8a9374b28b766, []int{0} } -func (m *DistributedPyTorchTrainingTask) XXX_Unmarshal(b []byte) error { - return xxx_messageInfo_DistributedPyTorchTrainingTask.Unmarshal(m, b) +func (m *ElasticConfig) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_ElasticConfig.Unmarshal(m, b) } -func (m *DistributedPyTorchTrainingTask) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - return xxx_messageInfo_DistributedPyTorchTrainingTask.Marshal(b, m, deterministic) +func (m *ElasticConfig) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_ElasticConfig.Marshal(b, m, deterministic) } -func (m *DistributedPyTorchTrainingTask) XXX_Merge(src proto.Message) { - xxx_messageInfo_DistributedPyTorchTrainingTask.Merge(m, src) +func (m *ElasticConfig) XXX_Merge(src proto.Message) { + xxx_messageInfo_ElasticConfig.Merge(m, src) } -func (m *DistributedPyTorchTrainingTask) XXX_Size() int { - return xxx_messageInfo_DistributedPyTorchTrainingTask.Size(m) +func (m *ElasticConfig) XXX_Size() int { + return xxx_messageInfo_ElasticConfig.Size(m) } -func (m *DistributedPyTorchTrainingTask) XXX_DiscardUnknown() { - xxx_messageInfo_DistributedPyTorchTrainingTask.DiscardUnknown(m) +func (m *ElasticConfig) XXX_DiscardUnknown() { + xxx_messageInfo_ElasticConfig.DiscardUnknown(m) } -var xxx_messageInfo_DistributedPyTorchTrainingTask proto.InternalMessageInfo - -func (m *DistributedPyTorchTrainingTask) GetWorkers() int32 { - if m != nil { - return m.Workers - } - return 0 -} +var xxx_messageInfo_ElasticConfig proto.InternalMessageInfo -func (m *DistributedPyTorchTrainingTask) GetRDZVBackend() string { +func (m *ElasticConfig) GetRdzvBackend() string { if m != nil { - return m.RDZVBackend + return m.RdzvBackend } return "" } -func (m *DistributedPyTorchTrainingTask) GetMinReplicas() int32 { +func (m *ElasticConfig) GetMinReplicas() int32 { if m != nil { return m.MinReplicas } return 0 } -func (m *DistributedPyTorchTrainingTask) GetMaxReplicas() int32 { +func (m *ElasticConfig) GetMaxReplicas() int32 { if m != nil { return m.MaxReplicas } return 0 } -func (m *DistributedPyTorchTrainingTask) GetNProcPerNode() int32 { +func (m *ElasticConfig) GetNprocPerNode() int32 { if m != nil { - return m.NProcPerNode + return m.NprocPerNode } return 0 } -func (m *DistributedPyTorchTrainingTask) GetMaxRestarts() int32 { +func (m *ElasticConfig) GetMaxRestarts() int32 { if m != nil { return m.MaxRestarts } return 0 } +// Custom proto for plugin that enables distributed training using https://github.com/kubeflow/pytorch-operator +type DistributedPyTorchTrainingTask struct { + // number of worker replicas spawned in the cluster for this job + Workers int32 `protobuf:"varint,1,opt,name=workers,proto3" json:"workers,omitempty"` + // config for an elastic pytorch job + // + ElasticConfig *ElasticConfig `protobuf:"bytes,2,opt,name=elastic_config,json=elasticConfig,proto3" json:"elastic_config,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *DistributedPyTorchTrainingTask) Reset() { *m = DistributedPyTorchTrainingTask{} } +func (m *DistributedPyTorchTrainingTask) String() string { return proto.CompactTextString(m) } +func (*DistributedPyTorchTrainingTask) ProtoMessage() {} +func (*DistributedPyTorchTrainingTask) Descriptor() ([]byte, []int) { + return fileDescriptor_4df8a9374b28b766, []int{1} +} + +func (m *DistributedPyTorchTrainingTask) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_DistributedPyTorchTrainingTask.Unmarshal(m, b) +} +func (m *DistributedPyTorchTrainingTask) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_DistributedPyTorchTrainingTask.Marshal(b, m, deterministic) +} +func (m *DistributedPyTorchTrainingTask) XXX_Merge(src proto.Message) { + xxx_messageInfo_DistributedPyTorchTrainingTask.Merge(m, src) +} +func (m *DistributedPyTorchTrainingTask) XXX_Size() int { + return xxx_messageInfo_DistributedPyTorchTrainingTask.Size(m) +} +func (m *DistributedPyTorchTrainingTask) XXX_DiscardUnknown() { + xxx_messageInfo_DistributedPyTorchTrainingTask.DiscardUnknown(m) +} + +var xxx_messageInfo_DistributedPyTorchTrainingTask proto.InternalMessageInfo + +func (m *DistributedPyTorchTrainingTask) GetWorkers() int32 { + if m != nil { + return m.Workers + } + return 0 +} + +func (m *DistributedPyTorchTrainingTask) GetElasticConfig() *ElasticConfig { + if m != nil { + return m.ElasticConfig + } + return nil +} + func init() { + proto.RegisterType((*ElasticConfig)(nil), "flyteidl.plugins.ElasticConfig") proto.RegisterType((*DistributedPyTorchTrainingTask)(nil), "flyteidl.plugins.DistributedPyTorchTrainingTask") } func init() { proto.RegisterFile("flyteidl/plugins/pytorch.proto", fileDescriptor_4df8a9374b28b766) } var fileDescriptor_4df8a9374b28b766 = []byte{ - // 238 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x5c, 0x90, 0x31, 0x4f, 0xc3, 0x30, - 0x10, 0x46, 0x15, 0xa0, 0x45, 0x18, 0x06, 0x94, 0xc9, 0x53, 0x15, 0x75, 0xea, 0x42, 0x3c, 0x30, - 0x20, 0xd6, 0xaa, 0x33, 0x8a, 0xa2, 0x88, 0xa1, 0x9b, 0xe3, 0x18, 0xf7, 0x94, 0xd4, 0x67, 0x9d, - 0x1d, 0x41, 0xfe, 0x30, 0xbf, 0x03, 0xd5, 0x34, 0x21, 0x74, 0xf4, 0xbb, 0xe7, 0x6f, 0x78, 0x6c, - 0xf5, 0xd1, 0x0d, 0x41, 0x43, 0xd3, 0x09, 0xd7, 0xf5, 0x06, 0xac, 0x17, 0x6e, 0x08, 0x48, 0xea, - 0x90, 0x3b, 0xc2, 0x80, 0xe9, 0xe3, 0x78, 0xcf, 0xcf, 0xf7, 0xf5, 0x77, 0xc2, 0x56, 0x3b, 0xf0, - 0x81, 0xa0, 0xee, 0x83, 0x6e, 0x8a, 0xa1, 0x3a, 0xe9, 0x15, 0x49, 0xb0, 0x60, 0x4d, 0x25, 0x7d, - 0x9b, 0x72, 0x76, 0xfb, 0x89, 0xd4, 0x6a, 0xf2, 0x3c, 0xc9, 0x92, 0xcd, 0xa2, 0x1c, 0x9f, 0x69, - 0xc6, 0xee, 0xcb, 0xdd, 0xfe, 0x7d, 0x2b, 0x55, 0xab, 0x6d, 0xc3, 0xaf, 0xb2, 0x64, 0x73, 0x57, - 0xce, 0xd1, 0xc9, 0x38, 0x82, 0x2d, 0xb5, 0xeb, 0x40, 0x49, 0xcf, 0xaf, 0xe3, 0xff, 0x39, 0x8a, - 0x86, 0xfc, 0x9a, 0x8c, 0x9b, 0xb3, 0xf1, 0x87, 0xd2, 0x35, 0x7b, 0xb0, 0x05, 0xa1, 0x2a, 0x34, - 0xbd, 0x61, 0xa3, 0xf9, 0x22, 0x2a, 0xff, 0xd8, 0xb4, 0xe2, 0x83, 0xa4, 0xe0, 0xf9, 0x72, 0xb6, - 0xf2, 0x8b, 0xb6, 0xaf, 0xfb, 0x17, 0x03, 0xe1, 0xd0, 0xd7, 0xb9, 0xc2, 0xa3, 0x88, 0x1d, 0x90, - 0x8c, 0x98, 0x82, 0x19, 0x6d, 0x85, 0xab, 0x9f, 0x0c, 0x8a, 0xcb, 0x86, 0xf5, 0x32, 0xc6, 0x7b, - 0xfe, 0x09, 0x00, 0x00, 0xff, 0xff, 0x5b, 0x03, 0x64, 0x09, 0x5e, 0x01, 0x00, 0x00, + // 299 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x5c, 0x91, 0xbd, 0x4f, 0xc3, 0x30, + 0x10, 0xc5, 0x15, 0xa0, 0x20, 0xdc, 0x0f, 0xa1, 0x4c, 0x99, 0x4a, 0xa9, 0x18, 0xba, 0x90, 0x48, + 0x30, 0x20, 0xd6, 0xf2, 0x31, 0xa2, 0x2a, 0xea, 0xc4, 0x12, 0x39, 0xf6, 0xd5, 0x3d, 0x35, 0xb5, + 0xad, 0xb3, 0x0b, 0x2d, 0x23, 0xff, 0x19, 0xff, 0x19, 0xaa, 0x9b, 0x7e, 0xd0, 0xf1, 0xde, 0xfd, + 0xee, 0xa4, 0xf7, 0x1e, 0xeb, 0x4e, 0xaa, 0x95, 0x07, 0x94, 0x55, 0x66, 0xab, 0x85, 0x42, 0xed, + 0x32, 0xbb, 0xf2, 0x86, 0xc4, 0x34, 0xb5, 0x64, 0xbc, 0x89, 0xaf, 0xb6, 0xfb, 0xb4, 0xde, 0xf7, + 0x7f, 0x23, 0xd6, 0x7e, 0xad, 0xb8, 0xf3, 0x28, 0x9e, 0x8d, 0x9e, 0xa0, 0x8a, 0x6f, 0x58, 0x8b, + 0xe4, 0xf7, 0x67, 0x51, 0x72, 0x31, 0x03, 0x2d, 0x93, 0xa8, 0x17, 0x0d, 0x2e, 0xf3, 0xe6, 0x5a, + 0x1b, 0x6e, 0xa4, 0x35, 0x32, 0x47, 0x5d, 0x10, 0xd8, 0x0a, 0x05, 0x77, 0xc9, 0x49, 0x2f, 0x1a, + 0x34, 0xf2, 0xe6, 0x1c, 0x75, 0x5e, 0x4b, 0x01, 0xe1, 0xcb, 0x3d, 0x72, 0x5a, 0x23, 0x7c, 0xb9, + 0x43, 0x6e, 0x59, 0x47, 0x5b, 0x32, 0xa2, 0xb0, 0x40, 0x85, 0x36, 0x12, 0x92, 0xb3, 0x00, 0xb5, + 0x82, 0x3a, 0x02, 0x7a, 0x37, 0x12, 0xf6, 0x8f, 0x9c, 0xe7, 0xe4, 0x5d, 0xd2, 0x38, 0x78, 0xb4, + 0x91, 0xfa, 0x3f, 0x11, 0xeb, 0xbe, 0xa0, 0xf3, 0x84, 0xe5, 0xc2, 0x83, 0x1c, 0xad, 0xc6, 0x6b, + 0xcb, 0x63, 0xe2, 0xa8, 0x51, 0xab, 0x31, 0x77, 0xb3, 0x38, 0x61, 0x17, 0x5f, 0x86, 0x66, 0x40, + 0x2e, 0xf8, 0x69, 0xe4, 0xdb, 0x31, 0x7e, 0x63, 0x1d, 0xd8, 0xf8, 0x2f, 0x44, 0x08, 0x20, 0xb8, + 0x69, 0xde, 0x5f, 0xa7, 0xc7, 0x59, 0xa5, 0xff, 0x72, 0xca, 0xdb, 0x70, 0x38, 0x0e, 0x9f, 0x3e, + 0x1e, 0x15, 0xfa, 0xe9, 0xa2, 0x4c, 0x85, 0x99, 0x67, 0xe1, 0xd6, 0x90, 0xca, 0x76, 0x85, 0x28, + 0xd0, 0x99, 0x2d, 0xef, 0x94, 0xc9, 0x8e, 0x3b, 0x2a, 0xcf, 0x43, 0x39, 0x0f, 0x7f, 0x01, 0x00, + 0x00, 0xff, 0xff, 0x6f, 0x80, 0x2c, 0x15, 0xbe, 0x01, 0x00, 0x00, } diff --git a/gen/pb-go/flyteidl/plugins/pytorch.pb.validate.go b/gen/pb-go/flyteidl/plugins/pytorch.pb.validate.go index ae49a4b14..17b90db72 100644 --- a/gen/pb-go/flyteidl/plugins/pytorch.pb.validate.go +++ b/gen/pb-go/flyteidl/plugins/pytorch.pb.validate.go @@ -36,29 +36,104 @@ var ( // define the regex for a UUID once up-front var _pytorch_uuidPattern = regexp.MustCompile("^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$") -// Validate checks the field values on DistributedPyTorchTrainingTask with the -// rules defined in the proto definition for this message. If any rules are -// violated, an error is returned. -func (m *DistributedPyTorchTrainingTask) Validate() error { +// Validate checks the field values on ElasticConfig with the rules defined in +// the proto definition for this message. If any rules are violated, an error +// is returned. +func (m *ElasticConfig) Validate() error { if m == nil { return nil } - // no validation rules for Workers - - // no validation rules for RDZVBackend + // no validation rules for RdzvBackend // no validation rules for MinReplicas // no validation rules for MaxReplicas - // no validation rules for NProcPerNode + // no validation rules for NprocPerNode // no validation rules for MaxRestarts return nil } +// ElasticConfigValidationError is the validation error returned by +// ElasticConfig.Validate if the designated constraints aren't met. +type ElasticConfigValidationError struct { + field string + reason string + cause error + key bool +} + +// Field function returns field value. +func (e ElasticConfigValidationError) Field() string { return e.field } + +// Reason function returns reason value. +func (e ElasticConfigValidationError) Reason() string { return e.reason } + +// Cause function returns cause value. +func (e ElasticConfigValidationError) Cause() error { return e.cause } + +// Key function returns key value. +func (e ElasticConfigValidationError) Key() bool { return e.key } + +// ErrorName returns error name. +func (e ElasticConfigValidationError) ErrorName() string { return "ElasticConfigValidationError" } + +// Error satisfies the builtin error interface +func (e ElasticConfigValidationError) Error() string { + cause := "" + if e.cause != nil { + cause = fmt.Sprintf(" | caused by: %v", e.cause) + } + + key := "" + if e.key { + key = "key for " + } + + return fmt.Sprintf( + "invalid %sElasticConfig.%s: %s%s", + key, + e.field, + e.reason, + cause) +} + +var _ error = ElasticConfigValidationError{} + +var _ interface { + Field() string + Reason() string + Key() bool + Cause() error + ErrorName() string +} = ElasticConfigValidationError{} + +// Validate checks the field values on DistributedPyTorchTrainingTask with the +// rules defined in the proto definition for this message. If any rules are +// violated, an error is returned. +func (m *DistributedPyTorchTrainingTask) Validate() error { + if m == nil { + return nil + } + + // no validation rules for Workers + + if v, ok := interface{}(m.GetElasticConfig()).(interface{ Validate() error }); ok { + if err := v.Validate(); err != nil { + return DistributedPyTorchTrainingTaskValidationError{ + field: "ElasticConfig", + reason: "embedded message failed validation", + cause: err, + } + } + } + + return nil +} + // DistributedPyTorchTrainingTaskValidationError is the validation error // returned by DistributedPyTorchTrainingTask.Validate if the designated // constraints aren't met. diff --git a/gen/pb-java/flyteidl/plugins/Pytorch.java b/gen/pb-java/flyteidl/plugins/Pytorch.java index feba13fb8..1df7f243b 100644 --- a/gen/pb-java/flyteidl/plugins/Pytorch.java +++ b/gen/pb-java/flyteidl/plugins/Pytorch.java @@ -14,77 +14,59 @@ public static void registerAllExtensions( registerAllExtensions( (com.google.protobuf.ExtensionRegistryLite) registry); } - public interface DistributedPyTorchTrainingTaskOrBuilder extends - // @@protoc_insertion_point(interface_extends:flyteidl.plugins.DistributedPyTorchTrainingTask) + public interface ElasticConfigOrBuilder extends + // @@protoc_insertion_point(interface_extends:flyteidl.plugins.ElasticConfig) com.google.protobuf.MessageOrBuilder { /** - *
-     * number of worker replicas spawned in the cluster for this job
-     * 
- * - * int32 workers = 1; - */ - int getWorkers(); - - /** - *
-     * config for an elastic pytorch job
-     * https://github.com/kubeflow/training-operator/blob/master/pkg/apis/kubeflow.org/v1/pytorch_types.go
-     * 
- * - * string RDZVBackend = 2; + * string rdzv_backend = 1; */ - java.lang.String getRDZVBackend(); + java.lang.String getRdzvBackend(); /** - *
-     * config for an elastic pytorch job
-     * https://github.com/kubeflow/training-operator/blob/master/pkg/apis/kubeflow.org/v1/pytorch_types.go
-     * 
- * - * string RDZVBackend = 2; + * string rdzv_backend = 1; */ com.google.protobuf.ByteString - getRDZVBackendBytes(); + getRdzvBackendBytes(); /** - * int32 minReplicas = 3; + * int32 min_replicas = 2; */ int getMinReplicas(); /** - * int32 maxReplicas = 4; + * int32 max_replicas = 3; */ int getMaxReplicas(); /** - * int32 nProcPerNode = 5; + * int32 nproc_per_node = 4; */ - int getNProcPerNode(); + int getNprocPerNode(); /** - * int32 maxRestarts = 6; + * int32 max_restarts = 5; */ int getMaxRestarts(); } /** *
-   * Custom proto for plugin that enables distributed training using https://github.com/kubeflow/pytorch-operator
+   * Custom proto for torch elastic config for distributed training using 
+   * https://github.com/kubeflow/training-operator/blob/master/pkg/apis/kubeflow.org/v1/pytorch_types.go
    * 
* - * Protobuf type {@code flyteidl.plugins.DistributedPyTorchTrainingTask} + * Protobuf type {@code flyteidl.plugins.ElasticConfig} */ - public static final class DistributedPyTorchTrainingTask extends + public static final class ElasticConfig extends com.google.protobuf.GeneratedMessageV3 implements - // @@protoc_insertion_point(message_implements:flyteidl.plugins.DistributedPyTorchTrainingTask) - DistributedPyTorchTrainingTaskOrBuilder { + // @@protoc_insertion_point(message_implements:flyteidl.plugins.ElasticConfig) + ElasticConfigOrBuilder { private static final long serialVersionUID = 0L; - // Use DistributedPyTorchTrainingTask.newBuilder() to construct. - private DistributedPyTorchTrainingTask(com.google.protobuf.GeneratedMessageV3.Builder builder) { + // Use ElasticConfig.newBuilder() to construct. + private ElasticConfig(com.google.protobuf.GeneratedMessageV3.Builder builder) { super(builder); } - private DistributedPyTorchTrainingTask() { - rDZVBackend_ = ""; + private ElasticConfig() { + rdzvBackend_ = ""; } @java.lang.Override @@ -92,7 +74,7 @@ private DistributedPyTorchTrainingTask() { getUnknownFields() { return this.unknownFields; } - private DistributedPyTorchTrainingTask( + private ElasticConfig( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { @@ -111,33 +93,28 @@ private DistributedPyTorchTrainingTask( case 0: done = true; break; - case 8: { - - workers_ = input.readInt32(); - break; - } - case 18: { + case 10: { java.lang.String s = input.readStringRequireUtf8(); - rDZVBackend_ = s; + rdzvBackend_ = s; break; } - case 24: { + case 16: { minReplicas_ = input.readInt32(); break; } - case 32: { + case 24: { maxReplicas_ = input.readInt32(); break; } - case 40: { + case 32: { - nProcPerNode_ = input.readInt32(); + nprocPerNode_ = input.readInt32(); break; } - case 48: { + case 40: { maxRestarts_ = input.readInt32(); break; @@ -163,105 +140,82 @@ private DistributedPyTorchTrainingTask( } public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return flyteidl.plugins.Pytorch.internal_static_flyteidl_plugins_DistributedPyTorchTrainingTask_descriptor; + return flyteidl.plugins.Pytorch.internal_static_flyteidl_plugins_ElasticConfig_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return flyteidl.plugins.Pytorch.internal_static_flyteidl_plugins_DistributedPyTorchTrainingTask_fieldAccessorTable + return flyteidl.plugins.Pytorch.internal_static_flyteidl_plugins_ElasticConfig_fieldAccessorTable .ensureFieldAccessorsInitialized( - flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask.class, flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask.Builder.class); - } - - public static final int WORKERS_FIELD_NUMBER = 1; - private int workers_; - /** - *
-     * number of worker replicas spawned in the cluster for this job
-     * 
- * - * int32 workers = 1; - */ - public int getWorkers() { - return workers_; + flyteidl.plugins.Pytorch.ElasticConfig.class, flyteidl.plugins.Pytorch.ElasticConfig.Builder.class); } - public static final int RDZVBACKEND_FIELD_NUMBER = 2; - private volatile java.lang.Object rDZVBackend_; + public static final int RDZV_BACKEND_FIELD_NUMBER = 1; + private volatile java.lang.Object rdzvBackend_; /** - *
-     * config for an elastic pytorch job
-     * https://github.com/kubeflow/training-operator/blob/master/pkg/apis/kubeflow.org/v1/pytorch_types.go
-     * 
- * - * string RDZVBackend = 2; + * string rdzv_backend = 1; */ - public java.lang.String getRDZVBackend() { - java.lang.Object ref = rDZVBackend_; + public java.lang.String getRdzvBackend() { + java.lang.Object ref = rdzvBackend_; if (ref instanceof java.lang.String) { return (java.lang.String) ref; } else { com.google.protobuf.ByteString bs = (com.google.protobuf.ByteString) ref; java.lang.String s = bs.toStringUtf8(); - rDZVBackend_ = s; + rdzvBackend_ = s; return s; } } /** - *
-     * config for an elastic pytorch job
-     * https://github.com/kubeflow/training-operator/blob/master/pkg/apis/kubeflow.org/v1/pytorch_types.go
-     * 
- * - * string RDZVBackend = 2; + * string rdzv_backend = 1; */ public com.google.protobuf.ByteString - getRDZVBackendBytes() { - java.lang.Object ref = rDZVBackend_; + getRdzvBackendBytes() { + java.lang.Object ref = rdzvBackend_; if (ref instanceof java.lang.String) { com.google.protobuf.ByteString b = com.google.protobuf.ByteString.copyFromUtf8( (java.lang.String) ref); - rDZVBackend_ = b; + rdzvBackend_ = b; return b; } else { return (com.google.protobuf.ByteString) ref; } } - public static final int MINREPLICAS_FIELD_NUMBER = 3; + public static final int MIN_REPLICAS_FIELD_NUMBER = 2; private int minReplicas_; /** - * int32 minReplicas = 3; + * int32 min_replicas = 2; */ public int getMinReplicas() { return minReplicas_; } - public static final int MAXREPLICAS_FIELD_NUMBER = 4; + public static final int MAX_REPLICAS_FIELD_NUMBER = 3; private int maxReplicas_; /** - * int32 maxReplicas = 4; + * int32 max_replicas = 3; */ public int getMaxReplicas() { return maxReplicas_; } - public static final int NPROCPERNODE_FIELD_NUMBER = 5; - private int nProcPerNode_; + public static final int NPROC_PER_NODE_FIELD_NUMBER = 4; + private int nprocPerNode_; /** - * int32 nProcPerNode = 5; + * int32 nproc_per_node = 4; */ - public int getNProcPerNode() { - return nProcPerNode_; + public int getNprocPerNode() { + return nprocPerNode_; } - public static final int MAXRESTARTS_FIELD_NUMBER = 6; + public static final int MAX_RESTARTS_FIELD_NUMBER = 5; private int maxRestarts_; /** - * int32 maxRestarts = 6; + * int32 max_restarts = 5; */ public int getMaxRestarts() { return maxRestarts_; @@ -281,23 +235,20 @@ public final boolean isInitialized() { @java.lang.Override public void writeTo(com.google.protobuf.CodedOutputStream output) throws java.io.IOException { - if (workers_ != 0) { - output.writeInt32(1, workers_); - } - if (!getRDZVBackendBytes().isEmpty()) { - com.google.protobuf.GeneratedMessageV3.writeString(output, 2, rDZVBackend_); + if (!getRdzvBackendBytes().isEmpty()) { + com.google.protobuf.GeneratedMessageV3.writeString(output, 1, rdzvBackend_); } if (minReplicas_ != 0) { - output.writeInt32(3, minReplicas_); + output.writeInt32(2, minReplicas_); } if (maxReplicas_ != 0) { - output.writeInt32(4, maxReplicas_); + output.writeInt32(3, maxReplicas_); } - if (nProcPerNode_ != 0) { - output.writeInt32(5, nProcPerNode_); + if (nprocPerNode_ != 0) { + output.writeInt32(4, nprocPerNode_); } if (maxRestarts_ != 0) { - output.writeInt32(6, maxRestarts_); + output.writeInt32(5, maxRestarts_); } unknownFields.writeTo(output); } @@ -308,28 +259,24 @@ public int getSerializedSize() { if (size != -1) return size; size = 0; - if (workers_ != 0) { - size += com.google.protobuf.CodedOutputStream - .computeInt32Size(1, workers_); - } - if (!getRDZVBackendBytes().isEmpty()) { - size += com.google.protobuf.GeneratedMessageV3.computeStringSize(2, rDZVBackend_); + if (!getRdzvBackendBytes().isEmpty()) { + size += com.google.protobuf.GeneratedMessageV3.computeStringSize(1, rdzvBackend_); } if (minReplicas_ != 0) { size += com.google.protobuf.CodedOutputStream - .computeInt32Size(3, minReplicas_); + .computeInt32Size(2, minReplicas_); } if (maxReplicas_ != 0) { size += com.google.protobuf.CodedOutputStream - .computeInt32Size(4, maxReplicas_); + .computeInt32Size(3, maxReplicas_); } - if (nProcPerNode_ != 0) { + if (nprocPerNode_ != 0) { size += com.google.protobuf.CodedOutputStream - .computeInt32Size(5, nProcPerNode_); + .computeInt32Size(4, nprocPerNode_); } if (maxRestarts_ != 0) { size += com.google.protobuf.CodedOutputStream - .computeInt32Size(6, maxRestarts_); + .computeInt32Size(5, maxRestarts_); } size += unknownFields.getSerializedSize(); memoizedSize = size; @@ -341,21 +288,19 @@ public boolean equals(final java.lang.Object obj) { if (obj == this) { return true; } - if (!(obj instanceof flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask)) { + if (!(obj instanceof flyteidl.plugins.Pytorch.ElasticConfig)) { return super.equals(obj); } - flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask other = (flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask) obj; + flyteidl.plugins.Pytorch.ElasticConfig other = (flyteidl.plugins.Pytorch.ElasticConfig) obj; - if (getWorkers() - != other.getWorkers()) return false; - if (!getRDZVBackend() - .equals(other.getRDZVBackend())) return false; + if (!getRdzvBackend() + .equals(other.getRdzvBackend())) return false; if (getMinReplicas() != other.getMinReplicas()) return false; if (getMaxReplicas() != other.getMaxReplicas()) return false; - if (getNProcPerNode() - != other.getNProcPerNode()) return false; + if (getNprocPerNode() + != other.getNprocPerNode()) return false; if (getMaxRestarts() != other.getMaxRestarts()) return false; if (!unknownFields.equals(other.unknownFields)) return false; @@ -369,86 +314,84 @@ public int hashCode() { } int hash = 41; hash = (19 * hash) + getDescriptor().hashCode(); - hash = (37 * hash) + WORKERS_FIELD_NUMBER; - hash = (53 * hash) + getWorkers(); - hash = (37 * hash) + RDZVBACKEND_FIELD_NUMBER; - hash = (53 * hash) + getRDZVBackend().hashCode(); - hash = (37 * hash) + MINREPLICAS_FIELD_NUMBER; + hash = (37 * hash) + RDZV_BACKEND_FIELD_NUMBER; + hash = (53 * hash) + getRdzvBackend().hashCode(); + hash = (37 * hash) + MIN_REPLICAS_FIELD_NUMBER; hash = (53 * hash) + getMinReplicas(); - hash = (37 * hash) + MAXREPLICAS_FIELD_NUMBER; + hash = (37 * hash) + MAX_REPLICAS_FIELD_NUMBER; hash = (53 * hash) + getMaxReplicas(); - hash = (37 * hash) + NPROCPERNODE_FIELD_NUMBER; - hash = (53 * hash) + getNProcPerNode(); - hash = (37 * hash) + MAXRESTARTS_FIELD_NUMBER; + hash = (37 * hash) + NPROC_PER_NODE_FIELD_NUMBER; + hash = (53 * hash) + getNprocPerNode(); + hash = (37 * hash) + MAX_RESTARTS_FIELD_NUMBER; hash = (53 * hash) + getMaxRestarts(); hash = (29 * hash) + unknownFields.hashCode(); memoizedHashCode = hash; return hash; } - public static flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask parseFrom( + public static flyteidl.plugins.Pytorch.ElasticConfig parseFrom( java.nio.ByteBuffer data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask parseFrom( + public static flyteidl.plugins.Pytorch.ElasticConfig parseFrom( java.nio.ByteBuffer data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask parseFrom( + public static flyteidl.plugins.Pytorch.ElasticConfig parseFrom( com.google.protobuf.ByteString data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask parseFrom( + public static flyteidl.plugins.Pytorch.ElasticConfig parseFrom( com.google.protobuf.ByteString data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask parseFrom(byte[] data) + public static flyteidl.plugins.Pytorch.ElasticConfig parseFrom(byte[] data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask parseFrom( + public static flyteidl.plugins.Pytorch.ElasticConfig parseFrom( byte[] data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask parseFrom(java.io.InputStream input) + public static flyteidl.plugins.Pytorch.ElasticConfig parseFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask parseFrom( + public static flyteidl.plugins.Pytorch.ElasticConfig parseFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input, extensionRegistry); } - public static flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask parseDelimitedFrom(java.io.InputStream input) + public static flyteidl.plugins.Pytorch.ElasticConfig parseDelimitedFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input); } - public static flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask parseDelimitedFrom( + public static flyteidl.plugins.Pytorch.ElasticConfig parseDelimitedFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input, extensionRegistry); } - public static flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask parseFrom( + public static flyteidl.plugins.Pytorch.ElasticConfig parseFrom( com.google.protobuf.CodedInputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask parseFrom( + public static flyteidl.plugins.Pytorch.ElasticConfig parseFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { @@ -461,7 +404,7 @@ public static flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask parseFrom( public static Builder newBuilder() { return DEFAULT_INSTANCE.toBuilder(); } - public static Builder newBuilder(flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask prototype) { + public static Builder newBuilder(flyteidl.plugins.Pytorch.ElasticConfig prototype) { return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); } @java.lang.Override @@ -478,29 +421,30 @@ protected Builder newBuilderForType( } /** *
-     * Custom proto for plugin that enables distributed training using https://github.com/kubeflow/pytorch-operator
+     * Custom proto for torch elastic config for distributed training using 
+     * https://github.com/kubeflow/training-operator/blob/master/pkg/apis/kubeflow.org/v1/pytorch_types.go
      * 
* - * Protobuf type {@code flyteidl.plugins.DistributedPyTorchTrainingTask} + * Protobuf type {@code flyteidl.plugins.ElasticConfig} */ public static final class Builder extends com.google.protobuf.GeneratedMessageV3.Builder implements - // @@protoc_insertion_point(builder_implements:flyteidl.plugins.DistributedPyTorchTrainingTask) - flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTaskOrBuilder { + // @@protoc_insertion_point(builder_implements:flyteidl.plugins.ElasticConfig) + flyteidl.plugins.Pytorch.ElasticConfigOrBuilder { public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return flyteidl.plugins.Pytorch.internal_static_flyteidl_plugins_DistributedPyTorchTrainingTask_descriptor; + return flyteidl.plugins.Pytorch.internal_static_flyteidl_plugins_ElasticConfig_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return flyteidl.plugins.Pytorch.internal_static_flyteidl_plugins_DistributedPyTorchTrainingTask_fieldAccessorTable + return flyteidl.plugins.Pytorch.internal_static_flyteidl_plugins_ElasticConfig_fieldAccessorTable .ensureFieldAccessorsInitialized( - flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask.class, flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask.Builder.class); + flyteidl.plugins.Pytorch.ElasticConfig.class, flyteidl.plugins.Pytorch.ElasticConfig.Builder.class); } - // Construct using flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask.newBuilder() + // Construct using flyteidl.plugins.Pytorch.ElasticConfig.newBuilder() private Builder() { maybeForceBuilderInitialization(); } @@ -518,15 +462,13 @@ private void maybeForceBuilderInitialization() { @java.lang.Override public Builder clear() { super.clear(); - workers_ = 0; - - rDZVBackend_ = ""; + rdzvBackend_ = ""; minReplicas_ = 0; maxReplicas_ = 0; - nProcPerNode_ = 0; + nprocPerNode_ = 0; maxRestarts_ = 0; @@ -536,17 +478,17 @@ public Builder clear() { @java.lang.Override public com.google.protobuf.Descriptors.Descriptor getDescriptorForType() { - return flyteidl.plugins.Pytorch.internal_static_flyteidl_plugins_DistributedPyTorchTrainingTask_descriptor; + return flyteidl.plugins.Pytorch.internal_static_flyteidl_plugins_ElasticConfig_descriptor; } @java.lang.Override - public flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask getDefaultInstanceForType() { - return flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask.getDefaultInstance(); + public flyteidl.plugins.Pytorch.ElasticConfig getDefaultInstanceForType() { + return flyteidl.plugins.Pytorch.ElasticConfig.getDefaultInstance(); } @java.lang.Override - public flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask build() { - flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask result = buildPartial(); + public flyteidl.plugins.Pytorch.ElasticConfig build() { + flyteidl.plugins.Pytorch.ElasticConfig result = buildPartial(); if (!result.isInitialized()) { throw newUninitializedMessageException(result); } @@ -554,13 +496,12 @@ public flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask build() { } @java.lang.Override - public flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask buildPartial() { - flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask result = new flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask(this); - result.workers_ = workers_; - result.rDZVBackend_ = rDZVBackend_; + public flyteidl.plugins.Pytorch.ElasticConfig buildPartial() { + flyteidl.plugins.Pytorch.ElasticConfig result = new flyteidl.plugins.Pytorch.ElasticConfig(this); + result.rdzvBackend_ = rdzvBackend_; result.minReplicas_ = minReplicas_; result.maxReplicas_ = maxReplicas_; - result.nProcPerNode_ = nProcPerNode_; + result.nprocPerNode_ = nprocPerNode_; result.maxRestarts_ = maxRestarts_; onBuilt(); return result; @@ -600,21 +541,18 @@ public Builder addRepeatedField( } @java.lang.Override public Builder mergeFrom(com.google.protobuf.Message other) { - if (other instanceof flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask) { - return mergeFrom((flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask)other); + if (other instanceof flyteidl.plugins.Pytorch.ElasticConfig) { + return mergeFrom((flyteidl.plugins.Pytorch.ElasticConfig)other); } else { super.mergeFrom(other); return this; } } - public Builder mergeFrom(flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask other) { - if (other == flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask.getDefaultInstance()) return this; - if (other.getWorkers() != 0) { - setWorkers(other.getWorkers()); - } - if (!other.getRDZVBackend().isEmpty()) { - rDZVBackend_ = other.rDZVBackend_; + public Builder mergeFrom(flyteidl.plugins.Pytorch.ElasticConfig other) { + if (other == flyteidl.plugins.Pytorch.ElasticConfig.getDefaultInstance()) return this; + if (!other.getRdzvBackend().isEmpty()) { + rdzvBackend_ = other.rdzvBackend_; onChanged(); } if (other.getMinReplicas() != 0) { @@ -623,8 +561,8 @@ public Builder mergeFrom(flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask if (other.getMaxReplicas() != 0) { setMaxReplicas(other.getMaxReplicas()); } - if (other.getNProcPerNode() != 0) { - setNProcPerNode(other.getNProcPerNode()); + if (other.getNprocPerNode() != 0) { + setNprocPerNode(other.getNprocPerNode()); } if (other.getMaxRestarts() != 0) { setMaxRestarts(other.getMaxRestarts()); @@ -644,11 +582,11 @@ public Builder mergeFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { - flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask parsedMessage = null; + flyteidl.plugins.Pytorch.ElasticConfig parsedMessage = null; try { parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry); } catch (com.google.protobuf.InvalidProtocolBufferException e) { - parsedMessage = (flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask) e.getUnfinishedMessage(); + parsedMessage = (flyteidl.plugins.Pytorch.ElasticConfig) e.getUnfinishedMessage(); throw e.unwrapIOException(); } finally { if (parsedMessage != null) { @@ -658,147 +596,84 @@ public Builder mergeFrom( return this; } - private int workers_ ; - /** - *
-       * number of worker replicas spawned in the cluster for this job
-       * 
- * - * int32 workers = 1; - */ - public int getWorkers() { - return workers_; - } - /** - *
-       * number of worker replicas spawned in the cluster for this job
-       * 
- * - * int32 workers = 1; - */ - public Builder setWorkers(int value) { - - workers_ = value; - onChanged(); - return this; - } - /** - *
-       * number of worker replicas spawned in the cluster for this job
-       * 
- * - * int32 workers = 1; - */ - public Builder clearWorkers() { - - workers_ = 0; - onChanged(); - return this; - } - - private java.lang.Object rDZVBackend_ = ""; + private java.lang.Object rdzvBackend_ = ""; /** - *
-       * config for an elastic pytorch job
-       * https://github.com/kubeflow/training-operator/blob/master/pkg/apis/kubeflow.org/v1/pytorch_types.go
-       * 
- * - * string RDZVBackend = 2; + * string rdzv_backend = 1; */ - public java.lang.String getRDZVBackend() { - java.lang.Object ref = rDZVBackend_; + public java.lang.String getRdzvBackend() { + java.lang.Object ref = rdzvBackend_; if (!(ref instanceof java.lang.String)) { com.google.protobuf.ByteString bs = (com.google.protobuf.ByteString) ref; java.lang.String s = bs.toStringUtf8(); - rDZVBackend_ = s; + rdzvBackend_ = s; return s; } else { return (java.lang.String) ref; } } /** - *
-       * config for an elastic pytorch job
-       * https://github.com/kubeflow/training-operator/blob/master/pkg/apis/kubeflow.org/v1/pytorch_types.go
-       * 
- * - * string RDZVBackend = 2; + * string rdzv_backend = 1; */ public com.google.protobuf.ByteString - getRDZVBackendBytes() { - java.lang.Object ref = rDZVBackend_; + getRdzvBackendBytes() { + java.lang.Object ref = rdzvBackend_; if (ref instanceof String) { com.google.protobuf.ByteString b = com.google.protobuf.ByteString.copyFromUtf8( (java.lang.String) ref); - rDZVBackend_ = b; + rdzvBackend_ = b; return b; } else { return (com.google.protobuf.ByteString) ref; } } /** - *
-       * config for an elastic pytorch job
-       * https://github.com/kubeflow/training-operator/blob/master/pkg/apis/kubeflow.org/v1/pytorch_types.go
-       * 
- * - * string RDZVBackend = 2; + * string rdzv_backend = 1; */ - public Builder setRDZVBackend( + public Builder setRdzvBackend( java.lang.String value) { if (value == null) { throw new NullPointerException(); } - rDZVBackend_ = value; + rdzvBackend_ = value; onChanged(); return this; } /** - *
-       * config for an elastic pytorch job
-       * https://github.com/kubeflow/training-operator/blob/master/pkg/apis/kubeflow.org/v1/pytorch_types.go
-       * 
- * - * string RDZVBackend = 2; + * string rdzv_backend = 1; */ - public Builder clearRDZVBackend() { + public Builder clearRdzvBackend() { - rDZVBackend_ = getDefaultInstance().getRDZVBackend(); + rdzvBackend_ = getDefaultInstance().getRdzvBackend(); onChanged(); return this; } /** - *
-       * config for an elastic pytorch job
-       * https://github.com/kubeflow/training-operator/blob/master/pkg/apis/kubeflow.org/v1/pytorch_types.go
-       * 
- * - * string RDZVBackend = 2; + * string rdzv_backend = 1; */ - public Builder setRDZVBackendBytes( + public Builder setRdzvBackendBytes( com.google.protobuf.ByteString value) { if (value == null) { throw new NullPointerException(); } checkByteStringIsUtf8(value); - rDZVBackend_ = value; + rdzvBackend_ = value; onChanged(); return this; } private int minReplicas_ ; /** - * int32 minReplicas = 3; + * int32 min_replicas = 2; */ public int getMinReplicas() { return minReplicas_; } /** - * int32 minReplicas = 3; + * int32 min_replicas = 2; */ public Builder setMinReplicas(int value) { @@ -807,7 +682,7 @@ public Builder setMinReplicas(int value) { return this; } /** - * int32 minReplicas = 3; + * int32 min_replicas = 2; */ public Builder clearMinReplicas() { @@ -818,13 +693,13 @@ public Builder clearMinReplicas() { private int maxReplicas_ ; /** - * int32 maxReplicas = 4; + * int32 max_replicas = 3; */ public int getMaxReplicas() { return maxReplicas_; } /** - * int32 maxReplicas = 4; + * int32 max_replicas = 3; */ public Builder setMaxReplicas(int value) { @@ -833,7 +708,7 @@ public Builder setMaxReplicas(int value) { return this; } /** - * int32 maxReplicas = 4; + * int32 max_replicas = 3; */ public Builder clearMaxReplicas() { @@ -842,41 +717,41 @@ public Builder clearMaxReplicas() { return this; } - private int nProcPerNode_ ; + private int nprocPerNode_ ; /** - * int32 nProcPerNode = 5; + * int32 nproc_per_node = 4; */ - public int getNProcPerNode() { - return nProcPerNode_; + public int getNprocPerNode() { + return nprocPerNode_; } /** - * int32 nProcPerNode = 5; + * int32 nproc_per_node = 4; */ - public Builder setNProcPerNode(int value) { + public Builder setNprocPerNode(int value) { - nProcPerNode_ = value; + nprocPerNode_ = value; onChanged(); return this; } /** - * int32 nProcPerNode = 5; + * int32 nproc_per_node = 4; */ - public Builder clearNProcPerNode() { + public Builder clearNprocPerNode() { - nProcPerNode_ = 0; + nprocPerNode_ = 0; onChanged(); return this; } private int maxRestarts_ ; /** - * int32 maxRestarts = 6; + * int32 max_restarts = 5; */ public int getMaxRestarts() { return maxRestarts_; } /** - * int32 maxRestarts = 6; + * int32 max_restarts = 5; */ public Builder setMaxRestarts(int value) { @@ -885,7 +760,7 @@ public Builder setMaxRestarts(int value) { return this; } /** - * int32 maxRestarts = 6; + * int32 max_restarts = 5; */ public Builder clearMaxRestarts() { @@ -906,86 +781,870 @@ public final Builder mergeUnknownFields( } - // @@protoc_insertion_point(builder_scope:flyteidl.plugins.DistributedPyTorchTrainingTask) + // @@protoc_insertion_point(builder_scope:flyteidl.plugins.ElasticConfig) } - // @@protoc_insertion_point(class_scope:flyteidl.plugins.DistributedPyTorchTrainingTask) - private static final flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask DEFAULT_INSTANCE; + // @@protoc_insertion_point(class_scope:flyteidl.plugins.ElasticConfig) + private static final flyteidl.plugins.Pytorch.ElasticConfig DEFAULT_INSTANCE; static { - DEFAULT_INSTANCE = new flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask(); + DEFAULT_INSTANCE = new flyteidl.plugins.Pytorch.ElasticConfig(); } - public static flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask getDefaultInstance() { + public static flyteidl.plugins.Pytorch.ElasticConfig getDefaultInstance() { return DEFAULT_INSTANCE; } - private static final com.google.protobuf.Parser - PARSER = new com.google.protobuf.AbstractParser() { + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { @java.lang.Override - public DistributedPyTorchTrainingTask parsePartialFrom( + public ElasticConfig parsePartialFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { - return new DistributedPyTorchTrainingTask(input, extensionRegistry); + return new ElasticConfig(input, extensionRegistry); } }; - public static com.google.protobuf.Parser parser() { + public static com.google.protobuf.Parser parser() { return PARSER; } @java.lang.Override - public com.google.protobuf.Parser getParserForType() { + public com.google.protobuf.Parser getParserForType() { return PARSER; } @java.lang.Override - public flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask getDefaultInstanceForType() { + public flyteidl.plugins.Pytorch.ElasticConfig getDefaultInstanceForType() { return DEFAULT_INSTANCE; } } - private static final com.google.protobuf.Descriptors.Descriptor - internal_static_flyteidl_plugins_DistributedPyTorchTrainingTask_descriptor; - private static final - com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internal_static_flyteidl_plugins_DistributedPyTorchTrainingTask_fieldAccessorTable; + public interface DistributedPyTorchTrainingTaskOrBuilder extends + // @@protoc_insertion_point(interface_extends:flyteidl.plugins.DistributedPyTorchTrainingTask) + com.google.protobuf.MessageOrBuilder { - public static com.google.protobuf.Descriptors.FileDescriptor - getDescriptor() { - return descriptor; + /** + *
+     * number of worker replicas spawned in the cluster for this job
+     * 
+ * + * int32 workers = 1; + */ + int getWorkers(); + + /** + *
+     * config for an elastic pytorch job
+     * 
+     * 
+ * + * .flyteidl.plugins.ElasticConfig elastic_config = 2; + */ + boolean hasElasticConfig(); + /** + *
+     * config for an elastic pytorch job
+     * 
+     * 
+ * + * .flyteidl.plugins.ElasticConfig elastic_config = 2; + */ + flyteidl.plugins.Pytorch.ElasticConfig getElasticConfig(); + /** + *
+     * config for an elastic pytorch job
+     * 
+     * 
+ * + * .flyteidl.plugins.ElasticConfig elastic_config = 2; + */ + flyteidl.plugins.Pytorch.ElasticConfigOrBuilder getElasticConfigOrBuilder(); } - private static com.google.protobuf.Descriptors.FileDescriptor - descriptor; - static { - java.lang.String[] descriptorData = { - "\n\036flyteidl/plugins/pytorch.proto\022\020flytei" + - "dl.plugins\"\233\001\n\036DistributedPyTorchTrainin" + - "gTask\022\017\n\007workers\030\001 \001(\005\022\023\n\013RDZVBackend\030\002 " + - "\001(\t\022\023\n\013minReplicas\030\003 \001(\005\022\023\n\013maxReplicas\030" + - "\004 \001(\005\022\024\n\014nProcPerNode\030\005 \001(\005\022\023\n\013maxRestar" + - "ts\030\006 \001(\005B9Z7github.com/flyteorg/flyteidl" + - "/gen/pb-go/flyteidl/pluginsb\006proto3" - }; - com.google.protobuf.Descriptors.FileDescriptor.InternalDescriptorAssigner assigner = - new com.google.protobuf.Descriptors.FileDescriptor. InternalDescriptorAssigner() { - public com.google.protobuf.ExtensionRegistry assignDescriptors( - com.google.protobuf.Descriptors.FileDescriptor root) { - descriptor = root; - return null; - } - }; - com.google.protobuf.Descriptors.FileDescriptor - .internalBuildGeneratedFileFrom(descriptorData, - new com.google.protobuf.Descriptors.FileDescriptor[] { - }, assigner); - internal_static_flyteidl_plugins_DistributedPyTorchTrainingTask_descriptor = - getDescriptor().getMessageTypes().get(0); + /** + *
+   * Custom proto for plugin that enables distributed training using https://github.com/kubeflow/pytorch-operator
+   * 
+ * + * Protobuf type {@code flyteidl.plugins.DistributedPyTorchTrainingTask} + */ + public static final class DistributedPyTorchTrainingTask extends + com.google.protobuf.GeneratedMessageV3 implements + // @@protoc_insertion_point(message_implements:flyteidl.plugins.DistributedPyTorchTrainingTask) + DistributedPyTorchTrainingTaskOrBuilder { + private static final long serialVersionUID = 0L; + // Use DistributedPyTorchTrainingTask.newBuilder() to construct. + private DistributedPyTorchTrainingTask(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } + private DistributedPyTorchTrainingTask() { + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } + private DistributedPyTorchTrainingTask( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + this(); + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + int mutable_bitField0_ = 0; + com.google.protobuf.UnknownFieldSet.Builder unknownFields = + com.google.protobuf.UnknownFieldSet.newBuilder(); + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + case 8: { + + workers_ = input.readInt32(); + break; + } + case 18: { + flyteidl.plugins.Pytorch.ElasticConfig.Builder subBuilder = null; + if (elasticConfig_ != null) { + subBuilder = elasticConfig_.toBuilder(); + } + elasticConfig_ = input.readMessage(flyteidl.plugins.Pytorch.ElasticConfig.parser(), extensionRegistry); + if (subBuilder != null) { + subBuilder.mergeFrom(elasticConfig_); + elasticConfig_ = subBuilder.buildPartial(); + } + + break; + } + default: { + if (!parseUnknownField( + input, unknownFields, extensionRegistry, tag)) { + done = true; + } + break; + } + } + } + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(this); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException( + e).setUnfinishedMessage(this); + } finally { + this.unknownFields = unknownFields.build(); + makeExtensionsImmutable(); + } + } + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return flyteidl.plugins.Pytorch.internal_static_flyteidl_plugins_DistributedPyTorchTrainingTask_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return flyteidl.plugins.Pytorch.internal_static_flyteidl_plugins_DistributedPyTorchTrainingTask_fieldAccessorTable + .ensureFieldAccessorsInitialized( + flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask.class, flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask.Builder.class); + } + + public static final int WORKERS_FIELD_NUMBER = 1; + private int workers_; + /** + *
+     * number of worker replicas spawned in the cluster for this job
+     * 
+ * + * int32 workers = 1; + */ + public int getWorkers() { + return workers_; + } + + public static final int ELASTIC_CONFIG_FIELD_NUMBER = 2; + private flyteidl.plugins.Pytorch.ElasticConfig elasticConfig_; + /** + *
+     * config for an elastic pytorch job
+     * 
+     * 
+ * + * .flyteidl.plugins.ElasticConfig elastic_config = 2; + */ + public boolean hasElasticConfig() { + return elasticConfig_ != null; + } + /** + *
+     * config for an elastic pytorch job
+     * 
+     * 
+ * + * .flyteidl.plugins.ElasticConfig elastic_config = 2; + */ + public flyteidl.plugins.Pytorch.ElasticConfig getElasticConfig() { + return elasticConfig_ == null ? flyteidl.plugins.Pytorch.ElasticConfig.getDefaultInstance() : elasticConfig_; + } + /** + *
+     * config for an elastic pytorch job
+     * 
+     * 
+ * + * .flyteidl.plugins.ElasticConfig elastic_config = 2; + */ + public flyteidl.plugins.Pytorch.ElasticConfigOrBuilder getElasticConfigOrBuilder() { + return getElasticConfig(); + } + + private byte memoizedIsInitialized = -1; + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; + + memoizedIsInitialized = 1; + return true; + } + + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) + throws java.io.IOException { + if (workers_ != 0) { + output.writeInt32(1, workers_); + } + if (elasticConfig_ != null) { + output.writeMessage(2, getElasticConfig()); + } + unknownFields.writeTo(output); + } + + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + if (workers_ != 0) { + size += com.google.protobuf.CodedOutputStream + .computeInt32Size(1, workers_); + } + if (elasticConfig_ != null) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(2, getElasticConfig()); + } + size += unknownFields.getSerializedSize(); + memoizedSize = size; + return size; + } + + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask)) { + return super.equals(obj); + } + flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask other = (flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask) obj; + + if (getWorkers() + != other.getWorkers()) return false; + if (hasElasticConfig() != other.hasElasticConfig()) return false; + if (hasElasticConfig()) { + if (!getElasticConfig() + .equals(other.getElasticConfig())) return false; + } + if (!unknownFields.equals(other.unknownFields)) return false; + return true; + } + + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + hash = (37 * hash) + WORKERS_FIELD_NUMBER; + hash = (53 * hash) + getWorkers(); + if (hasElasticConfig()) { + hash = (37 * hash) + ELASTIC_CONFIG_FIELD_NUMBER; + hash = (53 * hash) + getElasticConfig().hashCode(); + } + hash = (29 * hash) + unknownFields.hashCode(); + memoizedHashCode = hash; + return hash; + } + + public static flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask parseFrom( + java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask parseFrom( + java.nio.ByteBuffer data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask parseFrom( + com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask parseFrom( + byte[] data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask parseFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + public static flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input); + } + public static flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask parseDelimitedFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input, extensionRegistry); + } + public static flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask parseFrom( + com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + + @java.lang.Override + public Builder newBuilderForType() { return newBuilder(); } + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } + public static Builder newBuilder(flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE + ? new Builder() : new Builder().mergeFrom(this); + } + + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } + /** + *
+     * Custom proto for plugin that enables distributed training using https://github.com/kubeflow/pytorch-operator
+     * 
+ * + * Protobuf type {@code flyteidl.plugins.DistributedPyTorchTrainingTask} + */ + public static final class Builder extends + com.google.protobuf.GeneratedMessageV3.Builder implements + // @@protoc_insertion_point(builder_implements:flyteidl.plugins.DistributedPyTorchTrainingTask) + flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTaskOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return flyteidl.plugins.Pytorch.internal_static_flyteidl_plugins_DistributedPyTorchTrainingTask_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return flyteidl.plugins.Pytorch.internal_static_flyteidl_plugins_DistributedPyTorchTrainingTask_fieldAccessorTable + .ensureFieldAccessorsInitialized( + flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask.class, flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask.Builder.class); + } + + // Construct using flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask.newBuilder() + private Builder() { + maybeForceBuilderInitialization(); + } + + private Builder( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + maybeForceBuilderInitialization(); + } + private void maybeForceBuilderInitialization() { + if (com.google.protobuf.GeneratedMessageV3 + .alwaysUseFieldBuilders) { + } + } + @java.lang.Override + public Builder clear() { + super.clear(); + workers_ = 0; + + if (elasticConfigBuilder_ == null) { + elasticConfig_ = null; + } else { + elasticConfig_ = null; + elasticConfigBuilder_ = null; + } + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor + getDescriptorForType() { + return flyteidl.plugins.Pytorch.internal_static_flyteidl_plugins_DistributedPyTorchTrainingTask_descriptor; + } + + @java.lang.Override + public flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask getDefaultInstanceForType() { + return flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask.getDefaultInstance(); + } + + @java.lang.Override + public flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask build() { + flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask buildPartial() { + flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask result = new flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask(this); + result.workers_ = workers_; + if (elasticConfigBuilder_ == null) { + result.elasticConfig_ = elasticConfig_; + } else { + result.elasticConfig_ = elasticConfigBuilder_.build(); + } + onBuilt(); + return result; + } + + @java.lang.Override + public Builder clone() { + return super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.addRepeatedField(field, value); + } + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask) { + return mergeFrom((flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask)other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask other) { + if (other == flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask.getDefaultInstance()) return this; + if (other.getWorkers() != 0) { + setWorkers(other.getWorkers()); + } + if (other.hasElasticConfig()) { + mergeElasticConfig(other.getElasticConfig()); + } + this.mergeUnknownFields(other.unknownFields); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask parsedMessage = null; + try { + parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + parsedMessage = (flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask) e.getUnfinishedMessage(); + throw e.unwrapIOException(); + } finally { + if (parsedMessage != null) { + mergeFrom(parsedMessage); + } + } + return this; + } + + private int workers_ ; + /** + *
+       * number of worker replicas spawned in the cluster for this job
+       * 
+ * + * int32 workers = 1; + */ + public int getWorkers() { + return workers_; + } + /** + *
+       * number of worker replicas spawned in the cluster for this job
+       * 
+ * + * int32 workers = 1; + */ + public Builder setWorkers(int value) { + + workers_ = value; + onChanged(); + return this; + } + /** + *
+       * number of worker replicas spawned in the cluster for this job
+       * 
+ * + * int32 workers = 1; + */ + public Builder clearWorkers() { + + workers_ = 0; + onChanged(); + return this; + } + + private flyteidl.plugins.Pytorch.ElasticConfig elasticConfig_; + private com.google.protobuf.SingleFieldBuilderV3< + flyteidl.plugins.Pytorch.ElasticConfig, flyteidl.plugins.Pytorch.ElasticConfig.Builder, flyteidl.plugins.Pytorch.ElasticConfigOrBuilder> elasticConfigBuilder_; + /** + *
+       * config for an elastic pytorch job
+       * 
+       * 
+ * + * .flyteidl.plugins.ElasticConfig elastic_config = 2; + */ + public boolean hasElasticConfig() { + return elasticConfigBuilder_ != null || elasticConfig_ != null; + } + /** + *
+       * config for an elastic pytorch job
+       * 
+       * 
+ * + * .flyteidl.plugins.ElasticConfig elastic_config = 2; + */ + public flyteidl.plugins.Pytorch.ElasticConfig getElasticConfig() { + if (elasticConfigBuilder_ == null) { + return elasticConfig_ == null ? flyteidl.plugins.Pytorch.ElasticConfig.getDefaultInstance() : elasticConfig_; + } else { + return elasticConfigBuilder_.getMessage(); + } + } + /** + *
+       * config for an elastic pytorch job
+       * 
+       * 
+ * + * .flyteidl.plugins.ElasticConfig elastic_config = 2; + */ + public Builder setElasticConfig(flyteidl.plugins.Pytorch.ElasticConfig value) { + if (elasticConfigBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + elasticConfig_ = value; + onChanged(); + } else { + elasticConfigBuilder_.setMessage(value); + } + + return this; + } + /** + *
+       * config for an elastic pytorch job
+       * 
+       * 
+ * + * .flyteidl.plugins.ElasticConfig elastic_config = 2; + */ + public Builder setElasticConfig( + flyteidl.plugins.Pytorch.ElasticConfig.Builder builderForValue) { + if (elasticConfigBuilder_ == null) { + elasticConfig_ = builderForValue.build(); + onChanged(); + } else { + elasticConfigBuilder_.setMessage(builderForValue.build()); + } + + return this; + } + /** + *
+       * config for an elastic pytorch job
+       * 
+       * 
+ * + * .flyteidl.plugins.ElasticConfig elastic_config = 2; + */ + public Builder mergeElasticConfig(flyteidl.plugins.Pytorch.ElasticConfig value) { + if (elasticConfigBuilder_ == null) { + if (elasticConfig_ != null) { + elasticConfig_ = + flyteidl.plugins.Pytorch.ElasticConfig.newBuilder(elasticConfig_).mergeFrom(value).buildPartial(); + } else { + elasticConfig_ = value; + } + onChanged(); + } else { + elasticConfigBuilder_.mergeFrom(value); + } + + return this; + } + /** + *
+       * config for an elastic pytorch job
+       * 
+       * 
+ * + * .flyteidl.plugins.ElasticConfig elastic_config = 2; + */ + public Builder clearElasticConfig() { + if (elasticConfigBuilder_ == null) { + elasticConfig_ = null; + onChanged(); + } else { + elasticConfig_ = null; + elasticConfigBuilder_ = null; + } + + return this; + } + /** + *
+       * config for an elastic pytorch job
+       * 
+       * 
+ * + * .flyteidl.plugins.ElasticConfig elastic_config = 2; + */ + public flyteidl.plugins.Pytorch.ElasticConfig.Builder getElasticConfigBuilder() { + + onChanged(); + return getElasticConfigFieldBuilder().getBuilder(); + } + /** + *
+       * config for an elastic pytorch job
+       * 
+       * 
+ * + * .flyteidl.plugins.ElasticConfig elastic_config = 2; + */ + public flyteidl.plugins.Pytorch.ElasticConfigOrBuilder getElasticConfigOrBuilder() { + if (elasticConfigBuilder_ != null) { + return elasticConfigBuilder_.getMessageOrBuilder(); + } else { + return elasticConfig_ == null ? + flyteidl.plugins.Pytorch.ElasticConfig.getDefaultInstance() : elasticConfig_; + } + } + /** + *
+       * config for an elastic pytorch job
+       * 
+       * 
+ * + * .flyteidl.plugins.ElasticConfig elastic_config = 2; + */ + private com.google.protobuf.SingleFieldBuilderV3< + flyteidl.plugins.Pytorch.ElasticConfig, flyteidl.plugins.Pytorch.ElasticConfig.Builder, flyteidl.plugins.Pytorch.ElasticConfigOrBuilder> + getElasticConfigFieldBuilder() { + if (elasticConfigBuilder_ == null) { + elasticConfigBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< + flyteidl.plugins.Pytorch.ElasticConfig, flyteidl.plugins.Pytorch.ElasticConfig.Builder, flyteidl.plugins.Pytorch.ElasticConfigOrBuilder>( + getElasticConfig(), + getParentForChildren(), + isClean()); + elasticConfig_ = null; + } + return elasticConfigBuilder_; + } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + + // @@protoc_insertion_point(builder_scope:flyteidl.plugins.DistributedPyTorchTrainingTask) + } + + // @@protoc_insertion_point(class_scope:flyteidl.plugins.DistributedPyTorchTrainingTask) + private static final flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask DEFAULT_INSTANCE; + static { + DEFAULT_INSTANCE = new flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask(); + } + + public static flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask getDefaultInstance() { + return DEFAULT_INSTANCE; + } + + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { + @java.lang.Override + public DistributedPyTorchTrainingTask parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return new DistributedPyTorchTrainingTask(input, extensionRegistry); + } + }; + + public static com.google.protobuf.Parser parser() { + return PARSER; + } + + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } + + @java.lang.Override + public flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + + } + + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_flyteidl_plugins_ElasticConfig_descriptor; + private static final + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_flyteidl_plugins_ElasticConfig_fieldAccessorTable; + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_flyteidl_plugins_DistributedPyTorchTrainingTask_descriptor; + private static final + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_flyteidl_plugins_DistributedPyTorchTrainingTask_fieldAccessorTable; + + public static com.google.protobuf.Descriptors.FileDescriptor + getDescriptor() { + return descriptor; + } + private static com.google.protobuf.Descriptors.FileDescriptor + descriptor; + static { + java.lang.String[] descriptorData = { + "\n\036flyteidl/plugins/pytorch.proto\022\020flytei" + + "dl.plugins\"\177\n\rElasticConfig\022\024\n\014rdzv_back" + + "end\030\001 \001(\t\022\024\n\014min_replicas\030\002 \001(\005\022\024\n\014max_r" + + "eplicas\030\003 \001(\005\022\026\n\016nproc_per_node\030\004 \001(\005\022\024\n" + + "\014max_restarts\030\005 \001(\005\"j\n\036DistributedPyTorc" + + "hTrainingTask\022\017\n\007workers\030\001 \001(\005\0227\n\016elasti" + + "c_config\030\002 \001(\0132\037.flyteidl.plugins.Elasti" + + "cConfigB9Z7github.com/flyteorg/flyteidl/" + + "gen/pb-go/flyteidl/pluginsb\006proto3" + }; + com.google.protobuf.Descriptors.FileDescriptor.InternalDescriptorAssigner assigner = + new com.google.protobuf.Descriptors.FileDescriptor. InternalDescriptorAssigner() { + public com.google.protobuf.ExtensionRegistry assignDescriptors( + com.google.protobuf.Descriptors.FileDescriptor root) { + descriptor = root; + return null; + } + }; + com.google.protobuf.Descriptors.FileDescriptor + .internalBuildGeneratedFileFrom(descriptorData, + new com.google.protobuf.Descriptors.FileDescriptor[] { + }, assigner); + internal_static_flyteidl_plugins_ElasticConfig_descriptor = + getDescriptor().getMessageTypes().get(0); + internal_static_flyteidl_plugins_ElasticConfig_fieldAccessorTable = new + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + internal_static_flyteidl_plugins_ElasticConfig_descriptor, + new java.lang.String[] { "RdzvBackend", "MinReplicas", "MaxReplicas", "NprocPerNode", "MaxRestarts", }); + internal_static_flyteidl_plugins_DistributedPyTorchTrainingTask_descriptor = + getDescriptor().getMessageTypes().get(1); internal_static_flyteidl_plugins_DistributedPyTorchTrainingTask_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_flyteidl_plugins_DistributedPyTorchTrainingTask_descriptor, - new java.lang.String[] { "Workers", "RDZVBackend", "MinReplicas", "MaxReplicas", "NProcPerNode", "MaxRestarts", }); + new java.lang.String[] { "Workers", "ElasticConfig", }); } // @@protoc_insertion_point(outer_class_scope) diff --git a/gen/pb_python/flyteidl/plugins/pytorch_pb2.py b/gen/pb_python/flyteidl/plugins/pytorch_pb2.py index 7bfa0492d..326a1bc11 100644 --- a/gen/pb_python/flyteidl/plugins/pytorch_pb2.py +++ b/gen/pb_python/flyteidl/plugins/pytorch_pb2.py @@ -13,7 +13,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1e\x66lyteidl/plugins/pytorch.proto\x12\x10\x66lyteidl.plugins\"\xe6\x01\n\x1e\x44istributedPyTorchTrainingTask\x12\x18\n\x07workers\x18\x01 \x01(\x05R\x07workers\x12 \n\x0bRDZVBackend\x18\x02 \x01(\tR\x0bRDZVBackend\x12 \n\x0bminReplicas\x18\x03 \x01(\x05R\x0bminReplicas\x12 \n\x0bmaxReplicas\x18\x04 \x01(\x05R\x0bmaxReplicas\x12\"\n\x0cnProcPerNode\x18\x05 \x01(\x05R\x0cnProcPerNode\x12 \n\x0bmaxRestarts\x18\x06 \x01(\x05R\x0bmaxRestartsB\xbe\x01\n\x14\x63om.flyteidl.pluginsB\x0cPytorchProtoP\x01Z7github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins\xa2\x02\x03\x46PX\xaa\x02\x10\x46lyteidl.Plugins\xca\x02\x10\x46lyteidl\\Plugins\xe2\x02\x1c\x46lyteidl\\Plugins\\GPBMetadata\xea\x02\x11\x46lyteidl::Pluginsb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1e\x66lyteidl/plugins/pytorch.proto\x12\x10\x66lyteidl.plugins\"\xc1\x01\n\rElasticConfig\x12!\n\x0crdzv_backend\x18\x01 \x01(\tR\x0brdzvBackend\x12!\n\x0cmin_replicas\x18\x02 \x01(\x05R\x0bminReplicas\x12!\n\x0cmax_replicas\x18\x03 \x01(\x05R\x0bmaxReplicas\x12$\n\x0enproc_per_node\x18\x04 \x01(\x05R\x0cnprocPerNode\x12!\n\x0cmax_restarts\x18\x05 \x01(\x05R\x0bmaxRestarts\"\x82\x01\n\x1e\x44istributedPyTorchTrainingTask\x12\x18\n\x07workers\x18\x01 \x01(\x05R\x07workers\x12\x46\n\x0e\x65lastic_config\x18\x02 \x01(\x0b\x32\x1f.flyteidl.plugins.ElasticConfigR\relasticConfigB\xbe\x01\n\x14\x63om.flyteidl.pluginsB\x0cPytorchProtoP\x01Z7github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins\xa2\x02\x03\x46PX\xaa\x02\x10\x46lyteidl.Plugins\xca\x02\x10\x46lyteidl\\Plugins\xe2\x02\x1c\x46lyteidl\\Plugins\\GPBMetadata\xea\x02\x11\x46lyteidl::Pluginsb\x06proto3') _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flyteidl.plugins.pytorch_pb2', globals()) @@ -21,6 +21,8 @@ DESCRIPTOR._options = None DESCRIPTOR._serialized_options = b'\n\024com.flyteidl.pluginsB\014PytorchProtoP\001Z7github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins\242\002\003FPX\252\002\020Flyteidl.Plugins\312\002\020Flyteidl\\Plugins\342\002\034Flyteidl\\Plugins\\GPBMetadata\352\002\021Flyteidl::Plugins' - _DISTRIBUTEDPYTORCHTRAININGTASK._serialized_start=53 - _DISTRIBUTEDPYTORCHTRAININGTASK._serialized_end=283 + _ELASTICCONFIG._serialized_start=53 + _ELASTICCONFIG._serialized_end=246 + _DISTRIBUTEDPYTORCHTRAININGTASK._serialized_start=249 + _DISTRIBUTEDPYTORCHTRAININGTASK._serialized_end=379 # @@protoc_insertion_point(module_scope) diff --git a/gen/pb_python/flyteidl/plugins/pytorch_pb2.pyi b/gen/pb_python/flyteidl/plugins/pytorch_pb2.pyi index bde6ac1d2..51aa2d4cf 100644 --- a/gen/pb_python/flyteidl/plugins/pytorch_pb2.pyi +++ b/gen/pb_python/flyteidl/plugins/pytorch_pb2.pyi @@ -1,21 +1,27 @@ from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message -from typing import ClassVar as _ClassVar, Optional as _Optional +from typing import ClassVar as _ClassVar, Mapping as _Mapping, Optional as _Optional, Union as _Union DESCRIPTOR: _descriptor.FileDescriptor class DistributedPyTorchTrainingTask(_message.Message): - __slots__ = ["RDZVBackend", "maxReplicas", "maxRestarts", "minReplicas", "nProcPerNode", "workers"] - MAXREPLICAS_FIELD_NUMBER: _ClassVar[int] - MAXRESTARTS_FIELD_NUMBER: _ClassVar[int] - MINREPLICAS_FIELD_NUMBER: _ClassVar[int] - NPROCPERNODE_FIELD_NUMBER: _ClassVar[int] - RDZVBACKEND_FIELD_NUMBER: _ClassVar[int] - RDZVBackend: str + __slots__ = ["elastic_config", "workers"] + ELASTIC_CONFIG_FIELD_NUMBER: _ClassVar[int] WORKERS_FIELD_NUMBER: _ClassVar[int] - maxReplicas: int - maxRestarts: int - minReplicas: int - nProcPerNode: int + elastic_config: ElasticConfig workers: int - def __init__(self, workers: _Optional[int] = ..., RDZVBackend: _Optional[str] = ..., minReplicas: _Optional[int] = ..., maxReplicas: _Optional[int] = ..., nProcPerNode: _Optional[int] = ..., maxRestarts: _Optional[int] = ...) -> None: ... + def __init__(self, workers: _Optional[int] = ..., elastic_config: _Optional[_Union[ElasticConfig, _Mapping]] = ...) -> None: ... + +class ElasticConfig(_message.Message): + __slots__ = ["max_replicas", "max_restarts", "min_replicas", "nproc_per_node", "rdzv_backend"] + MAX_REPLICAS_FIELD_NUMBER: _ClassVar[int] + MAX_RESTARTS_FIELD_NUMBER: _ClassVar[int] + MIN_REPLICAS_FIELD_NUMBER: _ClassVar[int] + NPROC_PER_NODE_FIELD_NUMBER: _ClassVar[int] + RDZV_BACKEND_FIELD_NUMBER: _ClassVar[int] + max_replicas: int + max_restarts: int + min_replicas: int + nproc_per_node: int + rdzv_backend: str + def __init__(self, rdzv_backend: _Optional[str] = ..., min_replicas: _Optional[int] = ..., max_replicas: _Optional[int] = ..., nproc_per_node: _Optional[int] = ..., max_restarts: _Optional[int] = ...) -> None: ... diff --git a/gen/pb_rust/flyteidl.plugins.rs b/gen/pb_rust/flyteidl.plugins.rs index 922cd86ab..2bc8b08c2 100644 --- a/gen/pb_rust/flyteidl.plugins.rs +++ b/gen/pb_rust/flyteidl.plugins.rs @@ -103,6 +103,22 @@ pub struct PrestoQuery { #[prost(string, tag="4")] pub statement: ::prost::alloc::string::String, } +/// Custom proto for torch elastic config for distributed training using +/// +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ElasticConfig { + #[prost(string, tag="1")] + pub rdzv_backend: ::prost::alloc::string::String, + #[prost(int32, tag="2")] + pub min_replicas: i32, + #[prost(int32, tag="3")] + pub max_replicas: i32, + #[prost(int32, tag="4")] + pub nproc_per_node: i32, + #[prost(int32, tag="5")] + pub max_restarts: i32, +} /// Custom proto for plugin that enables distributed training using #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -111,17 +127,9 @@ pub struct DistributedPyTorchTrainingTask { #[prost(int32, tag="1")] pub workers: i32, /// config for an elastic pytorch job - /// - #[prost(string, tag="2")] - pub rdzv_backend: ::prost::alloc::string::String, - #[prost(int32, tag="3")] - pub min_replicas: i32, - #[prost(int32, tag="4")] - pub max_replicas: i32, - #[prost(int32, tag="5")] - pub n_proc_per_node: i32, - #[prost(int32, tag="6")] - pub max_restarts: i32, + /// + #[prost(message, optional, tag="2")] + pub elastic_config: ::core::option::Option, } /// Defines a query to execute on a hive cluster. #[allow(clippy::derive_partial_eq_without_eq)] diff --git a/protos/flyteidl/plugins/pytorch.proto b/protos/flyteidl/plugins/pytorch.proto index 23191b276..2e219d82b 100644 --- a/protos/flyteidl/plugins/pytorch.proto +++ b/protos/flyteidl/plugins/pytorch.proto @@ -4,16 +4,22 @@ package flyteidl.plugins; option go_package = "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins"; +// Custom proto for torch elastic config for distributed training using +// https://github.com/kubeflow/training-operator/blob/master/pkg/apis/kubeflow.org/v1/pytorch_types.go +message ElasticConfig { + string rdzv_backend = 1; + int32 min_replicas = 2; + int32 max_replicas = 3; + int32 nproc_per_node = 4; + int32 max_restarts = 5; +} + // Custom proto for plugin that enables distributed training using https://github.com/kubeflow/pytorch-operator message DistributedPyTorchTrainingTask { // number of worker replicas spawned in the cluster for this job int32 workers = 1; // config for an elastic pytorch job - // https://github.com/kubeflow/training-operator/blob/master/pkg/apis/kubeflow.org/v1/pytorch_types.go - string RDZVBackend = 2; - int32 minReplicas = 3; - int32 maxReplicas = 4; - int32 nProcPerNode = 5; - int32 maxRestarts = 6; + // + ElasticConfig elastic_config = 2; }