Skip to content

Commit

Permalink
Merge pull request #148 from coxley/unique-maps
Browse files Browse the repository at this point in the history
features/unmarshal_unique: fix codegen for keys/values
  • Loading branch information
vmg authored Nov 21, 2024
2 parents 71c992b + bbb5fce commit 79df5c4
Show file tree
Hide file tree
Showing 7 changed files with 664 additions and 20 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
bin
_vendor
conformance/marshal.log
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ gen-testproto: get-grpc-testproto gen-wkt-testproto install
testproto/proto3opt/opt.proto \
testproto/proto2/scalars.proto \
testproto/unsafe/unsafe.proto \
testproto/unique/unique.proto \
|| exit 1;
$(PROTOBUF_ROOT)/src/protoc \
--proto_path=testproto \
Expand Down
10 changes: 5 additions & 5 deletions features/unmarshal/unmarshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,7 @@ func (p *unmarshal) declareMapField(varName string, nullable bool, field *protog
}
}

func (p *unmarshal) mapField(varName string, field *protogen.Field) {
unique := proto.GetExtension(field.Desc.Options(), vtproto.E_Options).(*vtproto.Opts).GetUnique()

func (p *unmarshal) mapField(varName string, field *protogen.Field, unique bool) {
switch field.Desc.Kind() {
case protoreflect.DoubleKind:
p.P(`var `, varName, `temp uint64`)
Expand Down Expand Up @@ -509,6 +507,8 @@ func (p *unmarshal) fieldItem(field *protogen.Field, fieldname string, message *
p.P(`m.`, fieldname, ` = &`, field.GoIdent, "{", field.GoName, `: v}`)
p.P(`}`)
} else if field.Desc.IsMap() {
unique := proto.GetExtension(field.Desc.Options(), vtproto.E_Options).(*vtproto.Opts).GetUnique()

goTyp, _ := p.FieldGoType(field)
goTypK, _ := p.FieldGoType(field.Message.Fields[0])
goTypV, _ := p.FieldGoType(field.Message.Fields[1])
Expand All @@ -527,9 +527,9 @@ func (p *unmarshal) fieldItem(field *protogen.Field, fieldname string, message *
p.P(`fieldNum := int32(wire >> 3)`)

p.P(`if fieldNum == 1 {`)
p.mapField("mapkey", field.Message.Fields[0])
p.mapField("mapkey", field.Message.Fields[0], unique)
p.P(`} else if fieldNum == 2 {`)
p.mapField("mapvalue", field.Message.Fields[1])
p.mapField("mapvalue", field.Message.Fields[1], unique)
p.P(`} else {`)
p.P(`iNdEx = entryPreIndex`)
p.P(`skippy, err := `, p.Helper("Skip"), `(dAtA[iNdEx:])`)
Expand Down
62 changes: 48 additions & 14 deletions testproto/unique/unique.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions testproto/unique/unique.proto
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@ import "github.com/planetscale/vtprotobuf/vtproto/ext.proto";

message UniqueFieldExtension {
string foo = 1 [(vtproto.options).unique = true];
map<string,int64> bar = 2 [(vtproto.options).unique = true];
map<int64,string> baz = 3 [(vtproto.options).unique = true];
}
18 changes: 17 additions & 1 deletion testproto/unique/unique_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package unique

import (
"maps"
"slices"
"testing"
"unsafe"

Expand All @@ -10,6 +12,8 @@ import (
func TestUnmarshalSameMemory(t *testing.T) {
m := &UniqueFieldExtension{
Foo: "bar",
Bar: map[string]int64{"key": 100},
Baz: map[int64]string{100: "value"},
}

b, err := m.MarshalVTStrict()
Expand All @@ -21,5 +25,17 @@ func TestUnmarshalSameMemory(t *testing.T) {
m3 := &UniqueFieldExtension{}
require.NoError(t, m3.UnmarshalVT(b))

require.Equal(t, unsafe.StringData(m2.Foo), unsafe.StringData(m3.Foo))
require.Same(t, unsafe.StringData(m2.Foo), unsafe.StringData(m3.Foo), "string field")

keys2 := slices.Collect(maps.Keys(m2.Bar))
keys3 := slices.Collect(maps.Keys(m3.Bar))
require.Len(t, keys2, 1)
require.Len(t, keys3, 1)
require.Same(t, unsafe.StringData(keys2[0]), unsafe.StringData(keys3[0]), "string key")

values2 := slices.Collect(maps.Values(m2.Baz))
values3 := slices.Collect(maps.Values(m3.Baz))
require.Len(t, values2, 1)
require.Len(t, values2, 1)
require.Same(t, unsafe.StringData(values2[0]), unsafe.StringData(values3[0]), "string value")
}
Loading

0 comments on commit 79df5c4

Please sign in to comment.