@@ -23,10 +23,12 @@ package conn
23
23
import (
24
24
"encoding/binary"
25
25
"fmt"
26
+ "io"
26
27
"math"
27
28
"net"
28
29
29
30
core "google.golang.org/grpc/credentials/alts/internal"
31
+ "google.golang.org/grpc/mem"
30
32
)
31
33
32
34
// ALTSRecordCrypto is the interface for gRPC ALTS record protocol.
@@ -63,6 +65,8 @@ const (
63
65
// The maximum write buffer size. This *must* be multiple of
64
66
// altsRecordDefaultLength.
65
67
altsWriteBufferMaxSize = 512 * 1024 // 512KiB
68
+ // The initial buffer used to read from the network.
69
+ altsReadBufferInitialSize = 32 * 1024 // 32KiB
66
70
)
67
71
68
72
var (
@@ -84,18 +88,27 @@ type conn struct {
84
88
crypto ALTSRecordCrypto
85
89
// buf holds data that has been read from the connection and decrypted,
86
90
// but has not yet been returned by Read.
87
- buf []byte
91
+ buf []byte
92
+ // bufPointer holds the entire decrypted record, even bytes that have
93
+ // been returned by read. It is used to restore buf to it's initial
94
+ // capacity after each frame is decrypted. It is also used to return the
95
+ // buffer to the buffer pool.
96
+ bufPointer * []byte
88
97
payloadLengthLimit int
89
98
// protected holds data read from the network but have not yet been
90
99
// decrypted. This data might not compose a complete frame.
91
100
protected []byte
101
+ // protectedPointer holds a pointer to the protected buffer. It is used to
102
+ // return the protected buffer to the buffer pool.
103
+ protectedPointer * []byte
92
104
// writeBuf is a buffer used to contain encrypted frames before being
93
105
// written to the network.
94
106
writeBuf []byte
95
107
// nextFrame stores the next frame (in protected buffer) info.
96
108
nextFrame []byte
97
109
// overhead is the calculated overhead of each frame.
98
110
overhead int
111
+ isClosed bool
99
112
}
100
113
101
114
// NewConn creates a new secure channel instance given the other party role and
@@ -111,39 +124,43 @@ func NewConn(c net.Conn, side core.Side, recordProtocol string, key []byte, prot
111
124
}
112
125
overhead := MsgLenFieldSize + msgTypeFieldSize + crypto .EncryptionOverhead ()
113
126
payloadLengthLimit := altsRecordDefaultLength - overhead
114
- var protectedBuf []byte
115
- if protected == nil {
116
- // We pre-allocate protected to be of size
117
- // 2*altsRecordDefaultLength-1 during initialization. We only
118
- // read from the network into protected when protected does not
119
- // contain a complete frame, which is at most
120
- // altsRecordDefaultLength-1 (bytes). And we read at most
121
- // altsRecordDefaultLength (bytes) data into protected at one
122
- // time. Therefore, 2*altsRecordDefaultLength-1 is large enough
123
- // to buffer data read from the network.
124
- protectedBuf = make ([]byte , 0 , 2 * altsRecordDefaultLength - 1 )
125
- } else {
126
- protectedBuf = make ([]byte , len (protected ))
127
- copy (protectedBuf , protected )
128
- }
127
+ // We pre-allocate protected to be of size 32KB during initialization.
128
+ // We increase the size of the buffer by the required amount if can't hold a
129
+ // complete encrypted record.
130
+ protectedPointer := mem .DefaultBufferPool ().Get (max (altsReadBufferInitialSize , len (protected )))
131
+ protectedBuf := (* protectedPointer )[:copy (* protectedPointer , protected )]
129
132
130
133
altsConn := & conn {
131
134
Conn : c ,
132
135
crypto : crypto ,
133
136
payloadLengthLimit : payloadLengthLimit ,
137
+ protectedPointer : protectedPointer ,
134
138
protected : protectedBuf ,
135
139
writeBuf : make ([]byte , altsWriteBufferInitialSize ),
136
140
nextFrame : protectedBuf ,
137
141
overhead : overhead ,
142
+ bufPointer : mem .DefaultBufferPool ().Get (altsReadBufferInitialSize ),
138
143
}
139
144
return altsConn , nil
140
145
}
141
146
147
+ func (p * conn ) Close () error {
148
+ if ! p .isClosed {
149
+ p .isClosed = true
150
+ mem .DefaultBufferPool ().Put (p .protectedPointer )
151
+ mem .DefaultBufferPool ().Put (p .bufPointer )
152
+ }
153
+ return p .Conn .Close ()
154
+ }
155
+
142
156
// Read reads and decrypts a frame from the underlying connection, and copies the
143
157
// decrypted payload into b. If the size of the payload is greater than len(b),
144
158
// Read retains the remaining bytes in an internal buffer, and subsequent calls
145
159
// to Read will read from this buffer until it is exhausted.
146
160
func (p * conn ) Read (b []byte ) (n int , err error ) {
161
+ if p .isClosed {
162
+ return 0 , io .EOF
163
+ }
147
164
if len (p .buf ) == 0 {
148
165
var framedMsg []byte
149
166
framedMsg , p .nextFrame , err = ParseFramedMsg (p .nextFrame , altsRecordLengthLimit )
@@ -153,20 +170,30 @@ func (p *conn) Read(b []byte) (n int, err error) {
153
170
// Check whether the next frame to be decrypted has been
154
171
// completely received yet.
155
172
if len (framedMsg ) == 0 {
156
- copy (p .protected , p .nextFrame )
157
- p .protected = p .protected [:len (p .nextFrame )]
173
+ p .protected = p .protected [:copy (p .protected , p .nextFrame )]
158
174
// Always copy next incomplete frame to the beginning of
159
175
// the protected buffer and reset nextFrame to it.
160
176
p .nextFrame = p .protected
161
177
}
162
178
// Check whether a complete frame has been received yet.
163
179
for len (framedMsg ) == 0 {
164
180
if len (p .protected ) == cap (p .protected ) {
165
- tmp := make ([]byte , len (p .protected ), cap (p .protected )+ altsRecordDefaultLength )
166
- copy (tmp , p .protected )
167
- p .protected = tmp
181
+ // We can parse the length header to know exactly how large
182
+ // the buffer needs to be to hold the entire frame.
183
+ length , didParse := ParseMessageLength (p .protected )
184
+ if ! didParse {
185
+ // The protected buffer is initialized with a capacity of
186
+ // larger than 4B. It should always hold the message length
187
+ // header.
188
+ panic (fmt .Sprintf ("protected buffer length shorter than expected: %d vs %d" , len (p .protected ), MsgLenFieldSize ))
189
+ }
190
+ tmp := mem .DefaultBufferPool ().Get (int (length ))
191
+ copy (* tmp , p .protected )
192
+ p .protected = (* tmp )[:len (p .protected )]
193
+ mem .DefaultBufferPool ().Put (p .protectedPointer )
194
+ p .protectedPointer = tmp
168
195
}
169
- n , err = p .Conn .Read (p .protected [len (p .protected ):min ( cap (p .protected ), len ( p . protected ) + altsRecordDefaultLength )])
196
+ n , err = p .Conn .Read (p .protected [len (p .protected ):cap (p .protected )])
170
197
if err != nil {
171
198
return 0 , err
172
199
}
@@ -185,17 +212,27 @@ func (p *conn) Read(b []byte) (n int, err error) {
185
212
}
186
213
ciphertext := msg [msgTypeFieldSize :]
187
214
188
- // Decrypt requires that if the dst and ciphertext alias, they
189
- // must alias exactly. Code here used to use msg[:0], but msg
190
- // starts MsgLenFieldSize+msgTypeFieldSize bytes earlier than
191
- // ciphertext, so they alias inexactly. Using ciphertext[:0]
192
- // arranges the appropriate aliasing without needing to copy
193
- // ciphertext or use a separate destination buffer. For more info
194
- // check: https://golang.org/pkg/crypto/cipher/#AEAD.
195
- p .buf , err = p .crypto .Decrypt (ciphertext [:0 ], ciphertext )
215
+ // Decrypt directly into the buffer, avoiding a copy to p.buf if
216
+ // possible.
217
+ if cap (b ) >= len (msg ) {
218
+ dec , err := p .crypto .Decrypt (b [:0 ], ciphertext )
219
+ if err != nil {
220
+ return 0 , err
221
+ }
222
+ return len (dec ), nil
223
+ }
224
+ // Resize the read buffer if needed to hold the entire decrypted
225
+ // frame.
226
+ if cap (* p .bufPointer ) < len (ciphertext ) {
227
+ mem .DefaultBufferPool ().Put (p .bufPointer )
228
+ p .bufPointer = mem .DefaultBufferPool ().Get (len (ciphertext ))
229
+ }
230
+ p .buf = * p .bufPointer
231
+ dec , err := p .crypto .Decrypt (p .buf [:0 ], ciphertext )
196
232
if err != nil {
197
233
return 0 , err
198
234
}
235
+ p .buf = p .buf [:len (dec )]
199
236
}
200
237
201
238
n = copy (b , p .buf )
0 commit comments