Skip to content

Commit 67aa0e5

Browse files
authored
Merge pull request #2085 from nolandseigler/rows-snake-case
RowToStructByName Snake Case Collision
2 parents 96791c8 + 71a8e53 commit 67aa0e5

File tree

2 files changed

+48
-8
lines changed

2 files changed

+48
-8
lines changed

rows.go

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -797,7 +797,7 @@ func computeNamedStructFields(
797797
if !dbTagPresent {
798798
colName = sf.Name
799799
}
800-
fpos := fieldPosByName(fldDescs, colName)
800+
fpos := fieldPosByName(fldDescs, colName, !dbTagPresent)
801801
if fpos == -1 {
802802
if missingField == "" {
803803
missingField = colName
@@ -816,16 +816,21 @@ func computeNamedStructFields(
816816

817817
const structTagKey = "db"
818818

819-
func fieldPosByName(fldDescs []pgconn.FieldDescription, field string) (i int) {
819+
func fieldPosByName(fldDescs []pgconn.FieldDescription, field string, normalize bool) (i int) {
820820
i = -1
821-
for i, desc := range fldDescs {
822821

823-
// Snake case support.
822+
if normalize {
824823
field = strings.ReplaceAll(field, "_", "")
825-
descName := strings.ReplaceAll(desc.Name, "_", "")
826-
827-
if strings.EqualFold(descName, field) {
828-
return i
824+
}
825+
for i, desc := range fldDescs {
826+
if normalize {
827+
if strings.EqualFold(strings.ReplaceAll(desc.Name, "_", ""), field) {
828+
return i
829+
}
830+
} else {
831+
if desc.Name == field {
832+
return i
833+
}
829834
}
830835
}
831836
return

rows_test.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,41 @@ func TestRowToStructByName(t *testing.T) {
667667
})
668668
}
669669

670+
func TestRowToStructByNameDbTags(t *testing.T) {
671+
type person struct {
672+
Last string `db:"last_name"`
673+
First string `db:"first_name"`
674+
Age int32 `db:"age"`
675+
AccountID string `db:"account_id"`
676+
AnotherAccountID string `db:"account__id"`
677+
}
678+
679+
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
680+
rows, _ := conn.Query(ctx, `select 'John' as first_name, 'Smith' as last_name, n as age, 'd5e49d3f' as account_id, '5e49d321' as account__id from generate_series(0, 9) n`)
681+
slice, err := pgx.CollectRows(rows, pgx.RowToStructByName[person])
682+
assert.NoError(t, err)
683+
684+
assert.Len(t, slice, 10)
685+
for i := range slice {
686+
assert.Equal(t, "Smith", slice[i].Last)
687+
assert.Equal(t, "John", slice[i].First)
688+
assert.EqualValues(t, i, slice[i].Age)
689+
assert.Equal(t, "d5e49d3f", slice[i].AccountID)
690+
assert.Equal(t, "5e49d321", slice[i].AnotherAccountID)
691+
}
692+
693+
// check missing fields in a returned row
694+
rows, _ = conn.Query(ctx, `select 'Smith' as last_name, n as age from generate_series(0, 9) n`)
695+
_, err = pgx.CollectRows(rows, pgx.RowToStructByName[person])
696+
assert.ErrorContains(t, err, "cannot find field first_name in returned row")
697+
698+
// check missing field in a destination struct
699+
rows, _ = conn.Query(ctx, `select 'John' as first_name, 'Smith' as last_name, n as age, 'd5e49d3f' as account_id, '5e49d321' as account__id, null as ignore from generate_series(0, 9) n`)
700+
_, err = pgx.CollectRows(rows, pgx.RowToAddrOfStructByName[person])
701+
assert.ErrorContains(t, err, "struct doesn't have corresponding row field ignore")
702+
})
703+
}
704+
670705
func TestRowToStructByNameEmbeddedStruct(t *testing.T) {
671706
type Name struct {
672707
Last string `db:"last_name"`

0 commit comments

Comments
 (0)