Skip to content

Commit 51a4066

Browse files
committed
[compression] NewZstdCompressingWriter
1 parent 441f506 commit 51a4066

File tree

3 files changed

+86
-6
lines changed

3 files changed

+86
-6
lines changed

server/util/compression/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ go_library(
77
visibility = ["//visibility:public"],
88
deps = [
99
"//server/metrics",
10+
"//server/util/bytebufferpool",
1011
"//server/util/log",
1112
"@com_github_klauspost_compress//zstd",
1213
"@com_github_prometheus_client_golang//prometheus",

server/util/compression/compression.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@ import (
77
"sync"
88

99
"github.com/buildbuddy-io/buildbuddy/server/metrics"
10+
"github.com/buildbuddy-io/buildbuddy/server/util/bytebufferpool"
1011
"github.com/buildbuddy-io/buildbuddy/server/util/log"
1112
"github.com/klauspost/compress/zstd"
1213
"github.com/prometheus/client_golang/prometheus"
1314
)
1415

16+
const compressChunkSize = 4 * 1024 * 1024 // 4MB
17+
1518
var (
1619
// zstdEncoder can be shared across goroutines to compress chunks of data
1720
// using EncodeAll. Streaming functions such as encoder.ReadFrom or io.Copy
@@ -22,6 +25,8 @@ var (
2225
// either for streaming decompression using ReadFrom or batch decompression
2326
// using DecodeAll. The returned decoders *must not* be closed.
2427
zstdDecoderPool = NewZstdDecoderPool()
28+
29+
compressBufPool = bytebufferpool.FixedSize(compressChunkSize)
2530
)
2631

2732
func mustGetZstdEncoder() *zstd.Encoder {
@@ -177,6 +182,50 @@ func NewZstdCompressingReader(reader io.ReadCloser, readBuf []byte, compressBuf
177182
}, nil
178183
}
179184

185+
type compressingWriter struct {
186+
w io.Writer
187+
compressBuf []byte
188+
poolCompressBuf []byte
189+
}
190+
191+
func (c *compressingWriter) Write(p []byte) (int, error) {
192+
var totalWritten int
193+
for len(p) > 0 {
194+
chunkSize := min(len(p), cap(c.compressBuf))
195+
chunk := p[:chunkSize]
196+
c.compressBuf = CompressZstd(c.compressBuf[:0], chunk)
197+
198+
written, err := c.w.Write(c.compressBuf)
199+
if err != nil {
200+
return totalWritten, err
201+
}
202+
if written < len(c.compressBuf) {
203+
return totalWritten, io.ErrShortWrite
204+
}
205+
206+
totalWritten += chunkSize
207+
p = p[chunkSize:]
208+
}
209+
return totalWritten, nil
210+
}
211+
212+
func (c *compressingWriter) Close() error {
213+
compressBufPool.Put(c.poolCompressBuf)
214+
return nil
215+
}
216+
217+
// NewZstdCompressingWriter returns a writer that compresses each chunk of the
218+
// input using zstd and writes the compressed data to the underlying writer.
219+
// The writer uses a fixed-size 4MB buffer for compression.
220+
func NewZstdCompressingWriter(w io.Writer) io.WriteCloser {
221+
compressBuf := compressBufPool.Get()
222+
return &compressingWriter{
223+
w: w,
224+
compressBuf: compressBuf,
225+
poolCompressBuf: compressBuf,
226+
}
227+
}
228+
180229
// NewZstdDecompressingReader reads zstd-compressed data from the input
181230
// reader and makes the decompressed data available on the output reader. The
182231
// output reader is also an io.WriterTo, which can often prevent allocations

server/util/compression/compression_test.go

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"bytes"
55
"fmt"
66
"io"
7-
"math"
87
"strconv"
98
"testing"
109

@@ -50,18 +49,39 @@ func TestLossless(t *testing.T) {
5049
compress: compressWithNewZstdCompressingReader,
5150
decompress: decompressWithNewZstdDecompressingReader,
5251
},
52+
{
53+
name: "NewZstdCompressingWriter -> DecompressZstd",
54+
compress: compressWithNewZstdCompressingWriter,
55+
decompress: decompressWithDecompressZstd,
56+
},
57+
{
58+
name: "NewZstdCompressingWriter -> NewZstdDecompressor",
59+
compress: compressWithNewZstdCompressingWriter,
60+
decompress: decompressWithNewZstdDecompressor,
61+
},
62+
{
63+
name: "NewZstdCompressingWriter -> NewZstdDecompressingReader",
64+
compress: compressWithNewZstdCompressingWriter,
65+
decompress: decompressWithNewZstdDecompressingReader,
66+
},
5367
} {
54-
for i := 1; i <= 5; i++ {
55-
srclen := int(math.Pow10(i))
68+
for _, srclen := range []int{9, 99, 999, 1_999_999, 5_999_999} {
5669
name := tc.name + "_" + strconv.Itoa(srclen) + "_bytes"
5770
t.Run(name, func(t *testing.T) {
5871
_, r := testdigest.NewReader(t, int64(srclen))
5972
src, err := io.ReadAll(r)
6073
require.NoError(t, err)
74+
require.Len(t, src, srclen)
6175
require.Equal(t, srclen, len(src))
6276
compressed := tc.compress(t, src)
6377

6478
decompressed := tc.decompress(t, len(src), compressed)
79+
require.Len(t, decompressed, srclen)
80+
if srclen > 1000 {
81+
require.Empty(t, cmp.Diff(src[:1000], decompressed[:1000]))
82+
require.Empty(t, cmp.Diff(src[len(src)-1000:], decompressed[len(decompressed)-1000:]))
83+
return
84+
}
6585
require.Empty(t, cmp.Diff(src, decompressed))
6686
})
6787
}
@@ -88,6 +108,17 @@ func compressWithNewZstdCompressingReader(t *testing.T, src []byte) []byte {
88108
return compressed
89109
}
90110

111+
func compressWithNewZstdCompressingWriter(t *testing.T, src []byte) []byte {
112+
compressed := &bytes.Buffer{}
113+
cw := compression.NewZstdCompressingWriter(compressed)
114+
written, err := cw.Write(src)
115+
require.NoError(t, err)
116+
require.Equal(t, len(src), written)
117+
err = cw.Close()
118+
require.NoError(t, err)
119+
return compressed.Bytes()
120+
}
121+
91122
func decompressWithDecompressZstd(t *testing.T, srclen int, compressed []byte) []byte {
92123
decompressed := make([]byte, srclen)
93124
decompressed, err := compression.DecompressZstd(decompressed, compressed)
@@ -111,10 +142,9 @@ func decompressWithNewZstdDecompressingReader(t *testing.T, srclen int, compress
111142
rc := io.NopCloser(bytes.NewReader(compressed))
112143
d, err := compression.NewZstdDecompressingReader(rc)
113144
require.NoError(t, err)
114-
buf := make([]byte, srclen)
115-
n, err := d.Read(buf)
145+
buf, err := io.ReadAll(d)
116146
require.NoError(t, err)
117-
require.Equal(t, srclen, n)
147+
require.Len(t, buf, srclen)
118148
err = d.Close()
119149
require.NoError(t, err)
120150
err = rc.Close()

0 commit comments

Comments
 (0)