Skip to content

Commit 79df5c4

Browse files
authored
Merge pull request #148 from coxley/unique-maps
features/unmarshal_unique: fix codegen for keys/values
2 parents 71c992b + bbb5fce commit 79df5c4

File tree

7 files changed

+664
-20
lines changed

7 files changed

+664
-20
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
bin
22
_vendor
3+
conformance/marshal.log

Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ gen-testproto: get-grpc-testproto gen-wkt-testproto install
6464
testproto/proto3opt/opt.proto \
6565
testproto/proto2/scalars.proto \
6666
testproto/unsafe/unsafe.proto \
67+
testproto/unique/unique.proto \
6768
|| exit 1;
6869
$(PROTOBUF_ROOT)/src/protoc \
6970
--proto_path=testproto \

features/unmarshal/unmarshal.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,7 @@ func (p *unmarshal) declareMapField(varName string, nullable bool, field *protog
157157
}
158158
}
159159

160-
func (p *unmarshal) mapField(varName string, field *protogen.Field) {
161-
unique := proto.GetExtension(field.Desc.Options(), vtproto.E_Options).(*vtproto.Opts).GetUnique()
162-
160+
func (p *unmarshal) mapField(varName string, field *protogen.Field, unique bool) {
163161
switch field.Desc.Kind() {
164162
case protoreflect.DoubleKind:
165163
p.P(`var `, varName, `temp uint64`)
@@ -509,6 +507,8 @@ func (p *unmarshal) fieldItem(field *protogen.Field, fieldname string, message *
509507
p.P(`m.`, fieldname, ` = &`, field.GoIdent, "{", field.GoName, `: v}`)
510508
p.P(`}`)
511509
} else if field.Desc.IsMap() {
510+
unique := proto.GetExtension(field.Desc.Options(), vtproto.E_Options).(*vtproto.Opts).GetUnique()
511+
512512
goTyp, _ := p.FieldGoType(field)
513513
goTypK, _ := p.FieldGoType(field.Message.Fields[0])
514514
goTypV, _ := p.FieldGoType(field.Message.Fields[1])
@@ -527,9 +527,9 @@ func (p *unmarshal) fieldItem(field *protogen.Field, fieldname string, message *
527527
p.P(`fieldNum := int32(wire >> 3)`)
528528

529529
p.P(`if fieldNum == 1 {`)
530-
p.mapField("mapkey", field.Message.Fields[0])
530+
p.mapField("mapkey", field.Message.Fields[0], unique)
531531
p.P(`} else if fieldNum == 2 {`)
532-
p.mapField("mapvalue", field.Message.Fields[1])
532+
p.mapField("mapvalue", field.Message.Fields[1], unique)
533533
p.P(`} else {`)
534534
p.P(`iNdEx = entryPreIndex`)
535535
p.P(`skippy, err := `, p.Helper("Skip"), `(dAtA[iNdEx:])`)

testproto/unique/unique.pb.go

Lines changed: 48 additions & 14 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

testproto/unique/unique.proto

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,6 @@ import "github.com/planetscale/vtprotobuf/vtproto/ext.proto";
55

66
message UniqueFieldExtension {
77
string foo = 1 [(vtproto.options).unique = true];
8+
map<string,int64> bar = 2 [(vtproto.options).unique = true];
9+
map<int64,string> baz = 3 [(vtproto.options).unique = true];
810
}

testproto/unique/unique_test.go

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package unique
22

33
import (
4+
"maps"
5+
"slices"
46
"testing"
57
"unsafe"
68

@@ -10,6 +12,8 @@ import (
1012
func TestUnmarshalSameMemory(t *testing.T) {
1113
m := &UniqueFieldExtension{
1214
Foo: "bar",
15+
Bar: map[string]int64{"key": 100},
16+
Baz: map[int64]string{100: "value"},
1317
}
1418

1519
b, err := m.MarshalVTStrict()
@@ -21,5 +25,17 @@ func TestUnmarshalSameMemory(t *testing.T) {
2125
m3 := &UniqueFieldExtension{}
2226
require.NoError(t, m3.UnmarshalVT(b))
2327

24-
require.Equal(t, unsafe.StringData(m2.Foo), unsafe.StringData(m3.Foo))
28+
require.Same(t, unsafe.StringData(m2.Foo), unsafe.StringData(m3.Foo), "string field")
29+
30+
keys2 := slices.Collect(maps.Keys(m2.Bar))
31+
keys3 := slices.Collect(maps.Keys(m3.Bar))
32+
require.Len(t, keys2, 1)
33+
require.Len(t, keys3, 1)
34+
require.Same(t, unsafe.StringData(keys2[0]), unsafe.StringData(keys3[0]), "string key")
35+
36+
values2 := slices.Collect(maps.Values(m2.Baz))
37+
values3 := slices.Collect(maps.Values(m3.Baz))
38+
require.Len(t, values2, 1)
39+
require.Len(t, values2, 1)
40+
require.Same(t, unsafe.StringData(values2[0]), unsafe.StringData(values3[0]), "string value")
2541
}

0 commit comments

Comments
 (0)