Skip to content

fix(decoder): use configurable limit for max number of records in a record batch #3120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 11, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packet_decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type packetDecoder interface {
getUVarint() (uint64, error)
getFloat64() (float64, error)
getArrayLength() (int, error)
getArrayLengthNoLimit() (int, error)
getCompactArrayLength() (int, error)
getBool() (bool, error)
getEmptyTaggedFieldArray() (int, error)
Expand Down
14 changes: 14 additions & 0 deletions real_decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,20 @@ func (rd *realDecoder) getArrayLength() (int, error) {
return tmp, nil
}

func (rd *realDecoder) getArrayLengthNoLimit() (int, error) {
if rd.remaining() < 4 {
rd.off = len(rd.raw)
return -1, ErrInsufficientData
}
tmp := int(int32(binary.BigEndian.Uint32(rd.raw[rd.off:])))
rd.off += 4
if tmp > rd.remaining() {
rd.off = len(rd.raw)
return -1, ErrInsufficientData
}
return tmp, nil
}

func (rd *realDecoder) getCompactArrayLength() (int, error) {
n, err := rd.getUVarint()
if err != nil {
Expand Down
6 changes: 5 additions & 1 deletion record_batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,11 @@ func (b *RecordBatch) decode(pd packetDecoder) (err error) {
return err
}

numRecs, err := pd.getArrayLength()
// Using NoLimit because a single record batch could contain
// more then 2*math.MaxUint16 records. The packet decoder will
// check to make sure the array is not greater than the
// remaining bytes.
numRecs, err := pd.getArrayLengthNoLimit()
if err != nil {
return err
}
Expand Down
38 changes: 38 additions & 0 deletions record_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package sarama

import (
"fmt"
"reflect"
"testing"
"time"
Expand Down Expand Up @@ -254,3 +255,40 @@ func TestRecordBatchDecoding(t *testing.T) {
}
}
}

func TestRecordBatchInvalidNumRecords(t *testing.T) {
encodedBatch := []byte{
0, 0, 0, 0, 0, 0, 0, 0, // First Offset
0, 0, 0, 70, // Length
0, 0, 0, 0, // Partition Leader Epoch
2, // Version
91, 48, 202, 99, // CRC
0, 0, // Attributes
0, 0, 0, 0, // Last Offset Delta
0, 0, 1, 88, 141, 205, 89, 56, // First Timestamp
0, 0, 0, 0, 0, 0, 0, 0, // Max Timestamp
0, 0, 0, 0, 0, 0, 0, 0, // Producer ID
0, 0, // Producer Epoch
0, 0, 0, 0, // First Sequence
0, 1, 255, 255, // Number of Records - 1 + 2*math.MaxUint16
40, // Record Length
0, // Attributes
10, // Timestamp Delta
0, // Offset Delta
8, // Key Length
1, 2, 3, 4,
6, // Value Length
5, 6, 7,
2, // Number of Headers
6, // Header Key Length
8, 9, 10, // Header Key
4, // Header Value Length
11, 12, // Header Value
}

batch := RecordBatch{}
err := decode(encodedBatch, &batch, nil)
if err != ErrInsufficientData {
t.Fatal(fmt.Errorf("was suppose to get ErrInsufficientData, instead got: %w", err))
}
}