Skip to content

Commit f8e6f94

Browse files
committed
Optimize ALTS reads
1 parent 6819ed7 commit f8e6f94

File tree

4 files changed

+107
-36
lines changed

4 files changed

+107
-36
lines changed

credentials/alts/internal/conn/common.go

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,10 @@ func SliceForAppend(in []byte, n int) (head, tail []byte) {
5454
func ParseFramedMsg(b []byte, maxLen uint32) ([]byte, []byte, error) {
5555
// If the size field is not complete, return the provided buffer as
5656
// remaining buffer.
57-
if len(b) < MsgLenFieldSize {
57+
length, sufficientBytes := ParseMessageLength(b)
58+
if !sufficientBytes {
5859
return nil, b, nil
5960
}
60-
msgLenField := b[:MsgLenFieldSize]
61-
length := binary.LittleEndian.Uint32(msgLenField)
6261
if length > maxLen {
6362
return nil, nil, fmt.Errorf("received the frame length %d larger than the limit %d", length, maxLen)
6463
}
@@ -68,3 +67,15 @@ func ParseFramedMsg(b []byte, maxLen uint32) ([]byte, []byte, error) {
6867
}
6968
return b[:MsgLenFieldSize+length], b[MsgLenFieldSize+length:], nil
7069
}
70+
71+
// ParseMessageLength returns the message length based on frame header. It also
72+
// returns a boolean that indicates if the buffer contains sufficient bytes to
73+
// parse the length header. If there are insufficient bytes, (0, false) is
74+
// returned.
75+
func ParseMessageLength(b []byte) (uint32, bool) {
76+
if len(b) < MsgLenFieldSize {
77+
return 0, false
78+
}
79+
msgLenField := b[:MsgLenFieldSize]
80+
return binary.LittleEndian.Uint32(msgLenField), true
81+
}

credentials/alts/internal/conn/record.go

Lines changed: 67 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@ package conn
2323
import (
2424
"encoding/binary"
2525
"fmt"
26+
"io"
2627
"math"
2728
"net"
2829

2930
core "google.golang.org/grpc/credentials/alts/internal"
31+
"google.golang.org/grpc/mem"
3032
)
3133

3234
// ALTSRecordCrypto is the interface for gRPC ALTS record protocol.
@@ -63,6 +65,8 @@ const (
6365
// The maximum write buffer size. This *must* be multiple of
6466
// altsRecordDefaultLength.
6567
altsWriteBufferMaxSize = 512 * 1024 // 512KiB
68+
// The initial buffer used to read from the network.
69+
altsReadBufferInitialSize = 32 * 1024 // 32KiB
6670
)
6771

6872
var (
@@ -84,18 +88,27 @@ type conn struct {
8488
crypto ALTSRecordCrypto
8589
// buf holds data that has been read from the connection and decrypted,
8690
// but has not yet been returned by Read.
87-
buf []byte
91+
buf []byte
92+
// bufPointer holds the entire decrypted record, even bytes that have
93+
// been returned by read. It is used to restore buf to it's initial
94+
// capacity after each frame is decrypted. It is also used to return the
95+
// buffer to the buffer pool.
96+
bufPointer *[]byte
8897
payloadLengthLimit int
8998
// protected holds data read from the network but have not yet been
9099
// decrypted. This data might not compose a complete frame.
91100
protected []byte
101+
// protectedPointer holds a pointer to the protected buffer. It is used to
102+
// return the protected buffer to the buffer pool.
103+
protectedPointer *[]byte
92104
// writeBuf is a buffer used to contain encrypted frames before being
93105
// written to the network.
94106
writeBuf []byte
95107
// nextFrame stores the next frame (in protected buffer) info.
96108
nextFrame []byte
97109
// overhead is the calculated overhead of each frame.
98110
overhead int
111+
isClosed bool
99112
}
100113

101114
// NewConn creates a new secure channel instance given the other party role and
@@ -111,39 +124,43 @@ func NewConn(c net.Conn, side core.Side, recordProtocol string, key []byte, prot
111124
}
112125
overhead := MsgLenFieldSize + msgTypeFieldSize + crypto.EncryptionOverhead()
113126
payloadLengthLimit := altsRecordDefaultLength - overhead
114-
var protectedBuf []byte
115-
if protected == nil {
116-
// We pre-allocate protected to be of size
117-
// 2*altsRecordDefaultLength-1 during initialization. We only
118-
// read from the network into protected when protected does not
119-
// contain a complete frame, which is at most
120-
// altsRecordDefaultLength-1 (bytes). And we read at most
121-
// altsRecordDefaultLength (bytes) data into protected at one
122-
// time. Therefore, 2*altsRecordDefaultLength-1 is large enough
123-
// to buffer data read from the network.
124-
protectedBuf = make([]byte, 0, 2*altsRecordDefaultLength-1)
125-
} else {
126-
protectedBuf = make([]byte, len(protected))
127-
copy(protectedBuf, protected)
128-
}
127+
// We pre-allocate protected to be of size 32KB during initialization.
128+
// We increase the size of the buffer by the required amount if can't hold a
129+
// complete encrypted record.
130+
protectedPointer := mem.DefaultBufferPool().Get(max(altsReadBufferInitialSize, len(protected)))
131+
protectedBuf := (*protectedPointer)[:copy(*protectedPointer, protected)]
129132

130133
altsConn := &conn{
131134
Conn: c,
132135
crypto: crypto,
133136
payloadLengthLimit: payloadLengthLimit,
137+
protectedPointer: protectedPointer,
134138
protected: protectedBuf,
135139
writeBuf: make([]byte, altsWriteBufferInitialSize),
136140
nextFrame: protectedBuf,
137141
overhead: overhead,
142+
bufPointer: mem.DefaultBufferPool().Get(altsReadBufferInitialSize),
138143
}
139144
return altsConn, nil
140145
}
141146

147+
func (p *conn) Close() error {
148+
if !p.isClosed {
149+
p.isClosed = true
150+
mem.DefaultBufferPool().Put(p.protectedPointer)
151+
mem.DefaultBufferPool().Put(p.bufPointer)
152+
}
153+
return p.Conn.Close()
154+
}
155+
142156
// Read reads and decrypts a frame from the underlying connection, and copies the
143157
// decrypted payload into b. If the size of the payload is greater than len(b),
144158
// Read retains the remaining bytes in an internal buffer, and subsequent calls
145159
// to Read will read from this buffer until it is exhausted.
146160
func (p *conn) Read(b []byte) (n int, err error) {
161+
if p.isClosed {
162+
return 0, io.EOF
163+
}
147164
if len(p.buf) == 0 {
148165
var framedMsg []byte
149166
framedMsg, p.nextFrame, err = ParseFramedMsg(p.nextFrame, altsRecordLengthLimit)
@@ -153,20 +170,30 @@ func (p *conn) Read(b []byte) (n int, err error) {
153170
// Check whether the next frame to be decrypted has been
154171
// completely received yet.
155172
if len(framedMsg) == 0 {
156-
copy(p.protected, p.nextFrame)
157-
p.protected = p.protected[:len(p.nextFrame)]
173+
p.protected = p.protected[:copy(p.protected, p.nextFrame)]
158174
// Always copy next incomplete frame to the beginning of
159175
// the protected buffer and reset nextFrame to it.
160176
p.nextFrame = p.protected
161177
}
162178
// Check whether a complete frame has been received yet.
163179
for len(framedMsg) == 0 {
164180
if len(p.protected) == cap(p.protected) {
165-
tmp := make([]byte, len(p.protected), cap(p.protected)+altsRecordDefaultLength)
166-
copy(tmp, p.protected)
167-
p.protected = tmp
181+
// We can parse the length header to know exactly how large
182+
// the buffer needs to be to hold the entire frame.
183+
length, didParse := ParseMessageLength(p.protected)
184+
if !didParse {
185+
// The protected buffer is initialized with a capacity of
186+
// larger than 4B. It should always hold the message length
187+
// header.
188+
panic(fmt.Sprintf("protected buffer length shorter than expected: %d vs %d", len(p.protected), MsgLenFieldSize))
189+
}
190+
tmp := mem.DefaultBufferPool().Get(int(length))
191+
copy(*tmp, p.protected)
192+
p.protected = (*tmp)[:len(p.protected)]
193+
mem.DefaultBufferPool().Put(p.protectedPointer)
194+
p.protectedPointer = tmp
168195
}
169-
n, err = p.Conn.Read(p.protected[len(p.protected):min(cap(p.protected), len(p.protected)+altsRecordDefaultLength)])
196+
n, err = p.Conn.Read(p.protected[len(p.protected):cap(p.protected)])
170197
if err != nil {
171198
return 0, err
172199
}
@@ -185,17 +212,27 @@ func (p *conn) Read(b []byte) (n int, err error) {
185212
}
186213
ciphertext := msg[msgTypeFieldSize:]
187214

188-
// Decrypt requires that if the dst and ciphertext alias, they
189-
// must alias exactly. Code here used to use msg[:0], but msg
190-
// starts MsgLenFieldSize+msgTypeFieldSize bytes earlier than
191-
// ciphertext, so they alias inexactly. Using ciphertext[:0]
192-
// arranges the appropriate aliasing without needing to copy
193-
// ciphertext or use a separate destination buffer. For more info
194-
// check: https://golang.org/pkg/crypto/cipher/#AEAD.
195-
p.buf, err = p.crypto.Decrypt(ciphertext[:0], ciphertext)
215+
// Decrypt directly into the buffer, avoiding a copy to p.buf if
216+
// possible.
217+
if cap(b) >= len(msg) {
218+
dec, err := p.crypto.Decrypt(b[:0], ciphertext)
219+
if err != nil {
220+
return 0, err
221+
}
222+
return len(dec), nil
223+
}
224+
// Resize the read buffer if needed to hold the entire decrypted
225+
// frame.
226+
if cap(*p.bufPointer) < len(ciphertext) {
227+
mem.DefaultBufferPool().Put(p.bufPointer)
228+
p.bufPointer = mem.DefaultBufferPool().Get(len(ciphertext))
229+
}
230+
p.buf = *p.bufPointer
231+
dec, err := p.crypto.Decrypt(p.buf[:0], ciphertext)
196232
if err != nil {
197233
return 0, err
198234
}
235+
p.buf = p.buf[:len(dec)]
199236
}
200237

201238
n = copy(b, p.buf)

credentials/alts/internal/conn/record_test.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,27 @@ func (s) TestLargeMsg(t *testing.T) {
188188
}
189189
}
190190

191+
// BenchmarkLargeMessage measures the performance of ALTS conns for sending and
192+
// receiving a large message.
193+
func BenchmarkLargeMessage(b *testing.B) {
194+
msgLen := 20 * 1024 * 1024 // 20 MiB
195+
msg := make([]byte, msgLen)
196+
rcvMsg := make([]byte, len(msg))
197+
b.ResetTimer()
198+
clientConn, serverConn := newConnPair(rekeyRecordProtocol, nil, nil)
199+
for range b.N {
200+
// Write 20 MiB 5 times to transfer a total of 100 MiB.
201+
for range 5 {
202+
if n, err := clientConn.Write(msg); n != len(msg) || err != nil {
203+
b.Fatalf("Write() = %v, %v; want %v, <nil>", n, err, len(msg))
204+
}
205+
if n, err := io.ReadFull(serverConn, rcvMsg); n != len(rcvMsg) || err != nil {
206+
b.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg))
207+
}
208+
}
209+
}
210+
}
211+
191212
func testIncorrectMsgType(t *testing.T, rp string) {
192213
// framedMsg is an empty ciphertext with correct framing but wrong
193214
// message type.

credentials/alts/internal/handshaker/handshaker.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ import (
3737
altsgrpc "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
3838
altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
3939
"google.golang.org/grpc/internal/envconfig"
40+
"google.golang.org/grpc/mem"
4041
)
4142

4243
const (
@@ -308,6 +309,8 @@ func (h *altsHandshaker) accessHandshakerService(req *altspb.HandshakerReq) (*al
308309
// whatever received from the network and send it to the handshaker service.
309310
func (h *altsHandshaker) processUntilDone(resp *altspb.HandshakerResp, extra []byte) (*altspb.HandshakerResult, []byte, error) {
310311
var lastWriteTime time.Time
312+
buf := mem.DefaultBufferPool().Get(frameLimit)
313+
defer mem.DefaultBufferPool().Put(buf)
311314
for {
312315
if len(resp.OutFrames) > 0 {
313316
lastWriteTime = time.Now()
@@ -318,8 +321,7 @@ func (h *altsHandshaker) processUntilDone(resp *altspb.HandshakerResp, extra []b
318321
if resp.Result != nil {
319322
return resp.Result, extra, nil
320323
}
321-
buf := make([]byte, frameLimit)
322-
n, err := h.conn.Read(buf)
324+
n, err := h.conn.Read(*buf)
323325
if err != nil && err != io.EOF {
324326
return nil, nil, err
325327
}
@@ -333,7 +335,7 @@ func (h *altsHandshaker) processUntilDone(resp *altspb.HandshakerResp, extra []b
333335
}
334336
// Append extra bytes from the previous interaction with the
335337
// handshaker service with the current buffer read from conn.
336-
p := append(extra, buf[:n]...)
338+
p := append(extra, (*buf)[:n]...)
337339
// Compute the time elapsed since the last write to the peer.
338340
timeElapsed := time.Since(lastWriteTime)
339341
timeElapsedMs := uint32(timeElapsed.Milliseconds())

0 commit comments

Comments
 (0)