@@ -20,7 +20,8 @@ import (
20
20
"bytes"
21
21
"errors"
22
22
"io"
23
- "io/ioutil"
23
+ "runtime"
24
+ "sync"
24
25
25
26
"github.com/klauspost/compress/zstd"
26
27
"google.golang.org/grpc/encoding"
@@ -34,9 +35,22 @@ var encoderOptions = []zstd.EOption{
34
35
zstd .WithWindowSize (512 * 1024 ),
35
36
}
36
37
38
+ var decoderOptions = []zstd.DOption {
39
+ // If the decoder concurrency level is not 1, we would need to call
40
+ // Close() to avoid leaking resources when the object is released
41
+ // from compressor.decoderPool.
42
+ zstd .WithDecoderConcurrency (1 ),
43
+ }
44
+
45
+ // We will set a finalizer on these objects, so when the go-grpc code is
46
+ // finished with them, they will be added back to compressor.decoderPool.
47
+ type decoderWrapper struct {
48
+ * zstd.Decoder
49
+ }
50
+
37
51
type compressor struct {
38
- encoder * zstd.Encoder
39
- decoder * zstd.Decoder
52
+ encoder * zstd.Encoder
53
+ decoderPool sync. Pool // To hold *zstd.Decoder's.
40
54
}
41
55
42
56
func PretendInit (clobbering bool ) {
@@ -45,10 +59,8 @@ func PretendInit(clobbering bool) {
45
59
}
46
60
47
61
enc , _ := zstd .NewWriter (nil , encoderOptions ... )
48
- dec , _ := zstd .NewReader (nil )
49
62
c := & compressor {
50
63
encoder : enc ,
51
- decoder : dec ,
52
64
}
53
65
encoding .RegisterCompressor (c )
54
66
}
@@ -97,17 +109,36 @@ func (z *zstdWriteCloser) Close() error {
97
109
}
98
110
99
111
func (c * compressor ) Decompress (r io.Reader ) (io.Reader , error ) {
100
- compressed , err := ioutil .ReadAll (r )
101
- if err != nil {
102
- return nil , err
112
+ var err error
113
+ var found bool
114
+ var decoder * zstd.Decoder
115
+
116
+ // Note: avoid the use of zstd.Decoder.DecodeAll here, since
117
+ // malicious payloads could DoS us with a decompression bomb.
118
+
119
+ decoder , found = c .decoderPool .Get ().(* zstd.Decoder )
120
+ if ! found {
121
+ decoder , err = zstd .NewReader (r , decoderOptions ... )
122
+ if err != nil {
123
+ return nil , err
124
+ }
125
+ } else {
126
+ err = decoder .Reset (r )
127
+ if err != nil {
128
+ c .decoderPool .Put (decoder )
129
+ return nil , err
130
+ }
103
131
}
104
132
105
- uncompressed , err := c .decoder .DecodeAll (compressed , nil )
106
- if err != nil {
107
- return nil , err
108
- }
133
+ wrapper := & decoderWrapper {Decoder : decoder }
134
+ runtime .SetFinalizer (wrapper , func (dw * decoderWrapper ) {
135
+ err := dw .Reset (nil )
136
+ if err == nil {
137
+ c .decoderPool .Put (dw .Decoder )
138
+ }
139
+ })
109
140
110
- return bytes . NewReader ( uncompressed ) , nil
141
+ return wrapper , nil
111
142
}
112
143
113
144
func (c * compressor ) Name () string {
0 commit comments