Skip to content

Commit ebdc8b8

Browse files
committed
chunked: add no-compression option for ZstdWriter
a new function NoCompression() is added to provide a way to create uncompressed zstd:chunked files. Signed-off-by: Giuseppe Scrivano <[email protected]>
1 parent 4aa5450 commit ebdc8b8

File tree

3 files changed

+139
-13
lines changed

3 files changed

+139
-13
lines changed

pkg/chunked/compressor/compressor.go

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -206,11 +206,11 @@ type tarSplitData struct {
206206
packer storage.Packer
207207
}
208208

209-
func newTarSplitData(level int) (*tarSplitData, error) {
209+
func newTarSplitData(createZstdWriter minimal.CreateZstdWriterFunc) (*tarSplitData, error) {
210210
compressed := bytes.NewBuffer(nil)
211211
digester := digest.Canonical.Digester()
212212

213-
zstdWriter, err := minimal.ZstdWriterWithLevel(io.MultiWriter(compressed, digester.Hash()), level)
213+
zstdWriter, err := createZstdWriter(io.MultiWriter(compressed, digester.Hash()))
214214
if err != nil {
215215
return nil, err
216216
}
@@ -227,11 +227,11 @@ func newTarSplitData(level int) (*tarSplitData, error) {
227227
}, nil
228228
}
229229

230-
func writeZstdChunkedStream(destFile io.Writer, outMetadata map[string]string, reader io.Reader, level int) error {
230+
func writeZstdChunkedStream(destFile io.Writer, outMetadata map[string]string, reader io.Reader, createZstdWriter minimal.CreateZstdWriterFunc) error {
231231
// total written so far. Used to retrieve partial offsets in the file
232232
dest := ioutils.NewWriteCounter(destFile)
233233

234-
tarSplitData, err := newTarSplitData(level)
234+
tarSplitData, err := newTarSplitData(createZstdWriter)
235235
if err != nil {
236236
return err
237237
}
@@ -251,7 +251,7 @@ func writeZstdChunkedStream(destFile io.Writer, outMetadata map[string]string, r
251251

252252
buf := make([]byte, 4096)
253253

254-
zstdWriter, err := minimal.ZstdWriterWithLevel(dest, level)
254+
zstdWriter, err := createZstdWriter(dest)
255255
if err != nil {
256256
return err
257257
}
@@ -420,7 +420,7 @@ func writeZstdChunkedStream(destFile io.Writer, outMetadata map[string]string, r
420420
UncompressedSize: tarSplitData.uncompressedCounter.Count,
421421
}
422422

423-
return minimal.WriteZstdChunkedManifest(dest, outMetadata, uint64(dest.Count), &ts, metadata, level)
423+
return minimal.WriteZstdChunkedManifest(dest, outMetadata, uint64(dest.Count), &ts, metadata, createZstdWriter)
424424
}
425425

426426
type zstdChunkedWriter struct {
@@ -447,7 +447,7 @@ func (w zstdChunkedWriter) Write(p []byte) (int, error) {
447447
}
448448
}
449449

450-
// zstdChunkedWriterWithLevel writes a zstd compressed tarball where each file is
450+
// makeZstdChunkedWriter writes a zstd compressed tarball where each file is
451451
// compressed separately so it can be addressed separately. Idea based on CRFS:
452452
// https://github.com/google/crfs
453453
// The difference with CRFS is that the zstd compression is used instead of gzip.
@@ -462,12 +462,12 @@ func (w zstdChunkedWriter) Write(p []byte) (int, error) {
462462
// [SKIPPABLE FRAME 1]: [ZSTD SKIPPABLE FRAME, SIZE=MANIFEST LENGTH][MANIFEST]
463463
// [SKIPPABLE FRAME 2]: [ZSTD SKIPPABLE FRAME, SIZE=16][MANIFEST_OFFSET][MANIFEST_LENGTH][MANIFEST_LENGTH_UNCOMPRESSED][MANIFEST_TYPE][CHUNKED_ZSTD_MAGIC_NUMBER]
464464
// MANIFEST_OFFSET, MANIFEST_LENGTH, MANIFEST_LENGTH_UNCOMPRESSED and CHUNKED_ZSTD_MAGIC_NUMBER are 64 bits unsigned in little endian format.
465-
func zstdChunkedWriterWithLevel(out io.Writer, metadata map[string]string, level int) (io.WriteCloser, error) {
465+
func makeZstdChunkedWriter(out io.Writer, metadata map[string]string, createZstdWriter minimal.CreateZstdWriterFunc) (io.WriteCloser, error) {
466466
ch := make(chan error, 1)
467467
r, w := io.Pipe()
468468

469469
go func() {
470-
ch <- writeZstdChunkedStream(out, metadata, r, level)
470+
ch <- writeZstdChunkedStream(out, metadata, r, createZstdWriter)
471471
_, _ = io.Copy(io.Discard, r) // Ordinarily writeZstdChunkedStream consumes all of r. If it fails, ensure the write end never blocks and eventually terminates.
472472
r.Close()
473473
close(ch)
@@ -486,5 +486,40 @@ func ZstdCompressor(r io.Writer, metadata map[string]string, level *int) (io.Wri
486486
level = &l
487487
}
488488

489-
return zstdChunkedWriterWithLevel(r, metadata, *level)
489+
createZstdWriter := func(dest io.Writer) (minimal.ZstdWriter, error) {
490+
return minimal.ZstdWriterWithLevel(dest, *level)
491+
}
492+
493+
return makeZstdChunkedWriter(r, metadata, createZstdWriter)
494+
}
495+
496+
type noCompression struct {
497+
dest io.Writer
498+
}
499+
500+
func (n *noCompression) Write(p []byte) (int, error) {
501+
return n.dest.Write(p)
502+
}
503+
504+
func (n *noCompression) Close() error {
505+
return nil
506+
}
507+
508+
func (n *noCompression) Flush() error {
509+
return nil
510+
}
511+
512+
func (n *noCompression) Reset(dest io.Writer) {
513+
n.dest = dest
514+
}
515+
516+
// NoCompression writes directly to the output file without any compression
517+
//
518+
// Such an output does not follow the zstd:chunked spec and cannot be generally consumed; this function
519+
// only exists for internal purposes and should not be called from outside c/storage.
520+
func NoCompression(r io.Writer, metadata map[string]string) (io.WriteCloser, error) {
521+
createZstdWriter := func(dest io.Writer) (minimal.ZstdWriter, error) {
522+
return &noCompression{dest: dest}, nil
523+
}
524+
return makeZstdChunkedWriter(r, metadata, createZstdWriter)
490525
}

pkg/chunked/compressor/compressor_test.go

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@ package compressor
33
import (
44
"bufio"
55
"bytes"
6+
"errors"
67
"io"
78
"testing"
9+
10+
"github.com/stretchr/testify/assert"
811
)
912

1013
func TestHole(t *testing.T) {
@@ -88,3 +91,82 @@ func TestTwoHoles(t *testing.T) {
8891
t.Error("didn't receive EOF")
8992
}
9093
}
94+
95+
func TestNoCompressionWrite(t *testing.T) {
96+
var buf bytes.Buffer
97+
nc := &noCompression{dest: &buf}
98+
99+
data := []byte("hello world")
100+
n, err := nc.Write(data)
101+
assert.NoError(t, err)
102+
assert.Equal(t, len(data), n)
103+
assert.Equal(t, data, buf.Bytes())
104+
105+
data2 := []byte(" again")
106+
n, err = nc.Write(data2)
107+
assert.NoError(t, err)
108+
assert.Equal(t, len(data2), n)
109+
assert.Equal(t, append(data, data2...), buf.Bytes())
110+
}
111+
112+
func TestNoCompressionClose(t *testing.T) {
113+
var buf bytes.Buffer
114+
nc := &noCompression{dest: &buf}
115+
err := nc.Close()
116+
assert.NoError(t, err)
117+
}
118+
119+
func TestNoCompressionFlush(t *testing.T) {
120+
var buf bytes.Buffer
121+
nc := &noCompression{dest: &buf}
122+
err := nc.Flush()
123+
assert.NoError(t, err)
124+
}
125+
126+
func TestNoCompressionReset(t *testing.T) {
127+
var buf1 bytes.Buffer
128+
nc := &noCompression{dest: &buf1}
129+
130+
data1 := []byte("initial data")
131+
_, err := nc.Write(data1)
132+
assert.NoError(t, err)
133+
assert.Equal(t, data1, buf1.Bytes())
134+
135+
err = nc.Close()
136+
assert.NoError(t, err)
137+
138+
var buf2 bytes.Buffer
139+
nc.Reset(&buf2)
140+
141+
data2 := []byte("new data")
142+
_, err = nc.Write(data2)
143+
assert.NoError(t, err)
144+
145+
assert.Equal(t, data1, buf1.Bytes(), "Buffer 1 should remain unchanged")
146+
assert.Equal(t, data2, buf2.Bytes(), "Buffer 2 should contain the new data")
147+
148+
err = nc.Close()
149+
assert.NoError(t, err)
150+
151+
// Test Reset with nil, though Write would panic, Reset itself should work
152+
nc.Reset(nil)
153+
assert.Nil(t, nc.dest)
154+
}
155+
156+
// Mock writer that returns an error on Write
157+
type errorWriter struct{}
158+
159+
func (ew *errorWriter) Write(p []byte) (n int, err error) {
160+
return 0, errors.New("mock write error")
161+
}
162+
163+
func TestNoCompressionWriteError(t *testing.T) {
164+
ew := &errorWriter{}
165+
nc := &noCompression{dest: ew}
166+
167+
data := []byte("hello world")
168+
n, err := nc.Write(data)
169+
assert.Error(t, err)
170+
assert.Equal(t, 0, n)
171+
assert.Equal(t, "mock write error", err.Error())
172+
}

pkg/chunked/internal/minimal/compression.go

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,15 @@ import (
2020
"github.com/vbatts/tar-split/archive/tar"
2121
)
2222

23+
// ZstdWriter is an interface that wraps standard io.WriteCloser and Reset() to reuse the compressor with a new writer.
24+
type ZstdWriter interface {
25+
io.WriteCloser
26+
Reset(dest io.Writer)
27+
}
28+
29+
// CreateZstdWriterFunc is a function that creates a ZstdWriter for the provided destination writer.
30+
type CreateZstdWriterFunc func(dest io.Writer) (ZstdWriter, error)
31+
2332
// TOC is short for Table of Contents and is used by the zstd:chunked
2433
// file format to effectively add an overall index into the contents
2534
// of a tarball; it also includes file metadata.
@@ -179,7 +188,7 @@ type TarSplitData struct {
179188
UncompressedSize int64
180189
}
181190

182-
func WriteZstdChunkedManifest(dest io.Writer, outMetadata map[string]string, offset uint64, tarSplitData *TarSplitData, metadata []FileMetadata, level int) error {
191+
func WriteZstdChunkedManifest(dest io.Writer, outMetadata map[string]string, offset uint64, tarSplitData *TarSplitData, metadata []FileMetadata, createZstdWriter CreateZstdWriterFunc) error {
183192
// 8 is the size of the zstd skippable frame header + the frame size
184193
const zstdSkippableFrameHeader = 8
185194
manifestOffset := offset + zstdSkippableFrameHeader
@@ -198,7 +207,7 @@ func WriteZstdChunkedManifest(dest io.Writer, outMetadata map[string]string, off
198207
}
199208

200209
var compressedBuffer bytes.Buffer
201-
zstdWriter, err := ZstdWriterWithLevel(&compressedBuffer, level)
210+
zstdWriter, err := createZstdWriter(&compressedBuffer)
202211
if err != nil {
203212
return err
204213
}
@@ -244,7 +253,7 @@ func WriteZstdChunkedManifest(dest io.Writer, outMetadata map[string]string, off
244253
return appendZstdSkippableFrame(dest, manifestDataLE)
245254
}
246255

247-
func ZstdWriterWithLevel(dest io.Writer, level int) (*zstd.Encoder, error) {
256+
func ZstdWriterWithLevel(dest io.Writer, level int) (ZstdWriter, error) {
248257
el := zstd.EncoderLevelFromZstd(level)
249258
return zstd.NewWriter(dest, zstd.WithEncoderLevel(el))
250259
}

0 commit comments

Comments
 (0)