From 771d809391fe920dff86fd6cdfd87b8e5c37d012 Mon Sep 17 00:00:00 2001 From: HUAN-PING SU Date: Tue, 7 Jul 2020 01:27:21 +0800 Subject: [PATCH] feature; tfoperator for tensorflow distributed training plugin (#71) --- .../flyteidl/plugins/tensorflow.grpc.pb.cc | 24 + .../flyteidl/plugins/tensorflow.grpc.pb.h | 47 ++ gen/pb-cpp/flyteidl/plugins/tensorflow.pb.cc | 461 ++++++++++++ gen/pb-cpp/flyteidl/plugins/tensorflow.pb.h | 257 +++++++ gen/pb-go/flyteidl/plugins/tensorflow.pb.go | 102 +++ .../plugins/tensorflow.pb.validate.go | 111 +++ gen/pb-java/flyteidl/plugins/Tensorflow.java | 705 ++++++++++++++++++ gen/pb-protodoc/flyteidl/plugins/index.rst | 1 + .../flyteidl/plugins/tensorflow.proto.rst | 40 + .../flyteidl/plugins/tensorflow_pb2.py | 85 +++ .../flyteidl/plugins/tensorflow_pb2_grpc.py | 3 + package.json | 2 +- protos/flyteidl/plugins/tensorflow.proto | 14 + setup.py | 2 +- 14 files changed, 1852 insertions(+), 2 deletions(-) create mode 100644 gen/pb-cpp/flyteidl/plugins/tensorflow.grpc.pb.cc create mode 100644 gen/pb-cpp/flyteidl/plugins/tensorflow.grpc.pb.h create mode 100644 gen/pb-cpp/flyteidl/plugins/tensorflow.pb.cc create mode 100644 gen/pb-cpp/flyteidl/plugins/tensorflow.pb.h create mode 100644 gen/pb-go/flyteidl/plugins/tensorflow.pb.go create mode 100644 gen/pb-go/flyteidl/plugins/tensorflow.pb.validate.go create mode 100644 gen/pb-java/flyteidl/plugins/Tensorflow.java create mode 100644 gen/pb-protodoc/flyteidl/plugins/tensorflow.proto.rst create mode 100644 gen/pb_python/flyteidl/plugins/tensorflow_pb2.py create mode 100644 gen/pb_python/flyteidl/plugins/tensorflow_pb2_grpc.py create mode 100644 protos/flyteidl/plugins/tensorflow.proto diff --git a/gen/pb-cpp/flyteidl/plugins/tensorflow.grpc.pb.cc b/gen/pb-cpp/flyteidl/plugins/tensorflow.grpc.pb.cc new file mode 100644 index 000000000..f3a3c5622 --- /dev/null +++ b/gen/pb-cpp/flyteidl/plugins/tensorflow.grpc.pb.cc @@ -0,0 +1,24 @@ +// Generated by the gRPC C++ plugin. +// If you make any local change, they will be lost. +// source: flyteidl/plugins/tensorflow.proto + +#include "flyteidl/plugins/tensorflow.pb.h" +#include "flyteidl/plugins/tensorflow.grpc.pb.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +namespace flyteidl { +namespace plugins { + +} // namespace flyteidl +} // namespace plugins + diff --git a/gen/pb-cpp/flyteidl/plugins/tensorflow.grpc.pb.h b/gen/pb-cpp/flyteidl/plugins/tensorflow.grpc.pb.h new file mode 100644 index 000000000..1bc80de44 --- /dev/null +++ b/gen/pb-cpp/flyteidl/plugins/tensorflow.grpc.pb.h @@ -0,0 +1,47 @@ +// Generated by the gRPC C++ plugin. +// If you make any local change, they will be lost. +// source: flyteidl/plugins/tensorflow.proto +#ifndef GRPC_flyteidl_2fplugins_2ftensorflow_2eproto__INCLUDED +#define GRPC_flyteidl_2fplugins_2ftensorflow_2eproto__INCLUDED + +#include "flyteidl/plugins/tensorflow.pb.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace grpc_impl { +class Channel; +class CompletionQueue; +class ServerCompletionQueue; +} // namespace grpc_impl + +namespace grpc { +namespace experimental { +template +class MessageAllocator; +} // namespace experimental +} // namespace grpc_impl + +namespace grpc { +class ServerContext; +} // namespace grpc + +namespace flyteidl { +namespace plugins { + +} // namespace plugins +} // namespace flyteidl + + +#endif // GRPC_flyteidl_2fplugins_2ftensorflow_2eproto__INCLUDED diff --git a/gen/pb-cpp/flyteidl/plugins/tensorflow.pb.cc b/gen/pb-cpp/flyteidl/plugins/tensorflow.pb.cc new file mode 100644 index 000000000..702225b61 --- /dev/null +++ b/gen/pb-cpp/flyteidl/plugins/tensorflow.pb.cc @@ -0,0 +1,461 @@ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: flyteidl/plugins/tensorflow.proto + +#include "flyteidl/plugins/tensorflow.pb.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +// @@protoc_insertion_point(includes) +#include + +namespace flyteidl { +namespace plugins { +class DistributedTensorflowTrainingTaskDefaultTypeInternal { + public: + ::google::protobuf::internal::ExplicitlyConstructed _instance; +} _DistributedTensorflowTrainingTask_default_instance_; +} // namespace plugins +} // namespace flyteidl +static void InitDefaultsDistributedTensorflowTrainingTask_flyteidl_2fplugins_2ftensorflow_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::flyteidl::plugins::_DistributedTensorflowTrainingTask_default_instance_; + new (ptr) ::flyteidl::plugins::DistributedTensorflowTrainingTask(); + ::google::protobuf::internal::OnShutdownDestroyMessage(ptr); + } + ::flyteidl::plugins::DistributedTensorflowTrainingTask::InitAsDefaultInstance(); +} + +::google::protobuf::internal::SCCInfo<0> scc_info_DistributedTensorflowTrainingTask_flyteidl_2fplugins_2ftensorflow_2eproto = + {{ATOMIC_VAR_INIT(::google::protobuf::internal::SCCInfoBase::kUninitialized), 0, InitDefaultsDistributedTensorflowTrainingTask_flyteidl_2fplugins_2ftensorflow_2eproto}, {}}; + +void InitDefaults_flyteidl_2fplugins_2ftensorflow_2eproto() { + ::google::protobuf::internal::InitSCC(&scc_info_DistributedTensorflowTrainingTask_flyteidl_2fplugins_2ftensorflow_2eproto.base); +} + +::google::protobuf::Metadata file_level_metadata_flyteidl_2fplugins_2ftensorflow_2eproto[1]; +constexpr ::google::protobuf::EnumDescriptor const** file_level_enum_descriptors_flyteidl_2fplugins_2ftensorflow_2eproto = nullptr; +constexpr ::google::protobuf::ServiceDescriptor const** file_level_service_descriptors_flyteidl_2fplugins_2ftensorflow_2eproto = nullptr; + +const ::google::protobuf::uint32 TableStruct_flyteidl_2fplugins_2ftensorflow_2eproto::offsets[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = { + ~0u, // no _has_bits_ + PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::DistributedTensorflowTrainingTask, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::DistributedTensorflowTrainingTask, workers_), + PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::DistributedTensorflowTrainingTask, ps_replicas_), + PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::DistributedTensorflowTrainingTask, chief_replicas_), +}; +static const ::google::protobuf::internal::MigrationSchema schemas[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = { + { 0, -1, sizeof(::flyteidl::plugins::DistributedTensorflowTrainingTask)}, +}; + +static ::google::protobuf::Message const * const file_default_instances[] = { + reinterpret_cast(&::flyteidl::plugins::_DistributedTensorflowTrainingTask_default_instance_), +}; + +::google::protobuf::internal::AssignDescriptorsTable assign_descriptors_table_flyteidl_2fplugins_2ftensorflow_2eproto = { + {}, AddDescriptors_flyteidl_2fplugins_2ftensorflow_2eproto, "flyteidl/plugins/tensorflow.proto", schemas, + file_default_instances, TableStruct_flyteidl_2fplugins_2ftensorflow_2eproto::offsets, + file_level_metadata_flyteidl_2fplugins_2ftensorflow_2eproto, 1, file_level_enum_descriptors_flyteidl_2fplugins_2ftensorflow_2eproto, file_level_service_descriptors_flyteidl_2fplugins_2ftensorflow_2eproto, +}; + +const char descriptor_table_protodef_flyteidl_2fplugins_2ftensorflow_2eproto[] = + "\n!flyteidl/plugins/tensorflow.proto\022\020fly" + "teidl.plugins\"a\n!DistributedTensorflowTr" + "ainingTask\022\017\n\007workers\030\001 \001(\005\022\023\n\013ps_replic" + "as\030\002 \001(\005\022\026\n\016chief_replicas\030\003 \001(\005B5Z3gith" + "ub.com/lyft/flyteidl/gen/pb-go/flyteidl/" + "pluginsb\006proto3" + ; +::google::protobuf::internal::DescriptorTable descriptor_table_flyteidl_2fplugins_2ftensorflow_2eproto = { + false, InitDefaults_flyteidl_2fplugins_2ftensorflow_2eproto, + descriptor_table_protodef_flyteidl_2fplugins_2ftensorflow_2eproto, + "flyteidl/plugins/tensorflow.proto", &assign_descriptors_table_flyteidl_2fplugins_2ftensorflow_2eproto, 215, +}; + +void AddDescriptors_flyteidl_2fplugins_2ftensorflow_2eproto() { + static constexpr ::google::protobuf::internal::InitFunc deps[1] = + { + }; + ::google::protobuf::internal::AddDescriptors(&descriptor_table_flyteidl_2fplugins_2ftensorflow_2eproto, deps, 0); +} + +// Force running AddDescriptors() at dynamic initialization time. +static bool dynamic_init_dummy_flyteidl_2fplugins_2ftensorflow_2eproto = []() { AddDescriptors_flyteidl_2fplugins_2ftensorflow_2eproto(); return true; }(); +namespace flyteidl { +namespace plugins { + +// =================================================================== + +void DistributedTensorflowTrainingTask::InitAsDefaultInstance() { +} +class DistributedTensorflowTrainingTask::HasBitSetters { + public: +}; + +#if !defined(_MSC_VER) || _MSC_VER >= 1900 +const int DistributedTensorflowTrainingTask::kWorkersFieldNumber; +const int DistributedTensorflowTrainingTask::kPsReplicasFieldNumber; +const int DistributedTensorflowTrainingTask::kChiefReplicasFieldNumber; +#endif // !defined(_MSC_VER) || _MSC_VER >= 1900 + +DistributedTensorflowTrainingTask::DistributedTensorflowTrainingTask() + : ::google::protobuf::Message(), _internal_metadata_(nullptr) { + SharedCtor(); + // @@protoc_insertion_point(constructor:flyteidl.plugins.DistributedTensorflowTrainingTask) +} +DistributedTensorflowTrainingTask::DistributedTensorflowTrainingTask(const DistributedTensorflowTrainingTask& from) + : ::google::protobuf::Message(), + _internal_metadata_(nullptr) { + _internal_metadata_.MergeFrom(from._internal_metadata_); + ::memcpy(&workers_, &from.workers_, + static_cast(reinterpret_cast(&chief_replicas_) - + reinterpret_cast(&workers_)) + sizeof(chief_replicas_)); + // @@protoc_insertion_point(copy_constructor:flyteidl.plugins.DistributedTensorflowTrainingTask) +} + +void DistributedTensorflowTrainingTask::SharedCtor() { + ::memset(&workers_, 0, static_cast( + reinterpret_cast(&chief_replicas_) - + reinterpret_cast(&workers_)) + sizeof(chief_replicas_)); +} + +DistributedTensorflowTrainingTask::~DistributedTensorflowTrainingTask() { + // @@protoc_insertion_point(destructor:flyteidl.plugins.DistributedTensorflowTrainingTask) + SharedDtor(); +} + +void DistributedTensorflowTrainingTask::SharedDtor() { +} + +void DistributedTensorflowTrainingTask::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const DistributedTensorflowTrainingTask& DistributedTensorflowTrainingTask::default_instance() { + ::google::protobuf::internal::InitSCC(&::scc_info_DistributedTensorflowTrainingTask_flyteidl_2fplugins_2ftensorflow_2eproto.base); + return *internal_default_instance(); +} + + +void DistributedTensorflowTrainingTask::Clear() { +// @@protoc_insertion_point(message_clear_start:flyteidl.plugins.DistributedTensorflowTrainingTask) + ::google::protobuf::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + ::memset(&workers_, 0, static_cast( + reinterpret_cast(&chief_replicas_) - + reinterpret_cast(&workers_)) + sizeof(chief_replicas_)); + _internal_metadata_.Clear(); +} + +#if GOOGLE_PROTOBUF_ENABLE_EXPERIMENTAL_PARSER +const char* DistributedTensorflowTrainingTask::_InternalParse(const char* begin, const char* end, void* object, + ::google::protobuf::internal::ParseContext* ctx) { + auto msg = static_cast(object); + ::google::protobuf::int32 size; (void)size; + int depth; (void)depth; + ::google::protobuf::uint32 tag; + ::google::protobuf::internal::ParseFunc parser_till_end; (void)parser_till_end; + auto ptr = begin; + while (ptr < end) { + ptr = ::google::protobuf::io::Parse32(ptr, &tag); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + switch (tag >> 3) { + // int32 workers = 1; + case 1: { + if (static_cast<::google::protobuf::uint8>(tag) != 8) goto handle_unusual; + msg->set_workers(::google::protobuf::internal::ReadVarint(&ptr)); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + break; + } + // int32 ps_replicas = 2; + case 2: { + if (static_cast<::google::protobuf::uint8>(tag) != 16) goto handle_unusual; + msg->set_ps_replicas(::google::protobuf::internal::ReadVarint(&ptr)); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + break; + } + // int32 chief_replicas = 3; + case 3: { + if (static_cast<::google::protobuf::uint8>(tag) != 24) goto handle_unusual; + msg->set_chief_replicas(::google::protobuf::internal::ReadVarint(&ptr)); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + break; + } + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->EndGroup(tag); + return ptr; + } + auto res = UnknownFieldParse(tag, {_InternalParse, msg}, + ptr, end, msg->_internal_metadata_.mutable_unknown_fields(), ctx); + ptr = res.first; + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr != nullptr); + if (res.second) return ptr; + } + } // switch + } // while + return ptr; +} +#else // GOOGLE_PROTOBUF_ENABLE_EXPERIMENTAL_PARSER +bool DistributedTensorflowTrainingTask::MergePartialFromCodedStream( + ::google::protobuf::io::CodedInputStream* input) { +#define DO_(EXPRESSION) if (!PROTOBUF_PREDICT_TRUE(EXPRESSION)) goto failure + ::google::protobuf::uint32 tag; + // @@protoc_insertion_point(parse_start:flyteidl.plugins.DistributedTensorflowTrainingTask) + for (;;) { + ::std::pair<::google::protobuf::uint32, bool> p = input->ReadTagWithCutoffNoLastTag(127u); + tag = p.first; + if (!p.second) goto handle_unusual; + switch (::google::protobuf::internal::WireFormatLite::GetTagFieldNumber(tag)) { + // int32 workers = 1; + case 1: { + if (static_cast< ::google::protobuf::uint8>(tag) == (8 & 0xFF)) { + + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( + input, &workers_))); + } else { + goto handle_unusual; + } + break; + } + + // int32 ps_replicas = 2; + case 2: { + if (static_cast< ::google::protobuf::uint8>(tag) == (16 & 0xFF)) { + + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( + input, &ps_replicas_))); + } else { + goto handle_unusual; + } + break; + } + + // int32 chief_replicas = 3; + case 3: { + if (static_cast< ::google::protobuf::uint8>(tag) == (24 & 0xFF)) { + + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( + input, &chief_replicas_))); + } else { + goto handle_unusual; + } + break; + } + + default: { + handle_unusual: + if (tag == 0) { + goto success; + } + DO_(::google::protobuf::internal::WireFormat::SkipField( + input, tag, _internal_metadata_.mutable_unknown_fields())); + break; + } + } + } +success: + // @@protoc_insertion_point(parse_success:flyteidl.plugins.DistributedTensorflowTrainingTask) + return true; +failure: + // @@protoc_insertion_point(parse_failure:flyteidl.plugins.DistributedTensorflowTrainingTask) + return false; +#undef DO_ +} +#endif // GOOGLE_PROTOBUF_ENABLE_EXPERIMENTAL_PARSER + +void DistributedTensorflowTrainingTask::SerializeWithCachedSizes( + ::google::protobuf::io::CodedOutputStream* output) const { + // @@protoc_insertion_point(serialize_start:flyteidl.plugins.DistributedTensorflowTrainingTask) + ::google::protobuf::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // int32 workers = 1; + if (this->workers() != 0) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(1, this->workers(), output); + } + + // int32 ps_replicas = 2; + if (this->ps_replicas() != 0) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(2, this->ps_replicas(), output); + } + + // int32 chief_replicas = 3; + if (this->chief_replicas() != 0) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(3, this->chief_replicas(), output); + } + + if (_internal_metadata_.have_unknown_fields()) { + ::google::protobuf::internal::WireFormat::SerializeUnknownFields( + _internal_metadata_.unknown_fields(), output); + } + // @@protoc_insertion_point(serialize_end:flyteidl.plugins.DistributedTensorflowTrainingTask) +} + +::google::protobuf::uint8* DistributedTensorflowTrainingTask::InternalSerializeWithCachedSizesToArray( + ::google::protobuf::uint8* target) const { + // @@protoc_insertion_point(serialize_to_array_start:flyteidl.plugins.DistributedTensorflowTrainingTask) + ::google::protobuf::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // int32 workers = 1; + if (this->workers() != 0) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(1, this->workers(), target); + } + + // int32 ps_replicas = 2; + if (this->ps_replicas() != 0) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(2, this->ps_replicas(), target); + } + + // int32 chief_replicas = 3; + if (this->chief_replicas() != 0) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(3, this->chief_replicas(), target); + } + + if (_internal_metadata_.have_unknown_fields()) { + target = ::google::protobuf::internal::WireFormat::SerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields(), target); + } + // @@protoc_insertion_point(serialize_to_array_end:flyteidl.plugins.DistributedTensorflowTrainingTask) + return target; +} + +size_t DistributedTensorflowTrainingTask::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:flyteidl.plugins.DistributedTensorflowTrainingTask) + size_t total_size = 0; + + if (_internal_metadata_.have_unknown_fields()) { + total_size += + ::google::protobuf::internal::WireFormat::ComputeUnknownFieldsSize( + _internal_metadata_.unknown_fields()); + } + ::google::protobuf::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // int32 workers = 1; + if (this->workers() != 0) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::Int32Size( + this->workers()); + } + + // int32 ps_replicas = 2; + if (this->ps_replicas() != 0) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::Int32Size( + this->ps_replicas()); + } + + // int32 chief_replicas = 3; + if (this->chief_replicas() != 0) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::Int32Size( + this->chief_replicas()); + } + + int cached_size = ::google::protobuf::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void DistributedTensorflowTrainingTask::MergeFrom(const ::google::protobuf::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:flyteidl.plugins.DistributedTensorflowTrainingTask) + GOOGLE_DCHECK_NE(&from, this); + const DistributedTensorflowTrainingTask* source = + ::google::protobuf::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:flyteidl.plugins.DistributedTensorflowTrainingTask) + ::google::protobuf::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:flyteidl.plugins.DistributedTensorflowTrainingTask) + MergeFrom(*source); + } +} + +void DistributedTensorflowTrainingTask::MergeFrom(const DistributedTensorflowTrainingTask& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:flyteidl.plugins.DistributedTensorflowTrainingTask) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom(from._internal_metadata_); + ::google::protobuf::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + if (from.workers() != 0) { + set_workers(from.workers()); + } + if (from.ps_replicas() != 0) { + set_ps_replicas(from.ps_replicas()); + } + if (from.chief_replicas() != 0) { + set_chief_replicas(from.chief_replicas()); + } +} + +void DistributedTensorflowTrainingTask::CopyFrom(const ::google::protobuf::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:flyteidl.plugins.DistributedTensorflowTrainingTask) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void DistributedTensorflowTrainingTask::CopyFrom(const DistributedTensorflowTrainingTask& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:flyteidl.plugins.DistributedTensorflowTrainingTask) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool DistributedTensorflowTrainingTask::IsInitialized() const { + return true; +} + +void DistributedTensorflowTrainingTask::Swap(DistributedTensorflowTrainingTask* other) { + if (other == this) return; + InternalSwap(other); +} +void DistributedTensorflowTrainingTask::InternalSwap(DistributedTensorflowTrainingTask* other) { + using std::swap; + _internal_metadata_.Swap(&other->_internal_metadata_); + swap(workers_, other->workers_); + swap(ps_replicas_, other->ps_replicas_); + swap(chief_replicas_, other->chief_replicas_); +} + +::google::protobuf::Metadata DistributedTensorflowTrainingTask::GetMetadata() const { + ::google::protobuf::internal::AssignDescriptors(&::assign_descriptors_table_flyteidl_2fplugins_2ftensorflow_2eproto); + return ::file_level_metadata_flyteidl_2fplugins_2ftensorflow_2eproto[kIndexInFileMessages]; +} + + +// @@protoc_insertion_point(namespace_scope) +} // namespace plugins +} // namespace flyteidl +namespace google { +namespace protobuf { +template<> PROTOBUF_NOINLINE ::flyteidl::plugins::DistributedTensorflowTrainingTask* Arena::CreateMaybeMessage< ::flyteidl::plugins::DistributedTensorflowTrainingTask >(Arena* arena) { + return Arena::CreateInternal< ::flyteidl::plugins::DistributedTensorflowTrainingTask >(arena); +} +} // namespace protobuf +} // namespace google + +// @@protoc_insertion_point(global_scope) +#include diff --git a/gen/pb-cpp/flyteidl/plugins/tensorflow.pb.h b/gen/pb-cpp/flyteidl/plugins/tensorflow.pb.h new file mode 100644 index 000000000..613ed31d8 --- /dev/null +++ b/gen/pb-cpp/flyteidl/plugins/tensorflow.pb.h @@ -0,0 +1,257 @@ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: flyteidl/plugins/tensorflow.proto + +#ifndef PROTOBUF_INCLUDED_flyteidl_2fplugins_2ftensorflow_2eproto +#define PROTOBUF_INCLUDED_flyteidl_2fplugins_2ftensorflow_2eproto + +#include +#include + +#include +#if PROTOBUF_VERSION < 3007000 +#error This file was generated by a newer version of protoc which is +#error incompatible with your Protocol Buffer headers. Please update +#error your headers. +#endif +#if 3007000 < PROTOBUF_MIN_PROTOC_VERSION +#error This file was generated by an older version of protoc which is +#error incompatible with your Protocol Buffer headers. Please +#error regenerate this file with a newer version of protoc. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // IWYU pragma: export +#include // IWYU pragma: export +#include +// @@protoc_insertion_point(includes) +#include +#define PROTOBUF_INTERNAL_EXPORT_flyteidl_2fplugins_2ftensorflow_2eproto + +// Internal implementation detail -- do not use these members. +struct TableStruct_flyteidl_2fplugins_2ftensorflow_2eproto { + static const ::google::protobuf::internal::ParseTableField entries[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::google::protobuf::internal::AuxillaryParseTableField aux[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::google::protobuf::internal::ParseTable schema[1] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::google::protobuf::internal::FieldMetadata field_metadata[]; + static const ::google::protobuf::internal::SerializationTable serialization_table[]; + static const ::google::protobuf::uint32 offsets[]; +}; +void AddDescriptors_flyteidl_2fplugins_2ftensorflow_2eproto(); +namespace flyteidl { +namespace plugins { +class DistributedTensorflowTrainingTask; +class DistributedTensorflowTrainingTaskDefaultTypeInternal; +extern DistributedTensorflowTrainingTaskDefaultTypeInternal _DistributedTensorflowTrainingTask_default_instance_; +} // namespace plugins +} // namespace flyteidl +namespace google { +namespace protobuf { +template<> ::flyteidl::plugins::DistributedTensorflowTrainingTask* Arena::CreateMaybeMessage<::flyteidl::plugins::DistributedTensorflowTrainingTask>(Arena*); +} // namespace protobuf +} // namespace google +namespace flyteidl { +namespace plugins { + +// =================================================================== + +class DistributedTensorflowTrainingTask final : + public ::google::protobuf::Message /* @@protoc_insertion_point(class_definition:flyteidl.plugins.DistributedTensorflowTrainingTask) */ { + public: + DistributedTensorflowTrainingTask(); + virtual ~DistributedTensorflowTrainingTask(); + + DistributedTensorflowTrainingTask(const DistributedTensorflowTrainingTask& from); + + inline DistributedTensorflowTrainingTask& operator=(const DistributedTensorflowTrainingTask& from) { + CopyFrom(from); + return *this; + } + #if LANG_CXX11 + DistributedTensorflowTrainingTask(DistributedTensorflowTrainingTask&& from) noexcept + : DistributedTensorflowTrainingTask() { + *this = ::std::move(from); + } + + inline DistributedTensorflowTrainingTask& operator=(DistributedTensorflowTrainingTask&& from) noexcept { + if (GetArenaNoVirtual() == from.GetArenaNoVirtual()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + #endif + static const ::google::protobuf::Descriptor* descriptor() { + return default_instance().GetDescriptor(); + } + static const DistributedTensorflowTrainingTask& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const DistributedTensorflowTrainingTask* internal_default_instance() { + return reinterpret_cast( + &_DistributedTensorflowTrainingTask_default_instance_); + } + static constexpr int kIndexInFileMessages = + 0; + + void Swap(DistributedTensorflowTrainingTask* other); + friend void swap(DistributedTensorflowTrainingTask& a, DistributedTensorflowTrainingTask& b) { + a.Swap(&b); + } + + // implements Message ---------------------------------------------- + + inline DistributedTensorflowTrainingTask* New() const final { + return CreateMaybeMessage(nullptr); + } + + DistributedTensorflowTrainingTask* New(::google::protobuf::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::google::protobuf::Message& from) final; + void MergeFrom(const ::google::protobuf::Message& from) final; + void CopyFrom(const DistributedTensorflowTrainingTask& from); + void MergeFrom(const DistributedTensorflowTrainingTask& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + #if GOOGLE_PROTOBUF_ENABLE_EXPERIMENTAL_PARSER + static const char* _InternalParse(const char* begin, const char* end, void* object, ::google::protobuf::internal::ParseContext* ctx); + ::google::protobuf::internal::ParseFunc _ParseFunc() const final { return _InternalParse; } + #else + bool MergePartialFromCodedStream( + ::google::protobuf::io::CodedInputStream* input) final; + #endif // GOOGLE_PROTOBUF_ENABLE_EXPERIMENTAL_PARSER + void SerializeWithCachedSizes( + ::google::protobuf::io::CodedOutputStream* output) const final; + ::google::protobuf::uint8* InternalSerializeWithCachedSizesToArray( + ::google::protobuf::uint8* target) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + void SharedCtor(); + void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(DistributedTensorflowTrainingTask* other); + private: + inline ::google::protobuf::Arena* GetArenaNoVirtual() const { + return nullptr; + } + inline void* MaybeArenaPtr() const { + return nullptr; + } + public: + + ::google::protobuf::Metadata GetMetadata() const final; + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + // int32 workers = 1; + void clear_workers(); + static const int kWorkersFieldNumber = 1; + ::google::protobuf::int32 workers() const; + void set_workers(::google::protobuf::int32 value); + + // int32 ps_replicas = 2; + void clear_ps_replicas(); + static const int kPsReplicasFieldNumber = 2; + ::google::protobuf::int32 ps_replicas() const; + void set_ps_replicas(::google::protobuf::int32 value); + + // int32 chief_replicas = 3; + void clear_chief_replicas(); + static const int kChiefReplicasFieldNumber = 3; + ::google::protobuf::int32 chief_replicas() const; + void set_chief_replicas(::google::protobuf::int32 value); + + // @@protoc_insertion_point(class_scope:flyteidl.plugins.DistributedTensorflowTrainingTask) + private: + class HasBitSetters; + + ::google::protobuf::internal::InternalMetadataWithArena _internal_metadata_; + ::google::protobuf::int32 workers_; + ::google::protobuf::int32 ps_replicas_; + ::google::protobuf::int32 chief_replicas_; + mutable ::google::protobuf::internal::CachedSize _cached_size_; + friend struct ::TableStruct_flyteidl_2fplugins_2ftensorflow_2eproto; +}; +// =================================================================== + + +// =================================================================== + +#ifdef __GNUC__ + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // __GNUC__ +// DistributedTensorflowTrainingTask + +// int32 workers = 1; +inline void DistributedTensorflowTrainingTask::clear_workers() { + workers_ = 0; +} +inline ::google::protobuf::int32 DistributedTensorflowTrainingTask::workers() const { + // @@protoc_insertion_point(field_get:flyteidl.plugins.DistributedTensorflowTrainingTask.workers) + return workers_; +} +inline void DistributedTensorflowTrainingTask::set_workers(::google::protobuf::int32 value) { + + workers_ = value; + // @@protoc_insertion_point(field_set:flyteidl.plugins.DistributedTensorflowTrainingTask.workers) +} + +// int32 ps_replicas = 2; +inline void DistributedTensorflowTrainingTask::clear_ps_replicas() { + ps_replicas_ = 0; +} +inline ::google::protobuf::int32 DistributedTensorflowTrainingTask::ps_replicas() const { + // @@protoc_insertion_point(field_get:flyteidl.plugins.DistributedTensorflowTrainingTask.ps_replicas) + return ps_replicas_; +} +inline void DistributedTensorflowTrainingTask::set_ps_replicas(::google::protobuf::int32 value) { + + ps_replicas_ = value; + // @@protoc_insertion_point(field_set:flyteidl.plugins.DistributedTensorflowTrainingTask.ps_replicas) +} + +// int32 chief_replicas = 3; +inline void DistributedTensorflowTrainingTask::clear_chief_replicas() { + chief_replicas_ = 0; +} +inline ::google::protobuf::int32 DistributedTensorflowTrainingTask::chief_replicas() const { + // @@protoc_insertion_point(field_get:flyteidl.plugins.DistributedTensorflowTrainingTask.chief_replicas) + return chief_replicas_; +} +inline void DistributedTensorflowTrainingTask::set_chief_replicas(::google::protobuf::int32 value) { + + chief_replicas_ = value; + // @@protoc_insertion_point(field_set:flyteidl.plugins.DistributedTensorflowTrainingTask.chief_replicas) +} + +#ifdef __GNUC__ + #pragma GCC diagnostic pop +#endif // __GNUC__ + +// @@protoc_insertion_point(namespace_scope) + +} // namespace plugins +} // namespace flyteidl + +// @@protoc_insertion_point(global_scope) + +#include +#endif // PROTOBUF_INCLUDED_flyteidl_2fplugins_2ftensorflow_2eproto diff --git a/gen/pb-go/flyteidl/plugins/tensorflow.pb.go b/gen/pb-go/flyteidl/plugins/tensorflow.pb.go new file mode 100644 index 000000000..a7fa17ca9 --- /dev/null +++ b/gen/pb-go/flyteidl/plugins/tensorflow.pb.go @@ -0,0 +1,102 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: flyteidl/plugins/tensorflow.proto + +package plugins + +import ( + fmt "fmt" + proto "github.com/golang/protobuf/proto" + math "math" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package + +// Custom proto for plugin that enables distributed training using https://github.com/kubeflow/tf-operator +type DistributedTensorflowTrainingTask struct { + // number of worker, ps, chief replicas spawned in the cluster for this job + Workers int32 `protobuf:"varint,1,opt,name=workers,proto3" json:"workers,omitempty"` + // PS -> Parameter server + PsReplicas int32 `protobuf:"varint,2,opt,name=ps_replicas,json=psReplicas,proto3" json:"ps_replicas,omitempty"` + ChiefReplicas int32 `protobuf:"varint,3,opt,name=chief_replicas,json=chiefReplicas,proto3" json:"chief_replicas,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *DistributedTensorflowTrainingTask) Reset() { *m = DistributedTensorflowTrainingTask{} } +func (m *DistributedTensorflowTrainingTask) String() string { return proto.CompactTextString(m) } +func (*DistributedTensorflowTrainingTask) ProtoMessage() {} +func (*DistributedTensorflowTrainingTask) Descriptor() ([]byte, []int) { + return fileDescriptor_8da02783614e1bcc, []int{0} +} + +func (m *DistributedTensorflowTrainingTask) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_DistributedTensorflowTrainingTask.Unmarshal(m, b) +} +func (m *DistributedTensorflowTrainingTask) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_DistributedTensorflowTrainingTask.Marshal(b, m, deterministic) +} +func (m *DistributedTensorflowTrainingTask) XXX_Merge(src proto.Message) { + xxx_messageInfo_DistributedTensorflowTrainingTask.Merge(m, src) +} +func (m *DistributedTensorflowTrainingTask) XXX_Size() int { + return xxx_messageInfo_DistributedTensorflowTrainingTask.Size(m) +} +func (m *DistributedTensorflowTrainingTask) XXX_DiscardUnknown() { + xxx_messageInfo_DistributedTensorflowTrainingTask.DiscardUnknown(m) +} + +var xxx_messageInfo_DistributedTensorflowTrainingTask proto.InternalMessageInfo + +func (m *DistributedTensorflowTrainingTask) GetWorkers() int32 { + if m != nil { + return m.Workers + } + return 0 +} + +func (m *DistributedTensorflowTrainingTask) GetPsReplicas() int32 { + if m != nil { + return m.PsReplicas + } + return 0 +} + +func (m *DistributedTensorflowTrainingTask) GetChiefReplicas() int32 { + if m != nil { + return m.ChiefReplicas + } + return 0 +} + +func init() { + proto.RegisterType((*DistributedTensorflowTrainingTask)(nil), "flyteidl.plugins.DistributedTensorflowTrainingTask") +} + +func init() { proto.RegisterFile("flyteidl/plugins/tensorflow.proto", fileDescriptor_8da02783614e1bcc) } + +var fileDescriptor_8da02783614e1bcc = []byte{ + // 200 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x5c, 0xcf, 0x41, 0x8a, 0x83, 0x30, + 0x14, 0xc6, 0x71, 0x9c, 0x61, 0x66, 0x20, 0xc3, 0x0c, 0x43, 0x56, 0xee, 0xa6, 0x16, 0x0a, 0xdd, + 0xd4, 0x2c, 0xa4, 0x17, 0x28, 0x3d, 0x81, 0xb8, 0xea, 0xa6, 0x18, 0x4d, 0xe2, 0xc3, 0x34, 0x09, + 0x79, 0x11, 0xf1, 0x00, 0xbd, 0x77, 0x21, 0xa8, 0x85, 0x2e, 0xf3, 0xe5, 0xf7, 0x16, 0x7f, 0x92, + 0x49, 0x3d, 0x05, 0x01, 0xad, 0x66, 0x4e, 0x0f, 0x0a, 0x0c, 0xb2, 0x20, 0x0c, 0x5a, 0x2f, 0xb5, + 0x1d, 0x73, 0xe7, 0x6d, 0xb0, 0xf4, 0x6f, 0x21, 0xf9, 0x4c, 0xb6, 0xf7, 0x84, 0x64, 0x67, 0xc0, + 0xe0, 0x81, 0x0f, 0x41, 0xb4, 0xd5, 0x7a, 0x51, 0xf9, 0x1a, 0x0c, 0x18, 0x55, 0xd5, 0xd8, 0xd3, + 0x94, 0x7c, 0x8d, 0xd6, 0xf7, 0xc2, 0x63, 0x9a, 0x6c, 0x92, 0xfd, 0x47, 0xb9, 0x3c, 0xe9, 0x3f, + 0xf9, 0x76, 0x78, 0xf5, 0xc2, 0x69, 0x68, 0x6a, 0x4c, 0xdf, 0xe2, 0x2f, 0x71, 0x58, 0xce, 0x0b, + 0xdd, 0x91, 0xdf, 0xa6, 0x03, 0x21, 0x9f, 0xe6, 0x3d, 0x9a, 0x9f, 0xb8, 0x2e, 0xec, 0x74, 0xbc, + 0x14, 0x0a, 0x42, 0x37, 0xf0, 0xbc, 0xb1, 0x37, 0xa6, 0x27, 0x19, 0xd8, 0x9a, 0xa3, 0x84, 0x61, + 0x8e, 0x1f, 0x94, 0x65, 0xaf, 0x85, 0xfc, 0x33, 0x76, 0x15, 0x8f, 0x00, 0x00, 0x00, 0xff, 0xff, + 0x8f, 0x56, 0x52, 0x81, 0xfc, 0x00, 0x00, 0x00, +} diff --git a/gen/pb-go/flyteidl/plugins/tensorflow.pb.validate.go b/gen/pb-go/flyteidl/plugins/tensorflow.pb.validate.go new file mode 100644 index 000000000..ed7a8eeb8 --- /dev/null +++ b/gen/pb-go/flyteidl/plugins/tensorflow.pb.validate.go @@ -0,0 +1,111 @@ +// Code generated by protoc-gen-validate. DO NOT EDIT. +// source: flyteidl/plugins/tensorflow.proto + +package plugins + +import ( + "bytes" + "errors" + "fmt" + "net" + "net/mail" + "net/url" + "regexp" + "strings" + "time" + "unicode/utf8" + + "github.com/golang/protobuf/ptypes" +) + +// ensure the imports are used +var ( + _ = bytes.MinRead + _ = errors.New("") + _ = fmt.Print + _ = utf8.UTFMax + _ = (*regexp.Regexp)(nil) + _ = (*strings.Reader)(nil) + _ = net.IPv4len + _ = time.Duration(0) + _ = (*url.URL)(nil) + _ = (*mail.Address)(nil) + _ = ptypes.DynamicAny{} +) + +// define the regex for a UUID once up-front +var _tensorflow_uuidPattern = regexp.MustCompile("^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$") + +// Validate checks the field values on DistributedTensorflowTrainingTask with +// the rules defined in the proto definition for this message. If any rules +// are violated, an error is returned. +func (m *DistributedTensorflowTrainingTask) Validate() error { + if m == nil { + return nil + } + + // no validation rules for Workers + + // no validation rules for PsReplicas + + // no validation rules for ChiefReplicas + + return nil +} + +// DistributedTensorflowTrainingTaskValidationError is the validation error +// returned by DistributedTensorflowTrainingTask.Validate if the designated +// constraints aren't met. +type DistributedTensorflowTrainingTaskValidationError struct { + field string + reason string + cause error + key bool +} + +// Field function returns field value. +func (e DistributedTensorflowTrainingTaskValidationError) Field() string { return e.field } + +// Reason function returns reason value. +func (e DistributedTensorflowTrainingTaskValidationError) Reason() string { return e.reason } + +// Cause function returns cause value. +func (e DistributedTensorflowTrainingTaskValidationError) Cause() error { return e.cause } + +// Key function returns key value. +func (e DistributedTensorflowTrainingTaskValidationError) Key() bool { return e.key } + +// ErrorName returns error name. +func (e DistributedTensorflowTrainingTaskValidationError) ErrorName() string { + return "DistributedTensorflowTrainingTaskValidationError" +} + +// Error satisfies the builtin error interface +func (e DistributedTensorflowTrainingTaskValidationError) Error() string { + cause := "" + if e.cause != nil { + cause = fmt.Sprintf(" | caused by: %v", e.cause) + } + + key := "" + if e.key { + key = "key for " + } + + return fmt.Sprintf( + "invalid %sDistributedTensorflowTrainingTask.%s: %s%s", + key, + e.field, + e.reason, + cause) +} + +var _ error = DistributedTensorflowTrainingTaskValidationError{} + +var _ interface { + Field() string + Reason() string + Key() bool + Cause() error + ErrorName() string +} = DistributedTensorflowTrainingTaskValidationError{} diff --git a/gen/pb-java/flyteidl/plugins/Tensorflow.java b/gen/pb-java/flyteidl/plugins/Tensorflow.java new file mode 100644 index 000000000..e8d80be10 --- /dev/null +++ b/gen/pb-java/flyteidl/plugins/Tensorflow.java @@ -0,0 +1,705 @@ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: flyteidl/plugins/tensorflow.proto + +package flyteidl.plugins; + +public final class Tensorflow { + private Tensorflow() {} + public static void registerAllExtensions( + com.google.protobuf.ExtensionRegistryLite registry) { + } + + public static void registerAllExtensions( + com.google.protobuf.ExtensionRegistry registry) { + registerAllExtensions( + (com.google.protobuf.ExtensionRegistryLite) registry); + } + public interface DistributedTensorflowTrainingTaskOrBuilder extends + // @@protoc_insertion_point(interface_extends:flyteidl.plugins.DistributedTensorflowTrainingTask) + com.google.protobuf.MessageOrBuilder { + + /** + *
+     * number of worker, ps, chief replicas spawned in the cluster for this job
+     * 
+ * + * int32 workers = 1; + */ + int getWorkers(); + + /** + *
+     * PS -> Parameter server
+     * 
+ * + * int32 ps_replicas = 2; + */ + int getPsReplicas(); + + /** + * int32 chief_replicas = 3; + */ + int getChiefReplicas(); + } + /** + *
+   * Custom proto for plugin that enables distributed training using https://github.com/kubeflow/tf-operator
+   * 
+ * + * Protobuf type {@code flyteidl.plugins.DistributedTensorflowTrainingTask} + */ + public static final class DistributedTensorflowTrainingTask extends + com.google.protobuf.GeneratedMessageV3 implements + // @@protoc_insertion_point(message_implements:flyteidl.plugins.DistributedTensorflowTrainingTask) + DistributedTensorflowTrainingTaskOrBuilder { + private static final long serialVersionUID = 0L; + // Use DistributedTensorflowTrainingTask.newBuilder() to construct. + private DistributedTensorflowTrainingTask(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } + private DistributedTensorflowTrainingTask() { + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } + private DistributedTensorflowTrainingTask( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + this(); + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + int mutable_bitField0_ = 0; + com.google.protobuf.UnknownFieldSet.Builder unknownFields = + com.google.protobuf.UnknownFieldSet.newBuilder(); + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + case 8: { + + workers_ = input.readInt32(); + break; + } + case 16: { + + psReplicas_ = input.readInt32(); + break; + } + case 24: { + + chiefReplicas_ = input.readInt32(); + break; + } + default: { + if (!parseUnknownField( + input, unknownFields, extensionRegistry, tag)) { + done = true; + } + break; + } + } + } + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(this); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException( + e).setUnfinishedMessage(this); + } finally { + this.unknownFields = unknownFields.build(); + makeExtensionsImmutable(); + } + } + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return flyteidl.plugins.Tensorflow.internal_static_flyteidl_plugins_DistributedTensorflowTrainingTask_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return flyteidl.plugins.Tensorflow.internal_static_flyteidl_plugins_DistributedTensorflowTrainingTask_fieldAccessorTable + .ensureFieldAccessorsInitialized( + flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask.class, flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask.Builder.class); + } + + public static final int WORKERS_FIELD_NUMBER = 1; + private int workers_; + /** + *
+     * number of worker, ps, chief replicas spawned in the cluster for this job
+     * 
+ * + * int32 workers = 1; + */ + public int getWorkers() { + return workers_; + } + + public static final int PS_REPLICAS_FIELD_NUMBER = 2; + private int psReplicas_; + /** + *
+     * PS -> Parameter server
+     * 
+ * + * int32 ps_replicas = 2; + */ + public int getPsReplicas() { + return psReplicas_; + } + + public static final int CHIEF_REPLICAS_FIELD_NUMBER = 3; + private int chiefReplicas_; + /** + * int32 chief_replicas = 3; + */ + public int getChiefReplicas() { + return chiefReplicas_; + } + + private byte memoizedIsInitialized = -1; + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; + + memoizedIsInitialized = 1; + return true; + } + + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) + throws java.io.IOException { + if (workers_ != 0) { + output.writeInt32(1, workers_); + } + if (psReplicas_ != 0) { + output.writeInt32(2, psReplicas_); + } + if (chiefReplicas_ != 0) { + output.writeInt32(3, chiefReplicas_); + } + unknownFields.writeTo(output); + } + + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + if (workers_ != 0) { + size += com.google.protobuf.CodedOutputStream + .computeInt32Size(1, workers_); + } + if (psReplicas_ != 0) { + size += com.google.protobuf.CodedOutputStream + .computeInt32Size(2, psReplicas_); + } + if (chiefReplicas_ != 0) { + size += com.google.protobuf.CodedOutputStream + .computeInt32Size(3, chiefReplicas_); + } + size += unknownFields.getSerializedSize(); + memoizedSize = size; + return size; + } + + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask)) { + return super.equals(obj); + } + flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask other = (flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask) obj; + + if (getWorkers() + != other.getWorkers()) return false; + if (getPsReplicas() + != other.getPsReplicas()) return false; + if (getChiefReplicas() + != other.getChiefReplicas()) return false; + if (!unknownFields.equals(other.unknownFields)) return false; + return true; + } + + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + hash = (37 * hash) + WORKERS_FIELD_NUMBER; + hash = (53 * hash) + getWorkers(); + hash = (37 * hash) + PS_REPLICAS_FIELD_NUMBER; + hash = (53 * hash) + getPsReplicas(); + hash = (37 * hash) + CHIEF_REPLICAS_FIELD_NUMBER; + hash = (53 * hash) + getChiefReplicas(); + hash = (29 * hash) + unknownFields.hashCode(); + memoizedHashCode = hash; + return hash; + } + + public static flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask parseFrom( + java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask parseFrom( + java.nio.ByteBuffer data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask parseFrom( + com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask parseFrom( + byte[] data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask parseFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + public static flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input); + } + public static flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask parseDelimitedFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input, extensionRegistry); + } + public static flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask parseFrom( + com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + + @java.lang.Override + public Builder newBuilderForType() { return newBuilder(); } + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } + public static Builder newBuilder(flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE + ? new Builder() : new Builder().mergeFrom(this); + } + + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } + /** + *
+     * Custom proto for plugin that enables distributed training using https://github.com/kubeflow/tf-operator
+     * 
+ * + * Protobuf type {@code flyteidl.plugins.DistributedTensorflowTrainingTask} + */ + public static final class Builder extends + com.google.protobuf.GeneratedMessageV3.Builder implements + // @@protoc_insertion_point(builder_implements:flyteidl.plugins.DistributedTensorflowTrainingTask) + flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTaskOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return flyteidl.plugins.Tensorflow.internal_static_flyteidl_plugins_DistributedTensorflowTrainingTask_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return flyteidl.plugins.Tensorflow.internal_static_flyteidl_plugins_DistributedTensorflowTrainingTask_fieldAccessorTable + .ensureFieldAccessorsInitialized( + flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask.class, flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask.Builder.class); + } + + // Construct using flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask.newBuilder() + private Builder() { + maybeForceBuilderInitialization(); + } + + private Builder( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + maybeForceBuilderInitialization(); + } + private void maybeForceBuilderInitialization() { + if (com.google.protobuf.GeneratedMessageV3 + .alwaysUseFieldBuilders) { + } + } + @java.lang.Override + public Builder clear() { + super.clear(); + workers_ = 0; + + psReplicas_ = 0; + + chiefReplicas_ = 0; + + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor + getDescriptorForType() { + return flyteidl.plugins.Tensorflow.internal_static_flyteidl_plugins_DistributedTensorflowTrainingTask_descriptor; + } + + @java.lang.Override + public flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask getDefaultInstanceForType() { + return flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask.getDefaultInstance(); + } + + @java.lang.Override + public flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask build() { + flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask buildPartial() { + flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask result = new flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask(this); + result.workers_ = workers_; + result.psReplicas_ = psReplicas_; + result.chiefReplicas_ = chiefReplicas_; + onBuilt(); + return result; + } + + @java.lang.Override + public Builder clone() { + return super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.addRepeatedField(field, value); + } + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask) { + return mergeFrom((flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask)other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask other) { + if (other == flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask.getDefaultInstance()) return this; + if (other.getWorkers() != 0) { + setWorkers(other.getWorkers()); + } + if (other.getPsReplicas() != 0) { + setPsReplicas(other.getPsReplicas()); + } + if (other.getChiefReplicas() != 0) { + setChiefReplicas(other.getChiefReplicas()); + } + this.mergeUnknownFields(other.unknownFields); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask parsedMessage = null; + try { + parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + parsedMessage = (flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask) e.getUnfinishedMessage(); + throw e.unwrapIOException(); + } finally { + if (parsedMessage != null) { + mergeFrom(parsedMessage); + } + } + return this; + } + + private int workers_ ; + /** + *
+       * number of worker, ps, chief replicas spawned in the cluster for this job
+       * 
+ * + * int32 workers = 1; + */ + public int getWorkers() { + return workers_; + } + /** + *
+       * number of worker, ps, chief replicas spawned in the cluster for this job
+       * 
+ * + * int32 workers = 1; + */ + public Builder setWorkers(int value) { + + workers_ = value; + onChanged(); + return this; + } + /** + *
+       * number of worker, ps, chief replicas spawned in the cluster for this job
+       * 
+ * + * int32 workers = 1; + */ + public Builder clearWorkers() { + + workers_ = 0; + onChanged(); + return this; + } + + private int psReplicas_ ; + /** + *
+       * PS -> Parameter server
+       * 
+ * + * int32 ps_replicas = 2; + */ + public int getPsReplicas() { + return psReplicas_; + } + /** + *
+       * PS -> Parameter server
+       * 
+ * + * int32 ps_replicas = 2; + */ + public Builder setPsReplicas(int value) { + + psReplicas_ = value; + onChanged(); + return this; + } + /** + *
+       * PS -> Parameter server
+       * 
+ * + * int32 ps_replicas = 2; + */ + public Builder clearPsReplicas() { + + psReplicas_ = 0; + onChanged(); + return this; + } + + private int chiefReplicas_ ; + /** + * int32 chief_replicas = 3; + */ + public int getChiefReplicas() { + return chiefReplicas_; + } + /** + * int32 chief_replicas = 3; + */ + public Builder setChiefReplicas(int value) { + + chiefReplicas_ = value; + onChanged(); + return this; + } + /** + * int32 chief_replicas = 3; + */ + public Builder clearChiefReplicas() { + + chiefReplicas_ = 0; + onChanged(); + return this; + } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + + // @@protoc_insertion_point(builder_scope:flyteidl.plugins.DistributedTensorflowTrainingTask) + } + + // @@protoc_insertion_point(class_scope:flyteidl.plugins.DistributedTensorflowTrainingTask) + private static final flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask DEFAULT_INSTANCE; + static { + DEFAULT_INSTANCE = new flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask(); + } + + public static flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask getDefaultInstance() { + return DEFAULT_INSTANCE; + } + + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { + @java.lang.Override + public DistributedTensorflowTrainingTask parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return new DistributedTensorflowTrainingTask(input, extensionRegistry); + } + }; + + public static com.google.protobuf.Parser parser() { + return PARSER; + } + + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } + + @java.lang.Override + public flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + + } + + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_flyteidl_plugins_DistributedTensorflowTrainingTask_descriptor; + private static final + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_flyteidl_plugins_DistributedTensorflowTrainingTask_fieldAccessorTable; + + public static com.google.protobuf.Descriptors.FileDescriptor + getDescriptor() { + return descriptor; + } + private static com.google.protobuf.Descriptors.FileDescriptor + descriptor; + static { + java.lang.String[] descriptorData = { + "\n!flyteidl/plugins/tensorflow.proto\022\020fly" + + "teidl.plugins\"a\n!DistributedTensorflowTr" + + "ainingTask\022\017\n\007workers\030\001 \001(\005\022\023\n\013ps_replic" + + "as\030\002 \001(\005\022\026\n\016chief_replicas\030\003 \001(\005B5Z3gith" + + "ub.com/lyft/flyteidl/gen/pb-go/flyteidl/" + + "pluginsb\006proto3" + }; + com.google.protobuf.Descriptors.FileDescriptor.InternalDescriptorAssigner assigner = + new com.google.protobuf.Descriptors.FileDescriptor. InternalDescriptorAssigner() { + public com.google.protobuf.ExtensionRegistry assignDescriptors( + com.google.protobuf.Descriptors.FileDescriptor root) { + descriptor = root; + return null; + } + }; + com.google.protobuf.Descriptors.FileDescriptor + .internalBuildGeneratedFileFrom(descriptorData, + new com.google.protobuf.Descriptors.FileDescriptor[] { + }, assigner); + internal_static_flyteidl_plugins_DistributedTensorflowTrainingTask_descriptor = + getDescriptor().getMessageTypes().get(0); + internal_static_flyteidl_plugins_DistributedTensorflowTrainingTask_fieldAccessorTable = new + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + internal_static_flyteidl_plugins_DistributedTensorflowTrainingTask_descriptor, + new java.lang.String[] { "Workers", "PsReplicas", "ChiefReplicas", }); + } + + // @@protoc_insertion_point(outer_class_scope) +} diff --git a/gen/pb-protodoc/flyteidl/plugins/index.rst b/gen/pb-protodoc/flyteidl/plugins/index.rst index 43ca14335..a44c59d49 100644 --- a/gen/pb-protodoc/flyteidl/plugins/index.rst +++ b/gen/pb-protodoc/flyteidl/plugins/index.rst @@ -16,4 +16,5 @@ Plugins available in the Flyte system. qubole.proto sidecar.proto spark.proto + tensorflow.proto waitable.proto diff --git a/gen/pb-protodoc/flyteidl/plugins/tensorflow.proto.rst b/gen/pb-protodoc/flyteidl/plugins/tensorflow.proto.rst new file mode 100644 index 000000000..33b9fff66 --- /dev/null +++ b/gen/pb-protodoc/flyteidl/plugins/tensorflow.proto.rst @@ -0,0 +1,40 @@ +.. _api_file_flyteidl/plugins/tensorflow.proto: + +tensorflow.proto +================================= + +.. _api_msg_flyteidl.plugins.DistributedTensorflowTrainingTask: + +flyteidl.plugins.DistributedTensorflowTrainingTask +-------------------------------------------------- + +`[flyteidl.plugins.DistributedTensorflowTrainingTask proto] `_ + +Custom proto for plugin that enables distributed training using https://github.com/kubeflow/tf-operator + +.. code-block:: json + + { + "workers": "...", + "ps_replicas": "...", + "chief_replicas": "..." + } + +.. _api_field_flyteidl.plugins.DistributedTensorflowTrainingTask.workers: + +workers + (`int32 `_) number of worker, ps, chief replicas spawned in the cluster for this job + + +.. _api_field_flyteidl.plugins.DistributedTensorflowTrainingTask.ps_replicas: + +ps_replicas + (`int32 `_) PS -> Parameter server + + +.. _api_field_flyteidl.plugins.DistributedTensorflowTrainingTask.chief_replicas: + +chief_replicas + (`int32 `_) + + diff --git a/gen/pb_python/flyteidl/plugins/tensorflow_pb2.py b/gen/pb_python/flyteidl/plugins/tensorflow_pb2.py new file mode 100644 index 000000000..aae0170ad --- /dev/null +++ b/gen/pb_python/flyteidl/plugins/tensorflow_pb2.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: flyteidl/plugins/tensorflow.proto + +import sys +_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='flyteidl/plugins/tensorflow.proto', + package='flyteidl.plugins', + syntax='proto3', + serialized_options=_b('Z3github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins'), + serialized_pb=_b('\n!flyteidl/plugins/tensorflow.proto\x12\x10\x66lyteidl.plugins\"a\n!DistributedTensorflowTrainingTask\x12\x0f\n\x07workers\x18\x01 \x01(\x05\x12\x13\n\x0bps_replicas\x18\x02 \x01(\x05\x12\x16\n\x0e\x63hief_replicas\x18\x03 \x01(\x05\x42\x35Z3github.com/lyft/flyteidl/gen/pb-go/flyteidl/pluginsb\x06proto3') +) + + + + +_DISTRIBUTEDTENSORFLOWTRAININGTASK = _descriptor.Descriptor( + name='DistributedTensorflowTrainingTask', + full_name='flyteidl.plugins.DistributedTensorflowTrainingTask', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='workers', full_name='flyteidl.plugins.DistributedTensorflowTrainingTask.workers', index=0, + number=1, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='ps_replicas', full_name='flyteidl.plugins.DistributedTensorflowTrainingTask.ps_replicas', index=1, + number=2, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='chief_replicas', full_name='flyteidl.plugins.DistributedTensorflowTrainingTask.chief_replicas', index=2, + number=3, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=55, + serialized_end=152, +) + +DESCRIPTOR.message_types_by_name['DistributedTensorflowTrainingTask'] = _DISTRIBUTEDTENSORFLOWTRAININGTASK +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +DistributedTensorflowTrainingTask = _reflection.GeneratedProtocolMessageType('DistributedTensorflowTrainingTask', (_message.Message,), dict( + DESCRIPTOR = _DISTRIBUTEDTENSORFLOWTRAININGTASK, + __module__ = 'flyteidl.plugins.tensorflow_pb2' + # @@protoc_insertion_point(class_scope:flyteidl.plugins.DistributedTensorflowTrainingTask) + )) +_sym_db.RegisterMessage(DistributedTensorflowTrainingTask) + + +DESCRIPTOR._options = None +# @@protoc_insertion_point(module_scope) diff --git a/gen/pb_python/flyteidl/plugins/tensorflow_pb2_grpc.py b/gen/pb_python/flyteidl/plugins/tensorflow_pb2_grpc.py new file mode 100644 index 000000000..a89435267 --- /dev/null +++ b/gen/pb_python/flyteidl/plugins/tensorflow_pb2_grpc.py @@ -0,0 +1,3 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +import grpc + diff --git a/package.json b/package.json index 75718171c..708e25737 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@lyft/flyteidl", - "version": "0.17.34", + "version": "0.17.35", "description": "Compiled protocol buffers and gRPC service clients/servers for Flyte IDLs", "repository": { "type": "git", diff --git a/protos/flyteidl/plugins/tensorflow.proto b/protos/flyteidl/plugins/tensorflow.proto new file mode 100644 index 000000000..992eb045a --- /dev/null +++ b/protos/flyteidl/plugins/tensorflow.proto @@ -0,0 +1,14 @@ +syntax = "proto3"; + +package flyteidl.plugins; + +option go_package = "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins"; + +// Custom proto for plugin that enables distributed training using https://github.com/kubeflow/tf-operator +message DistributedTensorflowTrainingTask { + // number of worker, ps, chief replicas spawned in the cluster for this job + int32 workers = 1; + // PS -> Parameter server + int32 ps_replicas = 2; + int32 chief_replicas = 3; +} diff --git a/setup.py b/setup.py index 252b18ce4..f021ceb14 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ from setuptools import setup, find_packages -__version__ = '0.17.34' +__version__ = '0.17.35' setup( name='flyteidl',