diff --git a/internal/generator/accessor/doc.go b/internal/generator/accessor/doc.go new file mode 100644 index 0000000..a88e328 --- /dev/null +++ b/internal/generator/accessor/doc.go @@ -0,0 +1,6 @@ +// Package accessor generates an accessor method for each type in "one-of +// groups". While "protoc-gen-go" plugin already generates accessor methods for +// for each type in "one-of groups", the generated methods lack the ability to +// differentiate between the absence of a value and the presence of a zero +// value. +package accessor diff --git a/internal/generator/accessor/generate.go b/internal/generator/accessor/generate.go new file mode 100644 index 0000000..81d8d47 --- /dev/null +++ b/internal/generator/accessor/generate.go @@ -0,0 +1,20 @@ +package accessor + +import ( + "github.com/dave/jennifer/jen" + "github.com/dogmatiq/primo/internal/generator/internal/scope" +) + +// Generate generates accessor methods for each type in one-of group. In +// contrast to already available accessor methods generated by "protoc-gen-go" +// plugin, these accessor methods return a boolean value indicating whether the +// value is present or not. +func Generate(code *jen.File, f *scope.File) error { + for _, m := range f.Messages() { + for _, g := range m.OneOfGroups() { + generateForOneOf(code, g) + } + } + + return nil +} diff --git a/internal/generator/accessor/oneof.go b/internal/generator/accessor/oneof.go new file mode 100644 index 0000000..1e6a6a4 --- /dev/null +++ b/internal/generator/accessor/oneof.go @@ -0,0 +1,74 @@ +package accessor + +import ( + "github.com/dave/jennifer/jen" + "github.com/dogmatiq/primo/internal/generator/internal/scope" +) + +func generateForOneOf(code *jen.File, g *scope.OneOfGroup) { + for _, o := range g.Options { + oneOfAccessorTryFunc(code, o) + } +} + +func oneOfAccessorTryFunc(code *jen.File, o *scope.OneOfOption) { + methodName := "Try" + o.DiscriminatorFieldName + + code. + Commentf( + "%s returns the value of [%s] in one-of field x.%s.", + methodName, + o.DiscriminatorFieldName, + o.Group.GoFieldName, + ) + code.Comment("") + code.Comment("ok returns false if the value of this type is not set.") + + code. + Func(). + Params( + jen. + Id("x"). + Op("*"). + Id(o.Group.Message.GoTypeName), + ). + Id(methodName). + Params(). + Params( + jen. + Id("v"). + Add(o.Field.GoType()), + jen. + Id("ok"). + Add(jen.Bool()), + ). + Block( + jen. + If( + jen.List( + jen.Id("x"), + jen.Id("ok"), + ). + Op(":="). + Id("x"). + Dot("Get"+o.Group.GoFieldName). + Call(). + Assert( + jen.Op("*"). + Id(o.DiscriminatorTypeName), + ). + Op(";"). + Id("ok"), + ).Block( + jen.Return( + jen.Id("x"). + Dot(o.DiscriminatorFieldName), + jen.True(), + ), + ), + jen.Return( + jen.Id("v"), + jen.False(), + ), + ) +} diff --git a/internal/generator/generate.go b/internal/generator/generate.go index 033a07b..0729de0 100644 --- a/internal/generator/generate.go +++ b/internal/generator/generate.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/dave/jennifer/jen" + "github.com/dogmatiq/primo/internal/generator/accessor" "github.com/dogmatiq/primo/internal/generator/builder" "github.com/dogmatiq/primo/internal/generator/exhaustiveswitch" "github.com/dogmatiq/primo/internal/generator/internal/option" @@ -38,6 +39,7 @@ func Generate( builder.Generate, exhaustiveswitch.Generate, mutator.Generate, + accessor.Generate, ); err != nil { res.Error = proto.String(err.Error()) } diff --git a/internal/test/accessor/helper_test.go b/internal/test/accessor/helper_test.go new file mode 100644 index 0000000..f1213ab --- /dev/null +++ b/internal/test/accessor/helper_test.go @@ -0,0 +1,61 @@ +package accessor_test + +import ( + reflect "reflect" + "testing" + + "google.golang.org/protobuf/proto" +) + +// testAccessor calls a mutator method to set the value and verifies that the +// corresponding accessor returns the expected value and boolean ok value. +func testAccessor[M proto.Message, T comparable]( + t *testing.T, + mutator func(M, T), + accessor func(M) (T, bool), + want T, +) { + t.Helper() + + testAccessorFunc( + t, + mutator, + accessor, + want, + func(a, b T) bool { return a == b }, + ) +} + +// testAccessorFunc calls a mutator method to set the value and verifies that the +// corresponding accessor returns the expected value and boolean ok value. +func testAccessorFunc[M proto.Message, T any]( + t *testing.T, + mutate func(M, T), + access func(M) (T, bool), + want T, + eq func(T, T) bool, +) { + t.Helper() + + var m M + m = reflect.New( + reflect.TypeOf(m).Elem(), + ).Interface().(M) + + mutate(m, want) + + got, gotOK := access(m) + if !gotOK { + t.Fatalf( + "accessor did not return expected ok as true", + ) + } + + if !eq(got, want) { + t.Fatalf( + "accessor did not return the expected value: got: %v, want: %v", + got, + want, + ) + } +} diff --git a/internal/test/accessor/oneof.pb.go b/internal/test/accessor/oneof.pb.go new file mode 100644 index 0000000..35a7dcf --- /dev/null +++ b/internal/test/accessor/oneof.pb.go @@ -0,0 +1,205 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.33.0 +// protoc v5.26.1 +// source: github.com/dogmatiq/primo/internal/test/accessor/oneof.proto + +package accessor + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type OneOf struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // Types that are assignable to Group: + // + // *OneOf_FieldA + // *OneOf_FieldB + // *OneOf_FieldC + Group isOneOf_Group `protobuf_oneof:"group"` +} + +func (x *OneOf) Reset() { + *x = OneOf{} + if protoimpl.UnsafeEnabled { + mi := &file_github_com_dogmatiq_primo_internal_test_accessor_oneof_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *OneOf) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*OneOf) ProtoMessage() {} + +func (x *OneOf) ProtoReflect() protoreflect.Message { + mi := &file_github_com_dogmatiq_primo_internal_test_accessor_oneof_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use OneOf.ProtoReflect.Descriptor instead. +func (*OneOf) Descriptor() ([]byte, []int) { + return file_github_com_dogmatiq_primo_internal_test_accessor_oneof_proto_rawDescGZIP(), []int{0} +} + +func (m *OneOf) GetGroup() isOneOf_Group { + if m != nil { + return m.Group + } + return nil +} + +func (x *OneOf) GetFieldA() int32 { + if x, ok := x.GetGroup().(*OneOf_FieldA); ok { + return x.FieldA + } + return 0 +} + +func (x *OneOf) GetFieldB() int32 { + if x, ok := x.GetGroup().(*OneOf_FieldB); ok { + return x.FieldB + } + return 0 +} + +func (x *OneOf) GetFieldC() string { + if x, ok := x.GetGroup().(*OneOf_FieldC); ok { + return x.FieldC + } + return "" +} + +type isOneOf_Group interface { + isOneOf_Group() +} + +type OneOf_FieldA struct { + FieldA int32 `protobuf:"varint,1,opt,name=field_a,json=fieldA,proto3,oneof"` // note: two fields of the same type +} + +type OneOf_FieldB struct { + FieldB int32 `protobuf:"varint,2,opt,name=field_b,json=fieldB,proto3,oneof"` +} + +type OneOf_FieldC struct { + FieldC string `protobuf:"bytes,3,opt,name=field_c,json=fieldC,proto3,oneof"` +} + +func (*OneOf_FieldA) isOneOf_Group() {} + +func (*OneOf_FieldB) isOneOf_Group() {} + +func (*OneOf_FieldC) isOneOf_Group() {} + +var File_github_com_dogmatiq_primo_internal_test_accessor_oneof_proto protoreflect.FileDescriptor + +var file_github_com_dogmatiq_primo_internal_test_accessor_oneof_proto_rawDesc = []byte{ + 0x0a, 0x3c, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x64, 0x6f, 0x67, + 0x6d, 0x61, 0x74, 0x69, 0x71, 0x2f, 0x70, 0x72, 0x69, 0x6d, 0x6f, 0x2f, 0x69, 0x6e, 0x74, 0x65, + 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x2f, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, + 0x6f, 0x72, 0x2f, 0x6f, 0x6e, 0x65, 0x6f, 0x66, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x13, + 0x70, 0x72, 0x69, 0x6d, 0x6f, 0x2e, 0x74, 0x65, 0x73, 0x74, 0x2e, 0x6d, 0x75, 0x74, 0x61, 0x74, + 0x6f, 0x72, 0x73, 0x22, 0x61, 0x0a, 0x05, 0x4f, 0x6e, 0x65, 0x4f, 0x66, 0x12, 0x19, 0x0a, 0x07, + 0x66, 0x69, 0x65, 0x6c, 0x64, 0x5f, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x48, 0x00, 0x52, + 0x06, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x41, 0x12, 0x19, 0x0a, 0x07, 0x66, 0x69, 0x65, 0x6c, 0x64, + 0x5f, 0x62, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x48, 0x00, 0x52, 0x06, 0x66, 0x69, 0x65, 0x6c, + 0x64, 0x42, 0x12, 0x19, 0x0a, 0x07, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x5f, 0x63, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x06, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x43, 0x42, 0x07, 0x0a, + 0x05, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x42, 0x32, 0x5a, 0x30, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, + 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x64, 0x6f, 0x67, 0x6d, 0x61, 0x74, 0x69, 0x71, 0x2f, 0x70, 0x72, + 0x69, 0x6d, 0x6f, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x74, 0x65, 0x73, + 0x74, 0x2f, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, 0x6f, 0x72, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x33, +} + +var ( + file_github_com_dogmatiq_primo_internal_test_accessor_oneof_proto_rawDescOnce sync.Once + file_github_com_dogmatiq_primo_internal_test_accessor_oneof_proto_rawDescData = file_github_com_dogmatiq_primo_internal_test_accessor_oneof_proto_rawDesc +) + +func file_github_com_dogmatiq_primo_internal_test_accessor_oneof_proto_rawDescGZIP() []byte { + file_github_com_dogmatiq_primo_internal_test_accessor_oneof_proto_rawDescOnce.Do(func() { + file_github_com_dogmatiq_primo_internal_test_accessor_oneof_proto_rawDescData = protoimpl.X.CompressGZIP(file_github_com_dogmatiq_primo_internal_test_accessor_oneof_proto_rawDescData) + }) + return file_github_com_dogmatiq_primo_internal_test_accessor_oneof_proto_rawDescData +} + +var file_github_com_dogmatiq_primo_internal_test_accessor_oneof_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_github_com_dogmatiq_primo_internal_test_accessor_oneof_proto_goTypes = []interface{}{ + (*OneOf)(nil), // 0: primo.test.mutators.OneOf +} +var file_github_com_dogmatiq_primo_internal_test_accessor_oneof_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_github_com_dogmatiq_primo_internal_test_accessor_oneof_proto_init() } +func file_github_com_dogmatiq_primo_internal_test_accessor_oneof_proto_init() { + if File_github_com_dogmatiq_primo_internal_test_accessor_oneof_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_github_com_dogmatiq_primo_internal_test_accessor_oneof_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*OneOf); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + file_github_com_dogmatiq_primo_internal_test_accessor_oneof_proto_msgTypes[0].OneofWrappers = []interface{}{ + (*OneOf_FieldA)(nil), + (*OneOf_FieldB)(nil), + (*OneOf_FieldC)(nil), + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_github_com_dogmatiq_primo_internal_test_accessor_oneof_proto_rawDesc, + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_github_com_dogmatiq_primo_internal_test_accessor_oneof_proto_goTypes, + DependencyIndexes: file_github_com_dogmatiq_primo_internal_test_accessor_oneof_proto_depIdxs, + MessageInfos: file_github_com_dogmatiq_primo_internal_test_accessor_oneof_proto_msgTypes, + }.Build() + File_github_com_dogmatiq_primo_internal_test_accessor_oneof_proto = out.File + file_github_com_dogmatiq_primo_internal_test_accessor_oneof_proto_rawDesc = nil + file_github_com_dogmatiq_primo_internal_test_accessor_oneof_proto_goTypes = nil + file_github_com_dogmatiq_primo_internal_test_accessor_oneof_proto_depIdxs = nil +} diff --git a/internal/test/accessor/oneof.proto b/internal/test/accessor/oneof.proto new file mode 100644 index 0000000..1d9fcca --- /dev/null +++ b/internal/test/accessor/oneof.proto @@ -0,0 +1,12 @@ +syntax = "proto3"; +package primo.test.mutators; + +option go_package = "github.com/dogmatiq/primo/internal/test/accessor"; + +message OneOf { + oneof group { + int32 field_a = 1; // note: two fields of the same type + int32 field_b = 2; + string field_c = 3; + } +} diff --git a/internal/test/accessor/oneof_test.go b/internal/test/accessor/oneof_test.go new file mode 100644 index 0000000..acd6b92 --- /dev/null +++ b/internal/test/accessor/oneof_test.go @@ -0,0 +1,71 @@ +package accessor_test + +import ( + "testing" + + . "github.com/dogmatiq/primo/internal/test/accessor" +) + +func TestOneOfAccessor(t *testing.T) { + t.Parallel() + + t.Run("it returns the set value and ok as true to signify the presence of the value", func(t *testing.T) { + testAccessor( + t, + (*OneOf).SetFieldA, + (*OneOf).TryFieldA, + 123, + ) + + testAccessor( + t, + (*OneOf).SetFieldA, + (*OneOf).TryFieldA, + 0, + ) + + testAccessor( + t, + (*OneOf).SetFieldB, + (*OneOf).TryFieldB, + 456, + ) + + testAccessor( + t, + (*OneOf).SetFieldB, + (*OneOf).TryFieldB, + 0, + ) + + testAccessor( + t, + (*OneOf).SetFieldC, + (*OneOf).TryFieldC, + "", + ) + + testAccessor( + t, + (*OneOf).SetFieldC, + (*OneOf).TryFieldC, + "", + ) + }) + + t.Run("it ok as false to signify the absence of the value", func(t *testing.T) { + m := &OneOf{} + + if _, ok := m.TryFieldA(); ok { + t.Fatalf("TryFieldA() returned ok as true, want false") + } + + if _, ok := m.TryFieldB(); ok { + t.Fatalf("TryFieldB() returned ok as true, want false") + } + + if _, ok := m.TryFieldC(); ok { + t.Fatalf("TryFieldC() returned ok as true, want false") + } + }) +}