Skip to content

azblob: Random write in DownloadFile #22459

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/storage/azblob/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
### Bugs Fixed

* Re-enabled `SharedKeyCredential` authentication mode for non TLS protected endpoints.
* Use random write in `DownloadFile` method. Fixes [#22426](https://github.com/Azure/azure-sdk-for-go/issues/22426).

### Other Changes

Expand Down
177 changes: 2 additions & 175 deletions sdk/storage/azblob/blob/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ package blob
import (
"context"
"io"
"math"
"os"
"sync"
"time"
Expand Down Expand Up @@ -359,7 +358,7 @@ func (b *Client) downloadBuffer(ctx context.Context, writer io.WriterAt, o downl
OperationName: "downloadBlobToWriterAt",
TransferSize: count,
ChunkSize: o.BlockSize,
NumChunks: uint16(((count - 1) / o.BlockSize) + 1),
NumChunks: uint64(((count - 1) / o.BlockSize) + 1),
Concurrency: o.Concurrency,
Operation: func(ctx context.Context, chunkStart int64, count int64) error {
downloadBlobOptions := o.getDownloadBlobOptions(HTTPRange{
Expand Down Expand Up @@ -398,165 +397,6 @@ func (b *Client) downloadBuffer(ctx context.Context, writer io.WriterAt, o downl
return count, nil
}

// downloadFile downloads an Azure blob to a Writer. The blocks are downloaded parallely,
// but written to file serially
func (b *Client) downloadFile(ctx context.Context, writer io.Writer, o downloadOptions) (int64, error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
if o.BlockSize == 0 {
o.BlockSize = DefaultDownloadBlockSize
}

if o.Concurrency == 0 {
o.Concurrency = DefaultConcurrency
}

count := o.Range.Count
if count == CountToEnd { //Calculate size if not specified
gr, err := b.GetProperties(ctx, o.getBlobPropertiesOptions())
if err != nil {
return 0, err
}
count = *gr.ContentLength - o.Range.Offset
}

if count <= 0 {
// The file is empty, there is nothing to download.
return 0, nil
}

progress := int64(0)
progressLock := &sync.Mutex{}

// helper routine to get body
getBodyForRange := func(ctx context.Context, chunkStart, size int64) (io.ReadCloser, error) {
downloadBlobOptions := o.getDownloadBlobOptions(HTTPRange{
Offset: chunkStart + o.Range.Offset,
Count: size,
}, nil)
dr, err := b.DownloadStream(ctx, downloadBlobOptions)
if err != nil {
return nil, err
}

var body io.ReadCloser = dr.NewRetryReader(ctx, &o.RetryReaderOptionsPerBlock)
if o.Progress != nil {
rangeProgress := int64(0)
body = streaming.NewResponseProgress(
body,
func(bytesTransferred int64) {
diff := bytesTransferred - rangeProgress
rangeProgress = bytesTransferred
progressLock.Lock()
progress += diff
o.Progress(progress)
progressLock.Unlock()
})
}

return body, nil
}

// if file fits in a single buffer, we'll download here.
if count <= o.BlockSize {
body, err := getBodyForRange(ctx, int64(0), count)
if err != nil {
return 0, err
}
defer body.Close()

return io.Copy(writer, body)
}

buffers := shared.NewMMBPool(int(o.Concurrency), o.BlockSize)
defer buffers.Free()

numChunks := uint16((count-1)/o.BlockSize + 1)
for bufferCounter := float64(0); bufferCounter < math.Min(float64(numChunks), float64(o.Concurrency)); bufferCounter++ {
if _, err := buffers.Grow(); err != nil {
return 0, err
}
}

acquireBuffer := func() ([]byte, error) {
return <-buffers.Acquire(), nil
}

blocks := make([]chan []byte, numChunks)
for b := range blocks {
blocks[b] = make(chan []byte)
}

/*
* We have created as many channels as the number of chunks we have.
* Each downloaded block will be sent to the channel matching its
* sequence number, i.e. 0th block is sent to 0th channel, 1st block
* to 1st channel and likewise. The blocks are then read and written
* to the file serially by below goroutine. Do note that the blocks
* are still downloaded parallelly from n/w, only serialized
* and written to file here.
*/
writerError := make(chan error)
writeSize := int64(0)
go func(ch chan error) {
for _, block := range blocks {
select {
case <-ctx.Done():
return
case block := <-block:
n, err := writer.Write(block)
writeSize += int64(n)
buffers.Release(block[:cap(block)])
if err != nil {
ch <- err
return
}
}
}
ch <- nil
}(writerError)

// Prepare and do parallel download.
err := shared.DoBatchTransfer(ctx, &shared.BatchTransferOptions{
OperationName: "downloadBlobToWriterAt",
TransferSize: count,
ChunkSize: o.BlockSize,
NumChunks: numChunks,
Concurrency: o.Concurrency,
Operation: func(ctx context.Context, chunkStart int64, count int64) error {
buff, err := acquireBuffer()
if err != nil {
return err
}

body, err := getBodyForRange(ctx, chunkStart, count)
if err != nil {
buffers.Release(buff)
return nil
}

_, err = io.ReadFull(body, buff[:count])
body.Close()
if err != nil {
return err
}

blockIndex := chunkStart / o.BlockSize
blocks[blockIndex] <- buff[:count]
return nil
},
})

if err != nil {
return 0, err
}
// error from writer thread.
if err = <-writerError; err != nil {
return 0, err
}
return writeSize, nil
}

// DownloadStream reads a range of bytes from a blob. The response also includes the blob's properties and metadata.
// For more information, see https://docs.microsoft.com/rest/api/storageservices/get-blob.
func (b *Client) DownloadStream(ctx context.Context, o *DownloadStreamOptions) (DownloadStreamResponse, error) {
Expand Down Expand Up @@ -596,11 +436,6 @@ func (b *Client) DownloadFile(ctx context.Context, file *os.File, o *DownloadFil
}
do := (*downloadOptions)(o)

filePointer, err := file.Seek(0, io.SeekCurrent)
if err != nil {
return 0, err
}

// 1. Calculate the size of the destination file
var size int64

Expand Down Expand Up @@ -629,15 +464,7 @@ func (b *Client) DownloadFile(ctx context.Context, file *os.File, o *DownloadFil
}

if size > 0 {
writeSize, err := b.downloadFile(ctx, file, *do)
if err != nil {
return 0, err
}
_, err = file.Seek(filePointer, io.SeekStart)
if err != nil {
return 0, err
}
return writeSize, nil
return b.downloadBuffer(ctx, file, *do)
} else { // if the blob's size is 0, there is no need in downloading it
return 0, nil
}
Expand Down
60 changes: 60 additions & 0 deletions sdk/storage/azblob/blob/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@ import (
"crypto/rand"
"errors"
"fmt"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"io"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -3745,3 +3748,60 @@ func (s *BlobRecordedTestsSuite) TestBlobClientCustomAudience() {
_, err = blobClientAudience.GetProperties(context.Background(), nil)
_require.NoError(err)
}

type fakeDownloadBlob struct {
contentSize int64
numChunks uint64
}

// nolint
func (f *fakeDownloadBlob) Do(req *http.Request) (*http.Response, error) {
// check how many times range based get blob is called
if _, ok := req.Header["x-ms-range"]; ok {
atomic.AddUint64(&f.numChunks, 1)
}
return &http.Response{
Request: req,
Status: "Created",
StatusCode: http.StatusOK,
Header: http.Header{"Content-Length": []string{fmt.Sprintf("%v", f.contentSize)}},
Body: http.NoBody,
}, nil
}

func TestDownloadSmallBlockSize(t *testing.T) {
_require := require.New(t)

fileSize := int64(100 * 1024 * 1024)
blockSize := int64(1024)
numChunks := uint64(((fileSize - 1) / blockSize) + 1)
fbb := &fakeDownloadBlob{
contentSize: fileSize,
}
blobClient, err := blockblob.NewClientWithNoCredential("https://fake/blob/path", &blockblob.ClientOptions{
ClientOptions: policy.ClientOptions{
Transport: fbb,
},
})
_require.NoError(err)
_require.NotNil(blobClient)

// download to a temp file and verify contents
tmp, err := os.CreateTemp("", "")
_require.NoError(err)
defer tmp.Close()

_, err = blobClient.DownloadFile(context.Background(), tmp, &blob.DownloadFileOptions{BlockSize: blockSize})
_require.NoError(err)

_require.Equal(atomic.LoadUint64(&fbb.numChunks), numChunks)

// reset counter
atomic.StoreUint64(&fbb.numChunks, 0)

buff := make([]byte, fileSize)
_, err = blobClient.DownloadBuffer(context.Background(), buff, &blob.DownloadBufferOptions{BlockSize: blockSize})
_require.NoError(err)

_require.Equal(atomic.LoadUint64(&fbb.numChunks), numChunks)
}
2 changes: 1 addition & 1 deletion sdk/storage/azblob/blockblob/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ func (bb *Client) uploadFromReader(ctx context.Context, reader io.ReaderAt, actu
OperationName: "uploadFromReader",
TransferSize: actualSize,
ChunkSize: o.BlockSize,
NumChunks: uint16(((actualSize - 1) / o.BlockSize) + 1),
NumChunks: uint64(((actualSize - 1) / o.BlockSize) + 1),
Concurrency: o.Concurrency,
Operation: func(ctx context.Context, offset int64, chunkSize int64) error {
// This function is called once per block.
Expand Down
70 changes: 70 additions & 0 deletions sdk/storage/azblob/blockblob/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5945,3 +5945,73 @@ func (s *BlockBlobRecordedTestsSuite) TestBlockBlobClientCustomAudience() {
_, err = bbClientAudience.GetProperties(context.Background(), nil)
_require.NoError(err)
}

func (s *BlockBlobUnrecordedTestsSuite) TestBlockBlobClientUploadDownloadFile() {
_require := require.New(s.T())
testName := s.T().Name()

svcClient, err := testcommon.GetServiceClient(s.T(), testcommon.TestAccountDefault, nil)
_require.NoError(err)

containerClient := testcommon.CreateNewContainer(context.Background(), _require, testcommon.GenerateContainerName(testName), svcClient)
defer testcommon.DeleteContainer(context.Background(), _require, containerClient)

bbClient := containerClient.NewBlockBlobClient(testcommon.GenerateBlobName(testName))

// create local file
var fileSize int64 = 401 * 1024 * 1024
content := make([]byte, fileSize)
_, err = rand.Read(content)
_require.NoError(err)
err = os.WriteFile("testFile", content, 0644)
_require.NoError(err)

defer func() {
err = os.Remove("testFile")
_require.NoError(err)
}()

fh, err := os.Open("testFile")
_require.NoError(err)

defer func(fh *os.File) {
err := fh.Close()
_require.NoError(err)
}(fh)

srcHash := md5.New()
_, err = io.Copy(srcHash, fh)
_require.NoError(err)
contentMD5 := srcHash.Sum(nil)

_, err = bbClient.UploadFile(context.Background(), fh, &blockblob.UploadFileOptions{
Concurrency: 5,
BlockSize: 4 * 1024 * 1024,
})
_require.NoError(err)

// download to a temp file and verify contents
tmp, err := os.CreateTemp("", "")
_require.NoError(err)
defer tmp.Close()

n, err := bbClient.DownloadFile(context.Background(), tmp, &blob.DownloadFileOptions{BlockSize: 4 * 1024 * 1024})
_require.NoError(err)
_require.Equal(fileSize, n)

stat, err := tmp.Stat()
_require.NoError(err)
_require.Equal(fileSize, stat.Size())

destHash := md5.New()
_, err = io.Copy(destHash, tmp)
_require.NoError(err)
downloadedContentMD5 := destHash.Sum(nil)

_require.EqualValues(contentMD5, downloadedContentMD5)

gResp, err := bbClient.GetProperties(context.Background(), nil)
_require.NoError(err)
_require.NotNil(gResp.ContentLength)
_require.Equal(fileSize, *gResp.ContentLength)
}
5 changes: 2 additions & 3 deletions sdk/storage/azblob/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,6 @@ func performUploadAndDownloadFileTest(t *testing.T, _require *require.Assertions
destBuffer = make([]byte, downloadCount)
}

_require.NoError(err)
n, err := destFile.Read(destBuffer)
_require.NoError(err)

Expand Down Expand Up @@ -708,9 +707,9 @@ func (s *AZBlobUnrecordedTestsSuite) TestBasicDoBatchTransfer() {
totalSizeCount := int64(0)
runCount := int64(0)

numChunks := uint16(0)
numChunks := uint64(0)
if test.chunkSize != 0 {
numChunks = uint16(((test.transferSize - 1) / test.chunkSize) + 1)
numChunks = uint64(((test.transferSize - 1) / test.chunkSize) + 1)
}

err := shared.DoBatchTransfer(ctx, &shared.BatchTransferOptions{
Expand Down
Loading