diff --git a/gen/pb-cpp/flyteidl/plugins/pytorch.pb.cc b/gen/pb-cpp/flyteidl/plugins/pytorch.pb.cc index 9b1951b38..4f4096218 100644 --- a/gen/pb-cpp/flyteidl/plugins/pytorch.pb.cc +++ b/gen/pb-cpp/flyteidl/plugins/pytorch.pb.cc @@ -53,6 +53,11 @@ const ::google::protobuf::uint32 TableStruct_flyteidl_2fplugins_2fpytorch_2eprot ~0u, // no _oneof_case_ ~0u, // no _weak_field_map_ PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::DistributedPyTorchTrainingTask, workers_), + PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::DistributedPyTorchTrainingTask, rdzvbackend_), + PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::DistributedPyTorchTrainingTask, minreplicas_), + PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::DistributedPyTorchTrainingTask, maxreplicas_), + PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::DistributedPyTorchTrainingTask, nprocpernode_), + PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::DistributedPyTorchTrainingTask, maxrestarts_), }; static const ::google::protobuf::internal::MigrationSchema schemas[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = { { 0, -1, sizeof(::flyteidl::plugins::DistributedPyTorchTrainingTask)}, @@ -70,15 +75,17 @@ ::google::protobuf::internal::AssignDescriptorsTable assign_descriptors_table_fl const char descriptor_table_protodef_flyteidl_2fplugins_2fpytorch_2eproto[] = "\n\036flyteidl/plugins/pytorch.proto\022\020flytei" - "dl.plugins\"1\n\036DistributedPyTorchTraining" - "Task\022\017\n\007workers\030\001 \001(\005B9Z7github.com/flyt" - "eorg/flyteidl/gen/pb-go/flyteidl/plugins" - "b\006proto3" + "dl.plugins\"\233\001\n\036DistributedPyTorchTrainin" + "gTask\022\017\n\007workers\030\001 \001(\005\022\023\n\013RDZVBackend\030\002 " + "\001(\t\022\023\n\013minReplicas\030\003 \001(\005\022\023\n\013maxReplicas\030" + "\004 \001(\005\022\024\n\014nProcPerNode\030\005 \001(\005\022\023\n\013maxRestar" + "ts\030\006 \001(\005B9Z7github.com/flyteorg/flyteidl" + "/gen/pb-go/flyteidl/pluginsb\006proto3" ; ::google::protobuf::internal::DescriptorTable descriptor_table_flyteidl_2fplugins_2fpytorch_2eproto = { false, InitDefaults_flyteidl_2fplugins_2fpytorch_2eproto, descriptor_table_protodef_flyteidl_2fplugins_2fpytorch_2eproto, - "flyteidl/plugins/pytorch.proto", &assign_descriptors_table_flyteidl_2fplugins_2fpytorch_2eproto, 168, + "flyteidl/plugins/pytorch.proto", &assign_descriptors_table_flyteidl_2fplugins_2fpytorch_2eproto, 275, }; void AddDescriptors_flyteidl_2fplugins_2fpytorch_2eproto() { @@ -103,6 +110,11 @@ class DistributedPyTorchTrainingTask::HasBitSetters { #if !defined(_MSC_VER) || _MSC_VER >= 1900 const int DistributedPyTorchTrainingTask::kWorkersFieldNumber; +const int DistributedPyTorchTrainingTask::kRDZVBackendFieldNumber; +const int DistributedPyTorchTrainingTask::kMinReplicasFieldNumber; +const int DistributedPyTorchTrainingTask::kMaxReplicasFieldNumber; +const int DistributedPyTorchTrainingTask::kNProcPerNodeFieldNumber; +const int DistributedPyTorchTrainingTask::kMaxRestartsFieldNumber; #endif // !defined(_MSC_VER) || _MSC_VER >= 1900 DistributedPyTorchTrainingTask::DistributedPyTorchTrainingTask() @@ -114,12 +126,23 @@ DistributedPyTorchTrainingTask::DistributedPyTorchTrainingTask(const Distributed : ::google::protobuf::Message(), _internal_metadata_(nullptr) { _internal_metadata_.MergeFrom(from._internal_metadata_); - workers_ = from.workers_; + rdzvbackend_.UnsafeSetDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); + if (from.rdzvbackend().size() > 0) { + rdzvbackend_.AssignWithDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), from.rdzvbackend_); + } + ::memcpy(&workers_, &from.workers_, + static_cast(reinterpret_cast(&maxrestarts_) - + reinterpret_cast(&workers_)) + sizeof(maxrestarts_)); // @@protoc_insertion_point(copy_constructor:flyteidl.plugins.DistributedPyTorchTrainingTask) } void DistributedPyTorchTrainingTask::SharedCtor() { - workers_ = 0; + ::google::protobuf::internal::InitSCC( + &scc_info_DistributedPyTorchTrainingTask_flyteidl_2fplugins_2fpytorch_2eproto.base); + rdzvbackend_.UnsafeSetDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); + ::memset(&workers_, 0, static_cast( + reinterpret_cast(&maxrestarts_) - + reinterpret_cast(&workers_)) + sizeof(maxrestarts_)); } DistributedPyTorchTrainingTask::~DistributedPyTorchTrainingTask() { @@ -128,6 +151,7 @@ DistributedPyTorchTrainingTask::~DistributedPyTorchTrainingTask() { } void DistributedPyTorchTrainingTask::SharedDtor() { + rdzvbackend_.DestroyNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } void DistributedPyTorchTrainingTask::SetCachedSize(int size) const { @@ -145,7 +169,10 @@ void DistributedPyTorchTrainingTask::Clear() { // Prevent compiler warnings about cached_has_bits being unused (void) cached_has_bits; - workers_ = 0; + rdzvbackend_.ClearToEmptyNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); + ::memset(&workers_, 0, static_cast( + reinterpret_cast(&maxrestarts_) - + reinterpret_cast(&workers_)) + sizeof(maxrestarts_)); _internal_metadata_.Clear(); } @@ -169,6 +196,50 @@ const char* DistributedPyTorchTrainingTask::_InternalParse(const char* begin, co GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); break; } + // string RDZVBackend = 2; + case 2: { + if (static_cast<::google::protobuf::uint8>(tag) != 18) goto handle_unusual; + ptr = ::google::protobuf::io::ReadSize(ptr, &size); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + ctx->extra_parse_data().SetFieldName("flyteidl.plugins.DistributedPyTorchTrainingTask.RDZVBackend"); + object = msg->mutable_rdzvbackend(); + if (size > end - ptr + ::google::protobuf::internal::ParseContext::kSlopBytes) { + parser_till_end = ::google::protobuf::internal::GreedyStringParserUTF8; + goto string_till_end; + } + GOOGLE_PROTOBUF_PARSER_ASSERT(::google::protobuf::internal::StringCheckUTF8(ptr, size, ctx)); + ::google::protobuf::internal::InlineGreedyStringParser(object, ptr, size, ctx); + ptr += size; + break; + } + // int32 minReplicas = 3; + case 3: { + if (static_cast<::google::protobuf::uint8>(tag) != 24) goto handle_unusual; + msg->set_minreplicas(::google::protobuf::internal::ReadVarint(&ptr)); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + break; + } + // int32 maxReplicas = 4; + case 4: { + if (static_cast<::google::protobuf::uint8>(tag) != 32) goto handle_unusual; + msg->set_maxreplicas(::google::protobuf::internal::ReadVarint(&ptr)); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + break; + } + // int32 nProcPerNode = 5; + case 5: { + if (static_cast<::google::protobuf::uint8>(tag) != 40) goto handle_unusual; + msg->set_nprocpernode(::google::protobuf::internal::ReadVarint(&ptr)); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + break; + } + // int32 maxRestarts = 6; + case 6: { + if (static_cast<::google::protobuf::uint8>(tag) != 48) goto handle_unusual; + msg->set_maxrestarts(::google::protobuf::internal::ReadVarint(&ptr)); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + break; + } default: { handle_unusual: if ((tag & 7) == 4 || tag == 0) { @@ -184,6 +255,13 @@ const char* DistributedPyTorchTrainingTask::_InternalParse(const char* begin, co } // switch } // while return ptr; +string_till_end: + static_cast<::std::string*>(object)->clear(); + static_cast<::std::string*>(object)->reserve(size); + goto len_delim_till_end; +len_delim_till_end: + return ctx->StoreAndTailCall(ptr, end, {_InternalParse, msg}, + {parser_till_end, object}, size); } #else // GOOGLE_PROTOBUF_ENABLE_EXPERIMENTAL_PARSER bool DistributedPyTorchTrainingTask::MergePartialFromCodedStream( @@ -209,6 +287,73 @@ bool DistributedPyTorchTrainingTask::MergePartialFromCodedStream( break; } + // string RDZVBackend = 2; + case 2: { + if (static_cast< ::google::protobuf::uint8>(tag) == (18 & 0xFF)) { + DO_(::google::protobuf::internal::WireFormatLite::ReadString( + input, this->mutable_rdzvbackend())); + DO_(::google::protobuf::internal::WireFormatLite::VerifyUtf8String( + this->rdzvbackend().data(), static_cast(this->rdzvbackend().length()), + ::google::protobuf::internal::WireFormatLite::PARSE, + "flyteidl.plugins.DistributedPyTorchTrainingTask.RDZVBackend")); + } else { + goto handle_unusual; + } + break; + } + + // int32 minReplicas = 3; + case 3: { + if (static_cast< ::google::protobuf::uint8>(tag) == (24 & 0xFF)) { + + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( + input, &minreplicas_))); + } else { + goto handle_unusual; + } + break; + } + + // int32 maxReplicas = 4; + case 4: { + if (static_cast< ::google::protobuf::uint8>(tag) == (32 & 0xFF)) { + + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( + input, &maxreplicas_))); + } else { + goto handle_unusual; + } + break; + } + + // int32 nProcPerNode = 5; + case 5: { + if (static_cast< ::google::protobuf::uint8>(tag) == (40 & 0xFF)) { + + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( + input, &nprocpernode_))); + } else { + goto handle_unusual; + } + break; + } + + // int32 maxRestarts = 6; + case 6: { + if (static_cast< ::google::protobuf::uint8>(tag) == (48 & 0xFF)) { + + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( + input, &maxrestarts_))); + } else { + goto handle_unusual; + } + break; + } + default: { handle_unusual: if (tag == 0) { @@ -241,6 +386,36 @@ void DistributedPyTorchTrainingTask::SerializeWithCachedSizes( ::google::protobuf::internal::WireFormatLite::WriteInt32(1, this->workers(), output); } + // string RDZVBackend = 2; + if (this->rdzvbackend().size() > 0) { + ::google::protobuf::internal::WireFormatLite::VerifyUtf8String( + this->rdzvbackend().data(), static_cast(this->rdzvbackend().length()), + ::google::protobuf::internal::WireFormatLite::SERIALIZE, + "flyteidl.plugins.DistributedPyTorchTrainingTask.RDZVBackend"); + ::google::protobuf::internal::WireFormatLite::WriteStringMaybeAliased( + 2, this->rdzvbackend(), output); + } + + // int32 minReplicas = 3; + if (this->minreplicas() != 0) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(3, this->minreplicas(), output); + } + + // int32 maxReplicas = 4; + if (this->maxreplicas() != 0) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(4, this->maxreplicas(), output); + } + + // int32 nProcPerNode = 5; + if (this->nprocpernode() != 0) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(5, this->nprocpernode(), output); + } + + // int32 maxRestarts = 6; + if (this->maxrestarts() != 0) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(6, this->maxrestarts(), output); + } + if (_internal_metadata_.have_unknown_fields()) { ::google::protobuf::internal::WireFormat::SerializeUnknownFields( _internal_metadata_.unknown_fields(), output); @@ -259,6 +434,37 @@ ::google::protobuf::uint8* DistributedPyTorchTrainingTask::InternalSerializeWith target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(1, this->workers(), target); } + // string RDZVBackend = 2; + if (this->rdzvbackend().size() > 0) { + ::google::protobuf::internal::WireFormatLite::VerifyUtf8String( + this->rdzvbackend().data(), static_cast(this->rdzvbackend().length()), + ::google::protobuf::internal::WireFormatLite::SERIALIZE, + "flyteidl.plugins.DistributedPyTorchTrainingTask.RDZVBackend"); + target = + ::google::protobuf::internal::WireFormatLite::WriteStringToArray( + 2, this->rdzvbackend(), target); + } + + // int32 minReplicas = 3; + if (this->minreplicas() != 0) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(3, this->minreplicas(), target); + } + + // int32 maxReplicas = 4; + if (this->maxreplicas() != 0) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(4, this->maxreplicas(), target); + } + + // int32 nProcPerNode = 5; + if (this->nprocpernode() != 0) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(5, this->nprocpernode(), target); + } + + // int32 maxRestarts = 6; + if (this->maxrestarts() != 0) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(6, this->maxrestarts(), target); + } + if (_internal_metadata_.have_unknown_fields()) { target = ::google::protobuf::internal::WireFormat::SerializeUnknownFieldsToArray( _internal_metadata_.unknown_fields(), target); @@ -280,6 +486,13 @@ size_t DistributedPyTorchTrainingTask::ByteSizeLong() const { // Prevent compiler warnings about cached_has_bits being unused (void) cached_has_bits; + // string RDZVBackend = 2; + if (this->rdzvbackend().size() > 0) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::StringSize( + this->rdzvbackend()); + } + // int32 workers = 1; if (this->workers() != 0) { total_size += 1 + @@ -287,6 +500,34 @@ size_t DistributedPyTorchTrainingTask::ByteSizeLong() const { this->workers()); } + // int32 minReplicas = 3; + if (this->minreplicas() != 0) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::Int32Size( + this->minreplicas()); + } + + // int32 maxReplicas = 4; + if (this->maxreplicas() != 0) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::Int32Size( + this->maxreplicas()); + } + + // int32 nProcPerNode = 5; + if (this->nprocpernode() != 0) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::Int32Size( + this->nprocpernode()); + } + + // int32 maxRestarts = 6; + if (this->maxrestarts() != 0) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::Int32Size( + this->maxrestarts()); + } + int cached_size = ::google::protobuf::internal::ToCachedSize(total_size); SetCachedSize(cached_size); return total_size; @@ -314,9 +555,25 @@ void DistributedPyTorchTrainingTask::MergeFrom(const DistributedPyTorchTrainingT ::google::protobuf::uint32 cached_has_bits = 0; (void) cached_has_bits; + if (from.rdzvbackend().size() > 0) { + + rdzvbackend_.AssignWithDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), from.rdzvbackend_); + } if (from.workers() != 0) { set_workers(from.workers()); } + if (from.minreplicas() != 0) { + set_minreplicas(from.minreplicas()); + } + if (from.maxreplicas() != 0) { + set_maxreplicas(from.maxreplicas()); + } + if (from.nprocpernode() != 0) { + set_nprocpernode(from.nprocpernode()); + } + if (from.maxrestarts() != 0) { + set_maxrestarts(from.maxrestarts()); + } } void DistributedPyTorchTrainingTask::CopyFrom(const ::google::protobuf::Message& from) { @@ -344,7 +601,13 @@ void DistributedPyTorchTrainingTask::Swap(DistributedPyTorchTrainingTask* other) void DistributedPyTorchTrainingTask::InternalSwap(DistributedPyTorchTrainingTask* other) { using std::swap; _internal_metadata_.Swap(&other->_internal_metadata_); + rdzvbackend_.Swap(&other->rdzvbackend_, &::google::protobuf::internal::GetEmptyStringAlreadyInited(), + GetArenaNoVirtual()); swap(workers_, other->workers_); + swap(minreplicas_, other->minreplicas_); + swap(maxreplicas_, other->maxreplicas_); + swap(nprocpernode_, other->nprocpernode_); + swap(maxrestarts_, other->maxrestarts_); } ::google::protobuf::Metadata DistributedPyTorchTrainingTask::GetMetadata() const { diff --git a/gen/pb-cpp/flyteidl/plugins/pytorch.pb.h b/gen/pb-cpp/flyteidl/plugins/pytorch.pb.h index c546d0f48..7a0d15a76 100644 --- a/gen/pb-cpp/flyteidl/plugins/pytorch.pb.h +++ b/gen/pb-cpp/flyteidl/plugins/pytorch.pb.h @@ -160,18 +160,61 @@ class DistributedPyTorchTrainingTask final : // accessors ------------------------------------------------------- + // string RDZVBackend = 2; + void clear_rdzvbackend(); + static const int kRDZVBackendFieldNumber = 2; + const ::std::string& rdzvbackend() const; + void set_rdzvbackend(const ::std::string& value); + #if LANG_CXX11 + void set_rdzvbackend(::std::string&& value); + #endif + void set_rdzvbackend(const char* value); + void set_rdzvbackend(const char* value, size_t size); + ::std::string* mutable_rdzvbackend(); + ::std::string* release_rdzvbackend(); + void set_allocated_rdzvbackend(::std::string* rdzvbackend); + // int32 workers = 1; void clear_workers(); static const int kWorkersFieldNumber = 1; ::google::protobuf::int32 workers() const; void set_workers(::google::protobuf::int32 value); + // int32 minReplicas = 3; + void clear_minreplicas(); + static const int kMinReplicasFieldNumber = 3; + ::google::protobuf::int32 minreplicas() const; + void set_minreplicas(::google::protobuf::int32 value); + + // int32 maxReplicas = 4; + void clear_maxreplicas(); + static const int kMaxReplicasFieldNumber = 4; + ::google::protobuf::int32 maxreplicas() const; + void set_maxreplicas(::google::protobuf::int32 value); + + // int32 nProcPerNode = 5; + void clear_nprocpernode(); + static const int kNProcPerNodeFieldNumber = 5; + ::google::protobuf::int32 nprocpernode() const; + void set_nprocpernode(::google::protobuf::int32 value); + + // int32 maxRestarts = 6; + void clear_maxrestarts(); + static const int kMaxRestartsFieldNumber = 6; + ::google::protobuf::int32 maxrestarts() const; + void set_maxrestarts(::google::protobuf::int32 value); + // @@protoc_insertion_point(class_scope:flyteidl.plugins.DistributedPyTorchTrainingTask) private: class HasBitSetters; ::google::protobuf::internal::InternalMetadataWithArena _internal_metadata_; + ::google::protobuf::internal::ArenaStringPtr rdzvbackend_; ::google::protobuf::int32 workers_; + ::google::protobuf::int32 minreplicas_; + ::google::protobuf::int32 maxreplicas_; + ::google::protobuf::int32 nprocpernode_; + ::google::protobuf::int32 maxrestarts_; mutable ::google::protobuf::internal::CachedSize _cached_size_; friend struct ::TableStruct_flyteidl_2fplugins_2fpytorch_2eproto; }; @@ -200,6 +243,115 @@ inline void DistributedPyTorchTrainingTask::set_workers(::google::protobuf::int3 // @@protoc_insertion_point(field_set:flyteidl.plugins.DistributedPyTorchTrainingTask.workers) } +// string RDZVBackend = 2; +inline void DistributedPyTorchTrainingTask::clear_rdzvbackend() { + rdzvbackend_.ClearToEmptyNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); +} +inline const ::std::string& DistributedPyTorchTrainingTask::rdzvbackend() const { + // @@protoc_insertion_point(field_get:flyteidl.plugins.DistributedPyTorchTrainingTask.RDZVBackend) + return rdzvbackend_.GetNoArena(); +} +inline void DistributedPyTorchTrainingTask::set_rdzvbackend(const ::std::string& value) { + + rdzvbackend_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); + // @@protoc_insertion_point(field_set:flyteidl.plugins.DistributedPyTorchTrainingTask.RDZVBackend) +} +#if LANG_CXX11 +inline void DistributedPyTorchTrainingTask::set_rdzvbackend(::std::string&& value) { + + rdzvbackend_.SetNoArena( + &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); + // @@protoc_insertion_point(field_set_rvalue:flyteidl.plugins.DistributedPyTorchTrainingTask.RDZVBackend) +} +#endif +inline void DistributedPyTorchTrainingTask::set_rdzvbackend(const char* value) { + GOOGLE_DCHECK(value != nullptr); + + rdzvbackend_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); + // @@protoc_insertion_point(field_set_char:flyteidl.plugins.DistributedPyTorchTrainingTask.RDZVBackend) +} +inline void DistributedPyTorchTrainingTask::set_rdzvbackend(const char* value, size_t size) { + + rdzvbackend_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), + ::std::string(reinterpret_cast(value), size)); + // @@protoc_insertion_point(field_set_pointer:flyteidl.plugins.DistributedPyTorchTrainingTask.RDZVBackend) +} +inline ::std::string* DistributedPyTorchTrainingTask::mutable_rdzvbackend() { + + // @@protoc_insertion_point(field_mutable:flyteidl.plugins.DistributedPyTorchTrainingTask.RDZVBackend) + return rdzvbackend_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); +} +inline ::std::string* DistributedPyTorchTrainingTask::release_rdzvbackend() { + // @@protoc_insertion_point(field_release:flyteidl.plugins.DistributedPyTorchTrainingTask.RDZVBackend) + + return rdzvbackend_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); +} +inline void DistributedPyTorchTrainingTask::set_allocated_rdzvbackend(::std::string* rdzvbackend) { + if (rdzvbackend != nullptr) { + + } else { + + } + rdzvbackend_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), rdzvbackend); + // @@protoc_insertion_point(field_set_allocated:flyteidl.plugins.DistributedPyTorchTrainingTask.RDZVBackend) +} + +// int32 minReplicas = 3; +inline void DistributedPyTorchTrainingTask::clear_minreplicas() { + minreplicas_ = 0; +} +inline ::google::protobuf::int32 DistributedPyTorchTrainingTask::minreplicas() const { + // @@protoc_insertion_point(field_get:flyteidl.plugins.DistributedPyTorchTrainingTask.minReplicas) + return minreplicas_; +} +inline void DistributedPyTorchTrainingTask::set_minreplicas(::google::protobuf::int32 value) { + + minreplicas_ = value; + // @@protoc_insertion_point(field_set:flyteidl.plugins.DistributedPyTorchTrainingTask.minReplicas) +} + +// int32 maxReplicas = 4; +inline void DistributedPyTorchTrainingTask::clear_maxreplicas() { + maxreplicas_ = 0; +} +inline ::google::protobuf::int32 DistributedPyTorchTrainingTask::maxreplicas() const { + // @@protoc_insertion_point(field_get:flyteidl.plugins.DistributedPyTorchTrainingTask.maxReplicas) + return maxreplicas_; +} +inline void DistributedPyTorchTrainingTask::set_maxreplicas(::google::protobuf::int32 value) { + + maxreplicas_ = value; + // @@protoc_insertion_point(field_set:flyteidl.plugins.DistributedPyTorchTrainingTask.maxReplicas) +} + +// int32 nProcPerNode = 5; +inline void DistributedPyTorchTrainingTask::clear_nprocpernode() { + nprocpernode_ = 0; +} +inline ::google::protobuf::int32 DistributedPyTorchTrainingTask::nprocpernode() const { + // @@protoc_insertion_point(field_get:flyteidl.plugins.DistributedPyTorchTrainingTask.nProcPerNode) + return nprocpernode_; +} +inline void DistributedPyTorchTrainingTask::set_nprocpernode(::google::protobuf::int32 value) { + + nprocpernode_ = value; + // @@protoc_insertion_point(field_set:flyteidl.plugins.DistributedPyTorchTrainingTask.nProcPerNode) +} + +// int32 maxRestarts = 6; +inline void DistributedPyTorchTrainingTask::clear_maxrestarts() { + maxrestarts_ = 0; +} +inline ::google::protobuf::int32 DistributedPyTorchTrainingTask::maxrestarts() const { + // @@protoc_insertion_point(field_get:flyteidl.plugins.DistributedPyTorchTrainingTask.maxRestarts) + return maxrestarts_; +} +inline void DistributedPyTorchTrainingTask::set_maxrestarts(::google::protobuf::int32 value) { + + maxrestarts_ = value; + // @@protoc_insertion_point(field_set:flyteidl.plugins.DistributedPyTorchTrainingTask.maxRestarts) +} + #ifdef __GNUC__ #pragma GCC diagnostic pop #endif // __GNUC__ diff --git a/gen/pb-go/flyteidl/plugins/pytorch.pb.go b/gen/pb-go/flyteidl/plugins/pytorch.pb.go index 79138e568..c25b64024 100644 --- a/gen/pb-go/flyteidl/plugins/pytorch.pb.go +++ b/gen/pb-go/flyteidl/plugins/pytorch.pb.go @@ -23,7 +23,14 @@ const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package // Custom proto for plugin that enables distributed training using https://github.com/kubeflow/pytorch-operator type DistributedPyTorchTrainingTask struct { // number of worker replicas spawned in the cluster for this job - Workers int32 `protobuf:"varint,1,opt,name=workers,proto3" json:"workers,omitempty"` + Workers int32 `protobuf:"varint,1,opt,name=workers,proto3" json:"workers,omitempty"` + // config for an elastic pytorch job + // https://github.com/kubeflow/training-operator/blob/master/pkg/apis/kubeflow.org/v1/pytorch_types.go + RDZVBackend string `protobuf:"bytes,2,opt,name=RDZVBackend,proto3" json:"RDZVBackend,omitempty"` + MinReplicas int32 `protobuf:"varint,3,opt,name=minReplicas,proto3" json:"minReplicas,omitempty"` + MaxReplicas int32 `protobuf:"varint,4,opt,name=maxReplicas,proto3" json:"maxReplicas,omitempty"` + NProcPerNode int32 `protobuf:"varint,5,opt,name=nProcPerNode,proto3" json:"nProcPerNode,omitempty"` + MaxRestarts int32 `protobuf:"varint,6,opt,name=maxRestarts,proto3" json:"maxRestarts,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` @@ -61,6 +68,41 @@ func (m *DistributedPyTorchTrainingTask) GetWorkers() int32 { return 0 } +func (m *DistributedPyTorchTrainingTask) GetRDZVBackend() string { + if m != nil { + return m.RDZVBackend + } + return "" +} + +func (m *DistributedPyTorchTrainingTask) GetMinReplicas() int32 { + if m != nil { + return m.MinReplicas + } + return 0 +} + +func (m *DistributedPyTorchTrainingTask) GetMaxReplicas() int32 { + if m != nil { + return m.MaxReplicas + } + return 0 +} + +func (m *DistributedPyTorchTrainingTask) GetNProcPerNode() int32 { + if m != nil { + return m.NProcPerNode + } + return 0 +} + +func (m *DistributedPyTorchTrainingTask) GetMaxRestarts() int32 { + if m != nil { + return m.MaxRestarts + } + return 0 +} + func init() { proto.RegisterType((*DistributedPyTorchTrainingTask)(nil), "flyteidl.plugins.DistributedPyTorchTrainingTask") } @@ -68,15 +110,20 @@ func init() { func init() { proto.RegisterFile("flyteidl/plugins/pytorch.proto", fileDescriptor_4df8a9374b28b766) } var fileDescriptor_4df8a9374b28b766 = []byte{ - // 156 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x92, 0x4b, 0xcb, 0xa9, 0x2c, - 0x49, 0xcd, 0x4c, 0xc9, 0xd1, 0x2f, 0xc8, 0x29, 0x4d, 0xcf, 0xcc, 0x2b, 0xd6, 0x2f, 0xa8, 0x2c, - 0xc9, 0x2f, 0x4a, 0xce, 0xd0, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x12, 0x80, 0xc9, 0xeb, 0x41, - 0xe5, 0x95, 0xac, 0xb8, 0xe4, 0x5c, 0x32, 0x8b, 0x4b, 0x8a, 0x32, 0x93, 0x4a, 0x4b, 0x52, 0x53, - 0x02, 0x2a, 0x43, 0x40, 0xaa, 0x43, 0x8a, 0x12, 0x33, 0xf3, 0x32, 0xf3, 0xd2, 0x43, 0x12, 0x8b, - 0xb3, 0x85, 0x24, 0xb8, 0xd8, 0xcb, 0xf3, 0x8b, 0xb2, 0x53, 0x8b, 0x8a, 0x25, 0x18, 0x15, 0x18, - 0x35, 0x58, 0x83, 0x60, 0x5c, 0x27, 0xcb, 0x28, 0xf3, 0xf4, 0xcc, 0x92, 0x8c, 0xd2, 0x24, 0xbd, - 0xe4, 0xfc, 0x5c, 0x7d, 0xb0, 0xd1, 0xf9, 0x45, 0xe9, 0xfa, 0x70, 0x37, 0xa4, 0xa7, 0xe6, 0xe9, - 0x17, 0x24, 0xe9, 0xa6, 0xe7, 0xeb, 0xa3, 0x3b, 0x2b, 0x89, 0x0d, 0xec, 0x1e, 0x63, 0x40, 0x00, - 0x00, 0x00, 0xff, 0xff, 0x91, 0x53, 0x3a, 0xa1, 0xb1, 0x00, 0x00, 0x00, + // 238 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x5c, 0x90, 0x31, 0x4f, 0xc3, 0x30, + 0x10, 0x46, 0x15, 0xa0, 0x45, 0x18, 0x06, 0x94, 0xc9, 0x53, 0x15, 0x75, 0xea, 0x42, 0x3c, 0x30, + 0x20, 0xd6, 0xaa, 0x33, 0x8a, 0xa2, 0x88, 0xa1, 0x9b, 0xe3, 0x18, 0xf7, 0x94, 0xd4, 0x67, 0x9d, + 0x1d, 0x41, 0xfe, 0x30, 0xbf, 0x03, 0xd5, 0x34, 0x21, 0x74, 0xf4, 0xbb, 0xe7, 0x6f, 0x78, 0x6c, + 0xf5, 0xd1, 0x0d, 0x41, 0x43, 0xd3, 0x09, 0xd7, 0xf5, 0x06, 0xac, 0x17, 0x6e, 0x08, 0x48, 0xea, + 0x90, 0x3b, 0xc2, 0x80, 0xe9, 0xe3, 0x78, 0xcf, 0xcf, 0xf7, 0xf5, 0x77, 0xc2, 0x56, 0x3b, 0xf0, + 0x81, 0xa0, 0xee, 0x83, 0x6e, 0x8a, 0xa1, 0x3a, 0xe9, 0x15, 0x49, 0xb0, 0x60, 0x4d, 0x25, 0x7d, + 0x9b, 0x72, 0x76, 0xfb, 0x89, 0xd4, 0x6a, 0xf2, 0x3c, 0xc9, 0x92, 0xcd, 0xa2, 0x1c, 0x9f, 0x69, + 0xc6, 0xee, 0xcb, 0xdd, 0xfe, 0x7d, 0x2b, 0x55, 0xab, 0x6d, 0xc3, 0xaf, 0xb2, 0x64, 0x73, 0x57, + 0xce, 0xd1, 0xc9, 0x38, 0x82, 0x2d, 0xb5, 0xeb, 0x40, 0x49, 0xcf, 0xaf, 0xe3, 0xff, 0x39, 0x8a, + 0x86, 0xfc, 0x9a, 0x8c, 0x9b, 0xb3, 0xf1, 0x87, 0xd2, 0x35, 0x7b, 0xb0, 0x05, 0xa1, 0x2a, 0x34, + 0xbd, 0x61, 0xa3, 0xf9, 0x22, 0x2a, 0xff, 0xd8, 0xb4, 0xe2, 0x83, 0xa4, 0xe0, 0xf9, 0x72, 0xb6, + 0xf2, 0x8b, 0xb6, 0xaf, 0xfb, 0x17, 0x03, 0xe1, 0xd0, 0xd7, 0xb9, 0xc2, 0xa3, 0x88, 0x1d, 0x90, + 0x8c, 0x98, 0x82, 0x19, 0x6d, 0x85, 0xab, 0x9f, 0x0c, 0x8a, 0xcb, 0x86, 0xf5, 0x32, 0xc6, 0x7b, + 0xfe, 0x09, 0x00, 0x00, 0xff, 0xff, 0x5b, 0x03, 0x64, 0x09, 0x5e, 0x01, 0x00, 0x00, } diff --git a/gen/pb-go/flyteidl/plugins/pytorch.pb.validate.go b/gen/pb-go/flyteidl/plugins/pytorch.pb.validate.go index 8e6af9852..ae49a4b14 100644 --- a/gen/pb-go/flyteidl/plugins/pytorch.pb.validate.go +++ b/gen/pb-go/flyteidl/plugins/pytorch.pb.validate.go @@ -46,6 +46,16 @@ func (m *DistributedPyTorchTrainingTask) Validate() error { // no validation rules for Workers + // no validation rules for RDZVBackend + + // no validation rules for MinReplicas + + // no validation rules for MaxReplicas + + // no validation rules for NProcPerNode + + // no validation rules for MaxRestarts + return nil } diff --git a/gen/pb-java/flyteidl/plugins/Pytorch.java b/gen/pb-java/flyteidl/plugins/Pytorch.java index a7709263f..feba13fb8 100644 --- a/gen/pb-java/flyteidl/plugins/Pytorch.java +++ b/gen/pb-java/flyteidl/plugins/Pytorch.java @@ -26,6 +26,46 @@ public interface DistributedPyTorchTrainingTaskOrBuilder extends * int32 workers = 1; */ int getWorkers(); + + /** + *
+     * config for an elastic pytorch job
+     * https://github.com/kubeflow/training-operator/blob/master/pkg/apis/kubeflow.org/v1/pytorch_types.go
+     * 
+ * + * string RDZVBackend = 2; + */ + java.lang.String getRDZVBackend(); + /** + *
+     * config for an elastic pytorch job
+     * https://github.com/kubeflow/training-operator/blob/master/pkg/apis/kubeflow.org/v1/pytorch_types.go
+     * 
+ * + * string RDZVBackend = 2; + */ + com.google.protobuf.ByteString + getRDZVBackendBytes(); + + /** + * int32 minReplicas = 3; + */ + int getMinReplicas(); + + /** + * int32 maxReplicas = 4; + */ + int getMaxReplicas(); + + /** + * int32 nProcPerNode = 5; + */ + int getNProcPerNode(); + + /** + * int32 maxRestarts = 6; + */ + int getMaxRestarts(); } /** *
@@ -44,6 +84,7 @@ private DistributedPyTorchTrainingTask(com.google.protobuf.GeneratedMessageV3.Bu
       super(builder);
     }
     private DistributedPyTorchTrainingTask() {
+      rDZVBackend_ = "";
     }
 
     @java.lang.Override
@@ -75,6 +116,32 @@ private DistributedPyTorchTrainingTask(
               workers_ = input.readInt32();
               break;
             }
+            case 18: {
+              java.lang.String s = input.readStringRequireUtf8();
+
+              rDZVBackend_ = s;
+              break;
+            }
+            case 24: {
+
+              minReplicas_ = input.readInt32();
+              break;
+            }
+            case 32: {
+
+              maxReplicas_ = input.readInt32();
+              break;
+            }
+            case 40: {
+
+              nProcPerNode_ = input.readInt32();
+              break;
+            }
+            case 48: {
+
+              maxRestarts_ = input.readInt32();
+              break;
+            }
             default: {
               if (!parseUnknownField(
                   input, unknownFields, extensionRegistry, tag)) {
@@ -120,6 +187,86 @@ public int getWorkers() {
       return workers_;
     }
 
+    public static final int RDZVBACKEND_FIELD_NUMBER = 2;
+    private volatile java.lang.Object rDZVBackend_;
+    /**
+     * 
+     * config for an elastic pytorch job
+     * https://github.com/kubeflow/training-operator/blob/master/pkg/apis/kubeflow.org/v1/pytorch_types.go
+     * 
+ * + * string RDZVBackend = 2; + */ + public java.lang.String getRDZVBackend() { + java.lang.Object ref = rDZVBackend_; + if (ref instanceof java.lang.String) { + return (java.lang.String) ref; + } else { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + rDZVBackend_ = s; + return s; + } + } + /** + *
+     * config for an elastic pytorch job
+     * https://github.com/kubeflow/training-operator/blob/master/pkg/apis/kubeflow.org/v1/pytorch_types.go
+     * 
+ * + * string RDZVBackend = 2; + */ + public com.google.protobuf.ByteString + getRDZVBackendBytes() { + java.lang.Object ref = rDZVBackend_; + if (ref instanceof java.lang.String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + rDZVBackend_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + + public static final int MINREPLICAS_FIELD_NUMBER = 3; + private int minReplicas_; + /** + * int32 minReplicas = 3; + */ + public int getMinReplicas() { + return minReplicas_; + } + + public static final int MAXREPLICAS_FIELD_NUMBER = 4; + private int maxReplicas_; + /** + * int32 maxReplicas = 4; + */ + public int getMaxReplicas() { + return maxReplicas_; + } + + public static final int NPROCPERNODE_FIELD_NUMBER = 5; + private int nProcPerNode_; + /** + * int32 nProcPerNode = 5; + */ + public int getNProcPerNode() { + return nProcPerNode_; + } + + public static final int MAXRESTARTS_FIELD_NUMBER = 6; + private int maxRestarts_; + /** + * int32 maxRestarts = 6; + */ + public int getMaxRestarts() { + return maxRestarts_; + } + private byte memoizedIsInitialized = -1; @java.lang.Override public final boolean isInitialized() { @@ -137,6 +284,21 @@ public void writeTo(com.google.protobuf.CodedOutputStream output) if (workers_ != 0) { output.writeInt32(1, workers_); } + if (!getRDZVBackendBytes().isEmpty()) { + com.google.protobuf.GeneratedMessageV3.writeString(output, 2, rDZVBackend_); + } + if (minReplicas_ != 0) { + output.writeInt32(3, minReplicas_); + } + if (maxReplicas_ != 0) { + output.writeInt32(4, maxReplicas_); + } + if (nProcPerNode_ != 0) { + output.writeInt32(5, nProcPerNode_); + } + if (maxRestarts_ != 0) { + output.writeInt32(6, maxRestarts_); + } unknownFields.writeTo(output); } @@ -150,6 +312,25 @@ public int getSerializedSize() { size += com.google.protobuf.CodedOutputStream .computeInt32Size(1, workers_); } + if (!getRDZVBackendBytes().isEmpty()) { + size += com.google.protobuf.GeneratedMessageV3.computeStringSize(2, rDZVBackend_); + } + if (minReplicas_ != 0) { + size += com.google.protobuf.CodedOutputStream + .computeInt32Size(3, minReplicas_); + } + if (maxReplicas_ != 0) { + size += com.google.protobuf.CodedOutputStream + .computeInt32Size(4, maxReplicas_); + } + if (nProcPerNode_ != 0) { + size += com.google.protobuf.CodedOutputStream + .computeInt32Size(5, nProcPerNode_); + } + if (maxRestarts_ != 0) { + size += com.google.protobuf.CodedOutputStream + .computeInt32Size(6, maxRestarts_); + } size += unknownFields.getSerializedSize(); memoizedSize = size; return size; @@ -167,6 +348,16 @@ public boolean equals(final java.lang.Object obj) { if (getWorkers() != other.getWorkers()) return false; + if (!getRDZVBackend() + .equals(other.getRDZVBackend())) return false; + if (getMinReplicas() + != other.getMinReplicas()) return false; + if (getMaxReplicas() + != other.getMaxReplicas()) return false; + if (getNProcPerNode() + != other.getNProcPerNode()) return false; + if (getMaxRestarts() + != other.getMaxRestarts()) return false; if (!unknownFields.equals(other.unknownFields)) return false; return true; } @@ -180,6 +371,16 @@ public int hashCode() { hash = (19 * hash) + getDescriptor().hashCode(); hash = (37 * hash) + WORKERS_FIELD_NUMBER; hash = (53 * hash) + getWorkers(); + hash = (37 * hash) + RDZVBACKEND_FIELD_NUMBER; + hash = (53 * hash) + getRDZVBackend().hashCode(); + hash = (37 * hash) + MINREPLICAS_FIELD_NUMBER; + hash = (53 * hash) + getMinReplicas(); + hash = (37 * hash) + MAXREPLICAS_FIELD_NUMBER; + hash = (53 * hash) + getMaxReplicas(); + hash = (37 * hash) + NPROCPERNODE_FIELD_NUMBER; + hash = (53 * hash) + getNProcPerNode(); + hash = (37 * hash) + MAXRESTARTS_FIELD_NUMBER; + hash = (53 * hash) + getMaxRestarts(); hash = (29 * hash) + unknownFields.hashCode(); memoizedHashCode = hash; return hash; @@ -319,6 +520,16 @@ public Builder clear() { super.clear(); workers_ = 0; + rDZVBackend_ = ""; + + minReplicas_ = 0; + + maxReplicas_ = 0; + + nProcPerNode_ = 0; + + maxRestarts_ = 0; + return this; } @@ -346,6 +557,11 @@ public flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask build() { public flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask buildPartial() { flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask result = new flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask(this); result.workers_ = workers_; + result.rDZVBackend_ = rDZVBackend_; + result.minReplicas_ = minReplicas_; + result.maxReplicas_ = maxReplicas_; + result.nProcPerNode_ = nProcPerNode_; + result.maxRestarts_ = maxRestarts_; onBuilt(); return result; } @@ -397,6 +613,22 @@ public Builder mergeFrom(flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask if (other.getWorkers() != 0) { setWorkers(other.getWorkers()); } + if (!other.getRDZVBackend().isEmpty()) { + rDZVBackend_ = other.rDZVBackend_; + onChanged(); + } + if (other.getMinReplicas() != 0) { + setMinReplicas(other.getMinReplicas()); + } + if (other.getMaxReplicas() != 0) { + setMaxReplicas(other.getMaxReplicas()); + } + if (other.getNProcPerNode() != 0) { + setNProcPerNode(other.getNProcPerNode()); + } + if (other.getMaxRestarts() != 0) { + setMaxRestarts(other.getMaxRestarts()); + } this.mergeUnknownFields(other.unknownFields); onChanged(); return this; @@ -463,6 +695,204 @@ public Builder clearWorkers() { onChanged(); return this; } + + private java.lang.Object rDZVBackend_ = ""; + /** + *
+       * config for an elastic pytorch job
+       * https://github.com/kubeflow/training-operator/blob/master/pkg/apis/kubeflow.org/v1/pytorch_types.go
+       * 
+ * + * string RDZVBackend = 2; + */ + public java.lang.String getRDZVBackend() { + java.lang.Object ref = rDZVBackend_; + if (!(ref instanceof java.lang.String)) { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + rDZVBackend_ = s; + return s; + } else { + return (java.lang.String) ref; + } + } + /** + *
+       * config for an elastic pytorch job
+       * https://github.com/kubeflow/training-operator/blob/master/pkg/apis/kubeflow.org/v1/pytorch_types.go
+       * 
+ * + * string RDZVBackend = 2; + */ + public com.google.protobuf.ByteString + getRDZVBackendBytes() { + java.lang.Object ref = rDZVBackend_; + if (ref instanceof String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + rDZVBackend_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + /** + *
+       * config for an elastic pytorch job
+       * https://github.com/kubeflow/training-operator/blob/master/pkg/apis/kubeflow.org/v1/pytorch_types.go
+       * 
+ * + * string RDZVBackend = 2; + */ + public Builder setRDZVBackend( + java.lang.String value) { + if (value == null) { + throw new NullPointerException(); + } + + rDZVBackend_ = value; + onChanged(); + return this; + } + /** + *
+       * config for an elastic pytorch job
+       * https://github.com/kubeflow/training-operator/blob/master/pkg/apis/kubeflow.org/v1/pytorch_types.go
+       * 
+ * + * string RDZVBackend = 2; + */ + public Builder clearRDZVBackend() { + + rDZVBackend_ = getDefaultInstance().getRDZVBackend(); + onChanged(); + return this; + } + /** + *
+       * config for an elastic pytorch job
+       * https://github.com/kubeflow/training-operator/blob/master/pkg/apis/kubeflow.org/v1/pytorch_types.go
+       * 
+ * + * string RDZVBackend = 2; + */ + public Builder setRDZVBackendBytes( + com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + checkByteStringIsUtf8(value); + + rDZVBackend_ = value; + onChanged(); + return this; + } + + private int minReplicas_ ; + /** + * int32 minReplicas = 3; + */ + public int getMinReplicas() { + return minReplicas_; + } + /** + * int32 minReplicas = 3; + */ + public Builder setMinReplicas(int value) { + + minReplicas_ = value; + onChanged(); + return this; + } + /** + * int32 minReplicas = 3; + */ + public Builder clearMinReplicas() { + + minReplicas_ = 0; + onChanged(); + return this; + } + + private int maxReplicas_ ; + /** + * int32 maxReplicas = 4; + */ + public int getMaxReplicas() { + return maxReplicas_; + } + /** + * int32 maxReplicas = 4; + */ + public Builder setMaxReplicas(int value) { + + maxReplicas_ = value; + onChanged(); + return this; + } + /** + * int32 maxReplicas = 4; + */ + public Builder clearMaxReplicas() { + + maxReplicas_ = 0; + onChanged(); + return this; + } + + private int nProcPerNode_ ; + /** + * int32 nProcPerNode = 5; + */ + public int getNProcPerNode() { + return nProcPerNode_; + } + /** + * int32 nProcPerNode = 5; + */ + public Builder setNProcPerNode(int value) { + + nProcPerNode_ = value; + onChanged(); + return this; + } + /** + * int32 nProcPerNode = 5; + */ + public Builder clearNProcPerNode() { + + nProcPerNode_ = 0; + onChanged(); + return this; + } + + private int maxRestarts_ ; + /** + * int32 maxRestarts = 6; + */ + public int getMaxRestarts() { + return maxRestarts_; + } + /** + * int32 maxRestarts = 6; + */ + public Builder setMaxRestarts(int value) { + + maxRestarts_ = value; + onChanged(); + return this; + } + /** + * int32 maxRestarts = 6; + */ + public Builder clearMaxRestarts() { + + maxRestarts_ = 0; + onChanged(); + return this; + } @java.lang.Override public final Builder setUnknownFields( final com.google.protobuf.UnknownFieldSet unknownFields) { @@ -531,10 +961,12 @@ public flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask getDefaultInstanc static { java.lang.String[] descriptorData = { "\n\036flyteidl/plugins/pytorch.proto\022\020flytei" + - "dl.plugins\"1\n\036DistributedPyTorchTraining" + - "Task\022\017\n\007workers\030\001 \001(\005B9Z7github.com/flyt" + - "eorg/flyteidl/gen/pb-go/flyteidl/plugins" + - "b\006proto3" + "dl.plugins\"\233\001\n\036DistributedPyTorchTrainin" + + "gTask\022\017\n\007workers\030\001 \001(\005\022\023\n\013RDZVBackend\030\002 " + + "\001(\t\022\023\n\013minReplicas\030\003 \001(\005\022\023\n\013maxReplicas\030" + + "\004 \001(\005\022\024\n\014nProcPerNode\030\005 \001(\005\022\023\n\013maxRestar" + + "ts\030\006 \001(\005B9Z7github.com/flyteorg/flyteidl" + + "/gen/pb-go/flyteidl/pluginsb\006proto3" }; com.google.protobuf.Descriptors.FileDescriptor.InternalDescriptorAssigner assigner = new com.google.protobuf.Descriptors.FileDescriptor. InternalDescriptorAssigner() { @@ -553,7 +985,7 @@ public com.google.protobuf.ExtensionRegistry assignDescriptors( internal_static_flyteidl_plugins_DistributedPyTorchTrainingTask_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_flyteidl_plugins_DistributedPyTorchTrainingTask_descriptor, - new java.lang.String[] { "Workers", }); + new java.lang.String[] { "Workers", "RDZVBackend", "MinReplicas", "MaxReplicas", "NProcPerNode", "MaxRestarts", }); } // @@protoc_insertion_point(outer_class_scope) diff --git a/gen/pb_python/flyteidl/plugins/pytorch_pb2.py b/gen/pb_python/flyteidl/plugins/pytorch_pb2.py index c2afe8cbb..7bfa0492d 100644 --- a/gen/pb_python/flyteidl/plugins/pytorch_pb2.py +++ b/gen/pb_python/flyteidl/plugins/pytorch_pb2.py @@ -13,7 +13,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1e\x66lyteidl/plugins/pytorch.proto\x12\x10\x66lyteidl.plugins\":\n\x1e\x44istributedPyTorchTrainingTask\x12\x18\n\x07workers\x18\x01 \x01(\x05R\x07workersB\xbe\x01\n\x14\x63om.flyteidl.pluginsB\x0cPytorchProtoP\x01Z7github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins\xa2\x02\x03\x46PX\xaa\x02\x10\x46lyteidl.Plugins\xca\x02\x10\x46lyteidl\\Plugins\xe2\x02\x1c\x46lyteidl\\Plugins\\GPBMetadata\xea\x02\x11\x46lyteidl::Pluginsb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1e\x66lyteidl/plugins/pytorch.proto\x12\x10\x66lyteidl.plugins\"\xe6\x01\n\x1e\x44istributedPyTorchTrainingTask\x12\x18\n\x07workers\x18\x01 \x01(\x05R\x07workers\x12 \n\x0bRDZVBackend\x18\x02 \x01(\tR\x0bRDZVBackend\x12 \n\x0bminReplicas\x18\x03 \x01(\x05R\x0bminReplicas\x12 \n\x0bmaxReplicas\x18\x04 \x01(\x05R\x0bmaxReplicas\x12\"\n\x0cnProcPerNode\x18\x05 \x01(\x05R\x0cnProcPerNode\x12 \n\x0bmaxRestarts\x18\x06 \x01(\x05R\x0bmaxRestartsB\xbe\x01\n\x14\x63om.flyteidl.pluginsB\x0cPytorchProtoP\x01Z7github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins\xa2\x02\x03\x46PX\xaa\x02\x10\x46lyteidl.Plugins\xca\x02\x10\x46lyteidl\\Plugins\xe2\x02\x1c\x46lyteidl\\Plugins\\GPBMetadata\xea\x02\x11\x46lyteidl::Pluginsb\x06proto3') _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flyteidl.plugins.pytorch_pb2', globals()) @@ -21,6 +21,6 @@ DESCRIPTOR._options = None DESCRIPTOR._serialized_options = b'\n\024com.flyteidl.pluginsB\014PytorchProtoP\001Z7github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins\242\002\003FPX\252\002\020Flyteidl.Plugins\312\002\020Flyteidl\\Plugins\342\002\034Flyteidl\\Plugins\\GPBMetadata\352\002\021Flyteidl::Plugins' - _DISTRIBUTEDPYTORCHTRAININGTASK._serialized_start=52 - _DISTRIBUTEDPYTORCHTRAININGTASK._serialized_end=110 + _DISTRIBUTEDPYTORCHTRAININGTASK._serialized_start=53 + _DISTRIBUTEDPYTORCHTRAININGTASK._serialized_end=283 # @@protoc_insertion_point(module_scope) diff --git a/gen/pb_python/flyteidl/plugins/pytorch_pb2.pyi b/gen/pb_python/flyteidl/plugins/pytorch_pb2.pyi index c7f5cef66..bde6ac1d2 100644 --- a/gen/pb_python/flyteidl/plugins/pytorch_pb2.pyi +++ b/gen/pb_python/flyteidl/plugins/pytorch_pb2.pyi @@ -5,7 +5,17 @@ from typing import ClassVar as _ClassVar, Optional as _Optional DESCRIPTOR: _descriptor.FileDescriptor class DistributedPyTorchTrainingTask(_message.Message): - __slots__ = ["workers"] + __slots__ = ["RDZVBackend", "maxReplicas", "maxRestarts", "minReplicas", "nProcPerNode", "workers"] + MAXREPLICAS_FIELD_NUMBER: _ClassVar[int] + MAXRESTARTS_FIELD_NUMBER: _ClassVar[int] + MINREPLICAS_FIELD_NUMBER: _ClassVar[int] + NPROCPERNODE_FIELD_NUMBER: _ClassVar[int] + RDZVBACKEND_FIELD_NUMBER: _ClassVar[int] + RDZVBackend: str WORKERS_FIELD_NUMBER: _ClassVar[int] + maxReplicas: int + maxRestarts: int + minReplicas: int + nProcPerNode: int workers: int - def __init__(self, workers: _Optional[int] = ...) -> None: ... + def __init__(self, workers: _Optional[int] = ..., RDZVBackend: _Optional[str] = ..., minReplicas: _Optional[int] = ..., maxReplicas: _Optional[int] = ..., nProcPerNode: _Optional[int] = ..., maxRestarts: _Optional[int] = ...) -> None: ... diff --git a/gen/pb_rust/flyteidl.plugins.rs b/gen/pb_rust/flyteidl.plugins.rs index 14251aa64..922cd86ab 100644 --- a/gen/pb_rust/flyteidl.plugins.rs +++ b/gen/pb_rust/flyteidl.plugins.rs @@ -110,6 +110,18 @@ pub struct DistributedPyTorchTrainingTask { /// number of worker replicas spawned in the cluster for this job #[prost(int32, tag="1")] pub workers: i32, + /// config for an elastic pytorch job + /// + #[prost(string, tag="2")] + pub rdzv_backend: ::prost::alloc::string::String, + #[prost(int32, tag="3")] + pub min_replicas: i32, + #[prost(int32, tag="4")] + pub max_replicas: i32, + #[prost(int32, tag="5")] + pub n_proc_per_node: i32, + #[prost(int32, tag="6")] + pub max_restarts: i32, } /// Defines a query to execute on a hive cluster. #[allow(clippy::derive_partial_eq_without_eq)] diff --git a/protos/flyteidl/plugins/pytorch.proto b/protos/flyteidl/plugins/pytorch.proto index 603de00c3..23191b276 100644 --- a/protos/flyteidl/plugins/pytorch.proto +++ b/protos/flyteidl/plugins/pytorch.proto @@ -8,4 +8,12 @@ option go_package = "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins"; message DistributedPyTorchTrainingTask { // number of worker replicas spawned in the cluster for this job int32 workers = 1; + + // config for an elastic pytorch job + // https://github.com/kubeflow/training-operator/blob/master/pkg/apis/kubeflow.org/v1/pytorch_types.go + string RDZVBackend = 2; + int32 minReplicas = 3; + int32 maxReplicas = 4; + int32 nProcPerNode = 5; + int32 maxRestarts = 6; }