Skip to content

Commit

Permalink
server: Refactor wal version to use visitor pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
serathius committed Nov 29, 2021
1 parent bb90f8f commit 2777fd3
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 122 deletions.
168 changes: 125 additions & 43 deletions server/storage/wal/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,32 +53,63 @@ func (w *walVersion) MinimalEtcdVersion() *semver.Version {
func MinimalEtcdVersion(ents []raftpb.Entry) *semver.Version {
var maxVer *semver.Version
for _, ent := range ents {
maxVer = maxVersion(maxVer, etcdVersionFromEntry(ent))
err := visitEntry(ent, func(path protoreflect.FullName, ver *semver.Version) error {
maxVer = maxVersion(maxVer, ver)
return nil
})
if err != nil {
panic(err)
}
}
return maxVer
}

func etcdVersionFromEntry(ent raftpb.Entry) *semver.Version {
msgVer := etcdVersionFromMessage(proto.MessageReflect(&ent))
dataVer := etcdVersionFromData(ent.Type, ent.Data)
return maxVersion(msgVer, dataVer)
type Visitor func(path protoreflect.FullName, ver *semver.Version) error

func VisitFileDescriptor(file protoreflect.FileDescriptor, visitor Visitor) error {
msgs := file.Messages()
for i := 0; i < msgs.Len(); i++ {
err := visitMessageDescriptor(msgs.Get(i), visitor)
if err != nil {
return err
}
}
enums := file.Enums()
for i := 0; i < enums.Len(); i++ {
err := visitEnumDescriptor(enums.Get(i), visitor)
if err != nil {
return err
}
}
return nil
}

func visitEntry(ent raftpb.Entry, visitor Visitor) error {
err := visitMessage(proto.MessageReflect(&ent), visitor)
if err != nil {
return err
}
return visitEntryData(ent.Type, ent.Data, visitor)
}

func etcdVersionFromData(entryType raftpb.EntryType, data []byte) *semver.Version {
func visitEntryData(entryType raftpb.EntryType, data []byte, visitor Visitor) error {
var msg protoreflect.Message
var ver *semver.Version
switch entryType {
case raftpb.EntryNormal:
var raftReq etcdserverpb.InternalRaftRequest
err := pbutil.Unmarshaler(&raftReq).Unmarshal(data)
if err != nil {
return nil
return err
}
msg = proto.MessageReflect(&raftReq)
if raftReq.ClusterVersionSet != nil {
ver, err = semver.NewVersion(raftReq.ClusterVersionSet.Ver)
ver, err := semver.NewVersion(raftReq.ClusterVersionSet.Ver)
if err != nil {
panic(err)
return err
}
err = visitor(msg.Descriptor().FullName(), ver)
if err != nil {
return err
}
}
case raftpb.EntryConfChange:
Expand All @@ -98,46 +129,106 @@ func etcdVersionFromData(entryType raftpb.EntryType, data []byte) *semver.Versio
default:
panic("unhandled")
}
return maxVersion(etcdVersionFromMessage(msg), ver)
return visitMessage(msg, visitor)
}

func etcdVersionFromMessage(m protoreflect.Message) *semver.Version {
var maxVer *semver.Version
md := m.Descriptor()
opts := md.Options().(*descriptorpb.MessageOptions)
if opts != nil {
ver, _ := EtcdVersionFromOptionsString(opts.String())
maxVer = maxVersion(maxVer, ver)
func visitMessageDescriptor(md protoreflect.MessageDescriptor, visitor Visitor) error {
err := visitDescriptor(md, visitor)
if err != nil {
return err
}
fields := md.Fields()
for i := 0; i < fields.Len(); i++ {
fd := fields.Get(i)
err = visitDescriptor(fd, visitor)
if err != nil {
return err
}
}

enums := md.Enums()
for i := 0; i < enums.Len(); i++ {
err := visitEnumDescriptor(enums.Get(i), visitor)
if err != nil {
return err
}
}
return err
}

func visitMessage(m protoreflect.Message, visitor Visitor) error {
md := m.Descriptor()
err := visitDescriptor(md, visitor)
if err != nil {
return err
}
m.Range(func(field protoreflect.FieldDescriptor, value protoreflect.Value) bool {
fd := md.Fields().Get(field.Index())
maxVer = maxVersion(maxVer, etcdVersionFromField(fd))
err = visitDescriptor(fd, visitor)
if err != nil {
return false
}

switch m := value.Interface().(type) {
case protoreflect.Message:
maxVer = maxVersion(maxVer, etcdVersionFromMessage(m))
err = visitMessage(m, visitor)
case protoreflect.EnumNumber:
maxVer = maxVersion(maxVer, etcdVersionFromEnum(field.Enum(), m))
err = visitEnumNumber(fd.Enum(), m, visitor)
}
if err != nil {
return false
}
return true
})
return maxVer
return err
}

func etcdVersionFromEnum(enum protoreflect.EnumDescriptor, value protoreflect.EnumNumber) *semver.Version {
var maxVer *semver.Version
enumOpts := enum.Options().(*descriptorpb.EnumOptions)
if enumOpts != nil {
ver, _ := EtcdVersionFromOptionsString(enumOpts.String())
maxVer = maxVersion(maxVer, ver)
func visitEnumDescriptor(enum protoreflect.EnumDescriptor, visitor Visitor) error {
err := visitDescriptor(enum, visitor)
if err != nil {
return err
}
fields := enum.Values()
for i := 0; i < fields.Len(); i++ {
fd := fields.Get(i)
err = visitDescriptor(fd, visitor)
if err != nil {
return err
}
}
return err
}

func visitEnumNumber(enum protoreflect.EnumDescriptor, number protoreflect.EnumNumber, visitor Visitor) error {
err := visitDescriptor(enum, visitor)
if err != nil {
return err
}
valueDesc := enum.Values().Get(int(value))
valueOpts := valueDesc.Options().(*descriptorpb.EnumValueOptions)
return visitEnumValue(enum.Values().Get(int(number)), visitor)
}

func visitEnumValue(enum protoreflect.EnumValueDescriptor, visitor Visitor) error {
valueOpts := enum.Options().(*descriptorpb.EnumValueOptions)
if valueOpts != nil {
ver, _ := EtcdVersionFromOptionsString(valueOpts.String())
maxVer = maxVersion(maxVer, ver)
ver, _ := etcdVersionFromOptionsString(valueOpts.String())
err := visitor(enum.FullName(), ver)
if err != nil {
return err
}
}
return maxVer
return nil
}

func visitDescriptor(md protoreflect.Descriptor, visitor Visitor) error {
opts, ok := md.Options().(fmt.Stringer)
if !ok {
return nil
}
ver, err := etcdVersionFromOptionsString(opts.String())
if err != nil {
return fmt.Errorf("%s: %s", md.FullName(), err)
}
return visitor(md.FullName(), ver)
}

func maxVersion(a *semver.Version, b *semver.Version) *semver.Version {
Expand All @@ -147,16 +238,7 @@ func maxVersion(a *semver.Version, b *semver.Version) *semver.Version {
return b
}

func etcdVersionFromField(fd protoreflect.FieldDescriptor) *semver.Version {
opts := fd.Options().(*descriptorpb.FieldOptions)
if opts == nil {
return nil
}
ver, _ := EtcdVersionFromOptionsString(opts.String())
return ver
}

func EtcdVersionFromOptionsString(opts string) (*semver.Version, error) {
func etcdVersionFromOptionsString(opts string) (*semver.Version, error) {
// TODO: Use proto.GetExtention when gogo/protobuf is usable with protoreflect
msgs := []string{"[versionpb.etcd_version_msg]:", "[versionpb.etcd_version_field]:", "[versionpb.etcd_version_enum]:", "[versionpb.etcd_version_enum_value]:"}
var end, index int
Expand Down
21 changes: 16 additions & 5 deletions server/storage/wal/version_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"go.etcd.io/etcd/api/v3/membershippb"
"go.etcd.io/etcd/pkg/v3/pbutil"
"go.etcd.io/etcd/raft/v3/raftpb"
"google.golang.org/protobuf/reflect/protoreflect"
)

var (
Expand Down Expand Up @@ -97,8 +98,13 @@ func TestEtcdVersionFromEntry(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
ver := etcdVersionFromEntry(tc.input)
assert.Equal(t, tc.expect, ver)
var maxVer *semver.Version
err := visitEntry(tc.input, func(path protoreflect.FullName, ver *semver.Version) error {
maxVer = maxVersion(maxVer, ver)
return nil
})
assert.NoError(t, err)
assert.Equal(t, tc.expect, maxVer)
})
}
}
Expand Down Expand Up @@ -162,8 +168,13 @@ func TestEtcdVersionFromMessage(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
ver := etcdVersionFromMessage(proto.MessageReflect(tc.input))
assert.Equal(t, tc.expect, ver)
var maxVer *semver.Version
err := visitMessage(proto.MessageReflect(tc.input), func(path protoreflect.FullName, ver *semver.Version) error {
maxVer = maxVersion(maxVer, ver)
return nil
})
assert.NoError(t, err)
assert.Equal(t, tc.expect, maxVer)
})
}
}
Expand Down Expand Up @@ -237,7 +248,7 @@ func TestEtcdVersionFromFieldOptionsString(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.input, func(t *testing.T) {
ver, err := EtcdVersionFromOptionsString(tc.input)
ver, err := etcdVersionFromOptionsString(tc.input)
assert.NoError(t, err)
assert.Equal(t, ver, tc.expect)
})
Expand Down
76 changes: 2 additions & 74 deletions tools/proto-annotations/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,88 +101,16 @@ func allEtcdVersionAnnotations() (annotations []etcdVersionAnnotation, err error
}

func fileEtcdVersionAnnotations(file protoreflect.FileDescriptor) (annotations []etcdVersionAnnotation, err error) {
err = visitFileDescriptor(file, func(path string, ver *semver.Version) error {
err = wal.VisitFileDescriptor(file, func(path protoreflect.FullName, ver *semver.Version) error {
a := etcdVersionAnnotation{fullName: path, version: ver}
annotations = append(annotations, a)
return nil
})
return annotations, err
}

type Visitor func(path string, ver *semver.Version) error

func visitFileDescriptor(file protoreflect.FileDescriptor, visitor Visitor) error {
msgs := file.Messages()
for i := 0; i < msgs.Len(); i++ {
err := visitMessageDescriptor(msgs.Get(i), visitor)
if err != nil {
return err
}
}
enums := file.Enums()
for i := 0; i < enums.Len(); i++ {
err := visitEnumDescriptor(enums.Get(i), visitor)
if err != nil {
return err
}
}
return nil
}

func visitMessageDescriptor(md protoreflect.MessageDescriptor, visitor Visitor) error {
err := VisitDescriptor(md, visitor)
if err != nil {
return err
}
fields := md.Fields()
for i := 0; i < fields.Len(); i++ {
fd := fields.Get(i)
err = VisitDescriptor(fd, visitor)
if err != nil {
return err
}
}

enums := md.Enums()
for i := 0; i < enums.Len(); i++ {
err := visitEnumDescriptor(enums.Get(i), visitor)
if err != nil {
return err
}
}
return err
}

func visitEnumDescriptor(enum protoreflect.EnumDescriptor, visitor Visitor) error {
err := VisitDescriptor(enum, visitor)
if err != nil {
return err
}
fields := enum.Values()
for i := 0; i < fields.Len(); i++ {
fd := fields.Get(i)
err = VisitDescriptor(fd, visitor)
if err != nil {
return err
}
}
return err
}

func VisitDescriptor(md protoreflect.Descriptor, visitor Visitor) error {
s, ok := md.Options().(fmt.Stringer)
if !ok {
return nil
}
ver, err := wal.EtcdVersionFromOptionsString(s.String())
if err != nil {
return fmt.Errorf("%s: %s", md.FullName(), err)
}
return visitor(string(md.FullName()), ver)
}

type etcdVersionAnnotation struct {
fullName string
fullName protoreflect.FullName
version *semver.Version
}

Expand Down

0 comments on commit 2777fd3

Please sign in to comment.