diff --git a/sdk/storage/azfile/assets.json b/sdk/storage/azfile/assets.json index 079ab96fc36d..47d08f9c3faa 100644 --- a/sdk/storage/azfile/assets.json +++ b/sdk/storage/azfile/assets.json @@ -2,5 +2,5 @@ "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "go", "TagPrefix": "go/storage/azfile", - "Tag": "go/storage/azfile_f16684c6c5" + "Tag": "go/storage/azfile_f1e8c5b99b" } diff --git a/sdk/storage/azfile/directory/client.go b/sdk/storage/azfile/directory/client.go index f4d5c878c8dd..5c451013c1f7 100644 --- a/sdk/storage/azfile/directory/client.go +++ b/sdk/storage/azfile/directory/client.go @@ -12,16 +12,13 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/file" - "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/fileerror" "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/internal/base" "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/internal/exported" "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/internal/generated" "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/internal/shared" - "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/sas" "net/http" "net/url" "strings" - "time" ) // ClientOptions contains the optional parameters when creating a Client. @@ -209,34 +206,3 @@ func (d *Client) NewListFilesAndDirectoriesPager(options *ListFilesAndDirectorie }, }) } - -// GetSASURL is a convenience method for generating a SAS token for the currently pointed at directory. -// It can only be used if the credential supplied during creation was a SharedKeyCredential. -func (d *Client) GetSASURL(permissions sas.FilePermissions, expiry time.Time, o *GetSASURLOptions) (string, error) { - if d.sharedKey() == nil { - return "", fileerror.MissingSharedKeyCredential - } - st := o.format() - - urlParts, err := sas.ParseURL(d.URL()) - if err != nil { - return "", err - } - - qps, err := sas.SignatureValues{ - Version: sas.Version, - Protocol: sas.ProtocolHTTPS, - ShareName: urlParts.ShareName, - DirectoryOrFilePath: urlParts.DirectoryOrFilePath, - Permissions: permissions.String(), - StartTime: st, - ExpiryTime: expiry.UTC(), - }.SignWithSharedKey(d.sharedKey()) - if err != nil { - return "", err - } - - endpoint := d.URL() + "?" + qps.Encode() - - return endpoint, nil -} diff --git a/sdk/storage/azfile/directory/client_test.go b/sdk/storage/azfile/directory/client_test.go index 2fed9c43c119..96fe27da6399 100644 --- a/sdk/storage/azfile/directory/client_test.go +++ b/sdk/storage/azfile/directory/client_test.go @@ -15,8 +15,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/file" "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/fileerror" "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/internal/testcommon" - "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/sas" - "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/service" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/share" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "testing" @@ -219,63 +218,6 @@ func (d *DirectoryRecordedTestsSuite) TestDirectoryCreateNegativeMultiLevel() { testcommon.ValidateFileErrorCode(_require, err, fileerror.ParentNotFound) } -func (d *DirectoryUnrecordedTestsSuite) TestDirectoryClientUsingSAS() { - _require := require.New(d.T()) - testName := d.T().Name() - - svcClient, err := testcommon.GetServiceClient(d.T(), testcommon.TestAccountDefault, nil) - _require.NoError(err) - - shareName := testcommon.GenerateShareName(testName) - shareClient := testcommon.CreateNewShare(context.Background(), _require, shareName, svcClient) - defer testcommon.DeleteShare(context.Background(), _require, shareClient) - - dirClient := testcommon.CreateNewDirectory(context.Background(), _require, testcommon.GenerateDirectoryName(testName), shareClient) - - permissions := sas.FilePermissions{ - Read: true, - Write: true, - Delete: true, - Create: true, - } - expiry := time.Now().Add(time.Hour) - - dirSASURL, err := dirClient.GetSASURL(permissions, expiry, nil) - _require.NoError(err) - - dirSASClient, err := directory.NewClientWithNoCredential(dirSASURL, nil) - _require.NoError(err) - - _, err = dirSASClient.GetProperties(context.Background(), nil) - _require.Error(err) - testcommon.ValidateFileErrorCode(_require, err, fileerror.AuthenticationFailed) - - subDirSASClient := dirSASClient.NewSubdirectoryClient("subdir") - _, err = subDirSASClient.Create(context.Background(), nil) - _require.Error(err) - testcommon.ValidateFileErrorCode(_require, err, fileerror.AuthenticationFailed) - - // TODO: directory SAS client unable to do create and get properties on directories. - // Also unable to do create or get properties on files. Validate this behaviour. - fileSASClient := dirSASClient.NewFileClient(testcommon.GenerateFileName(testName)) - _, err = fileSASClient.Create(context.Background(), 1024, nil) - _require.Error(err) - testcommon.ValidateFileErrorCode(_require, err, fileerror.AuthenticationFailed) - - _, err = fileSASClient.GetProperties(context.Background(), nil) - _require.Error(err) - testcommon.ValidateFileErrorCode(_require, err, fileerror.AuthenticationFailed) - - // create file using shared key client - _, err = dirClient.NewFileClient(testcommon.GenerateFileName(testName)).Create(context.Background(), 1024, nil) - _require.NoError(err) - - // get properties using SAS client - _, err = fileSASClient.GetProperties(context.Background(), nil) - _require.Error(err) - testcommon.ValidateFileErrorCode(_require, err, fileerror.AuthenticationFailed) -} - func (d *DirectoryRecordedTestsSuite) TestDirCreateDeleteDefault() { _require := require.New(d.T()) testName := d.T().Name() @@ -311,7 +253,43 @@ func (d *DirectoryRecordedTestsSuite) TestDirCreateDeleteDefault() { _require.Equal(gResp.FileChangeTime.IsZero(), false) } -func (d *DirectoryUnrecordedTestsSuite) TestDirSetPropertiesNonDefault() { +func (d *DirectoryRecordedTestsSuite) TestDirSetPropertiesDefault() { + _require := require.New(d.T()) + testName := d.T().Name() + svcClient, err := testcommon.GetServiceClient(d.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareName := testcommon.GenerateShareName(testName) + shareClient := testcommon.CreateNewShare(context.Background(), _require, shareName, svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + dirName := testcommon.GenerateDirectoryName(testName) + dirClient := testcommon.GetDirectoryClient(dirName, shareClient) + + cResp, err := dirClient.Create(context.Background(), nil) + _require.NoError(err) + _require.NotNil(cResp.FilePermissionKey) + + // Set the custom permissions + sResp, err := dirClient.SetProperties(context.Background(), nil) + _require.NoError(err) + _require.NotNil(sResp.FileCreationTime) + _require.NotNil(sResp.FileLastWriteTime) + _require.NotNil(sResp.FilePermissionKey) + _require.Equal(*sResp.FilePermissionKey, *cResp.FilePermissionKey) + + gResp, err := dirClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.NotNil(gResp.FileCreationTime) + _require.NotNil(gResp.FileLastWriteTime) + _require.NotNil(gResp.FilePermissionKey) + _require.Equal(*gResp.FilePermissionKey, *sResp.FilePermissionKey) + _require.Equal(*gResp.FileCreationTime, *sResp.FileCreationTime) + _require.Equal(*gResp.FileLastWriteTime, *sResp.FileLastWriteTime) + _require.Equal(*gResp.FileAttributes, *sResp.FileAttributes) +} + +func (d *DirectoryRecordedTestsSuite) TestDirSetPropertiesNonDefault() { _require := require.New(d.T()) testName := d.T().Name() svcClient, err := testcommon.GetServiceClient(d.T(), testcommon.TestAccountDefault, nil) @@ -328,8 +306,10 @@ func (d *DirectoryUnrecordedTestsSuite) TestDirSetPropertiesNonDefault() { _require.NoError(err) _require.NotNil(cResp.FilePermissionKey) - creationTime := time.Now().Add(5 * time.Minute).Round(time.Microsecond) - lastWriteTime := time.Now().Add(10 * time.Minute).Round(time.Millisecond) + currTime, err := time.Parse(time.UnixDate, "Fri Mar 31 21:00:00 GMT 2023") + _require.NoError(err) + creationTime := currTime.Add(5 * time.Minute).Round(time.Microsecond) + lastWriteTime := currTime.Add(10 * time.Minute).Round(time.Millisecond) // Set the custom permissions sResp, err := dirClient.SetProperties(context.Background(), &directory.SetPropertiesOptions{ @@ -739,54 +719,401 @@ func (d *DirectoryRecordedTestsSuite) TestDirGetSetMetadataMergeAndReplace() { _require.EqualValues(gResp.Metadata, md2) } -func (d *DirectoryRecordedTestsSuite) TestSASDirectoryClientNoKey() { +func (d *DirectoryRecordedTestsSuite) TestDirListFilesAndDirsDefault() { _require := require.New(d.T()) - accountName, _ := testcommon.GetGenericAccountInfo(testcommon.TestAccountDefault) - _require.Greater(len(accountName), 0) + testName := d.T().Name() + + svcClient, err := testcommon.GetServiceClient(d.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + dirName := testcommon.GenerateDirectoryName(testName) + fileName := testcommon.GenerateFileName(testName) + + for i := 0; i < 10; i++ { + _ = testcommon.CreateNewDirectory(context.Background(), _require, dirName+fmt.Sprintf("%v", i), shareClient) + } + for i := 0; i < 5; i++ { + _ = testcommon.CreateNewFileFromShare(context.Background(), _require, fileName+fmt.Sprintf("%v", i), 2048, shareClient) + } + + dirCtr, fileCtr := 0, 0 + pager := shareClient.NewRootDirectoryClient().NewListFilesAndDirectoriesPager(nil) + for pager.More() { + resp, err := pager.NextPage(context.Background()) + _require.NoError(err) + dirCtr += len(resp.Segment.Directories) + fileCtr += len(resp.Segment.Files) + for _, dir := range resp.Segment.Directories { + _require.NotNil(dir.Name) + _require.NotNil(dir.ID) + _require.Nil(dir.Attributes) + _require.Nil(dir.PermissionKey) + _require.Nil(dir.Properties.ETag) + _require.Nil(dir.Properties.ChangeTime) + _require.Nil(dir.Properties.CreationTime) + _require.Nil(dir.Properties.ContentLength) + } + for _, f := range resp.Segment.Files { + _require.NotNil(f.Name) + _require.NotNil(f.ID) + _require.Nil(f.Attributes) + _require.Nil(f.PermissionKey) + _require.Nil(f.Properties.ETag) + _require.Nil(f.Properties.ChangeTime) + _require.Nil(f.Properties.CreationTime) + _require.NotNil(f.Properties.ContentLength) + _require.Equal(*f.Properties.ContentLength, int64(2048)) + } + } + _require.Equal(dirCtr, 10) + _require.Equal(fileCtr, 5) +} + +func (d *DirectoryRecordedTestsSuite) TestDirListFilesAndDirsInclude() { + _require := require.New(d.T()) testName := d.T().Name() - shareName := testcommon.GenerateShareName(testName) + + svcClient, err := testcommon.GetServiceClient(d.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + dirName := testcommon.GenerateDirectoryName(testName) - dirClient, err := directory.NewClientWithNoCredential(fmt.Sprintf("https://%s.file.core.windows.net/%v/%v", accountName, shareName, dirName), nil) + fileName := testcommon.GenerateFileName(testName) + + for i := 0; i < 10; i++ { + _ = testcommon.CreateNewDirectory(context.Background(), _require, dirName+fmt.Sprintf("%v", i), shareClient) + } + + for i := 0; i < 5; i++ { + _ = testcommon.CreateNewFileFromShare(context.Background(), _require, fileName+fmt.Sprintf("%v", i), 2048, shareClient) + } + + dirCtr, fileCtr := 0, 0 + pager := shareClient.NewRootDirectoryClient().NewListFilesAndDirectoriesPager(&directory.ListFilesAndDirectoriesOptions{ + Include: directory.ListFilesInclude{Timestamps: true, ETag: true, Attributes: true, PermissionKey: true}, + IncludeExtendedInfo: to.Ptr(true), + }) + for pager.More() { + resp, err := pager.NextPage(context.Background()) + _require.NoError(err) + dirCtr += len(resp.Segment.Directories) + fileCtr += len(resp.Segment.Files) + for _, dir := range resp.Segment.Directories { + _require.NotNil(dir.Name) + _require.NotNil(dir.ID) + _require.NotNil(dir.Attributes) + _require.NotNil(dir.PermissionKey) + _require.NotNil(dir.Properties.ETag) + _require.NotNil(dir.Properties.ChangeTime) + _require.NotNil(dir.Properties.CreationTime) + _require.Nil(dir.Properties.ContentLength) + } + for _, f := range resp.Segment.Files { + _require.NotNil(f.Name) + _require.NotNil(f.ID) + _require.NotNil(f.Attributes) + _require.NotNil(f.PermissionKey) + _require.NotNil(f.Properties.ETag) + _require.NotNil(f.Properties.ChangeTime) + _require.NotNil(f.Properties.CreationTime) + _require.NotNil(f.Properties.ContentLength) + _require.Equal(*f.Properties.ContentLength, int64(2048)) + } + } + _require.Equal(dirCtr, 10) + _require.Equal(fileCtr, 5) +} + +func (d *DirectoryRecordedTestsSuite) TestDirListFilesAndDirsMaxResultsAndMarker() { + _require := require.New(d.T()) + testName := d.T().Name() + + svcClient, err := testcommon.GetServiceClient(d.T(), testcommon.TestAccountDefault, nil) _require.NoError(err) - permissions := sas.FilePermissions{ - Read: true, - Write: true, - Delete: true, - Create: true, + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + dirName := testcommon.GenerateDirectoryName(testName) + fileName := testcommon.GenerateFileName(testName) + + for i := 0; i < 10; i++ { + _ = testcommon.CreateNewDirectory(context.Background(), _require, dirName+fmt.Sprintf("%v", i), shareClient) + } + + for i := 0; i < 5; i++ { + _ = testcommon.CreateNewFileFromShare(context.Background(), _require, fileName+fmt.Sprintf("%v", i), 2048, shareClient) } - expiry := time.Now().Add(time.Hour) - _, err = dirClient.GetSASURL(permissions, expiry, nil) - _require.Equal(err, fileerror.MissingSharedKeyCredential) + dirCtr, fileCtr := 0, 0 + pager := shareClient.NewRootDirectoryClient().NewListFilesAndDirectoriesPager(&directory.ListFilesAndDirectoriesOptions{ + MaxResults: to.Ptr(int32(2)), + }) + resp, err := pager.NextPage(context.Background()) + _require.NoError(err) + dirCtr += len(resp.Segment.Directories) + fileCtr += len(resp.Segment.Files) + _require.Equal(dirCtr+fileCtr, 2) + + pager = shareClient.NewRootDirectoryClient().NewListFilesAndDirectoriesPager(&directory.ListFilesAndDirectoriesOptions{ + Marker: resp.NextMarker, + MaxResults: to.Ptr(int32(5)), + }) + for pager.More() { + resp, err := pager.NextPage(context.Background()) + _require.NoError(err) + dirCtr += len(resp.Segment.Directories) + fileCtr += len(resp.Segment.Files) + } + _require.Equal(dirCtr, 10) + _require.Equal(fileCtr, 5) } -func (d *DirectoryRecordedTestsSuite) TestSASDirectoryClientSignNegative() { +func (d *DirectoryRecordedTestsSuite) TestDirListFilesAndDirsWithPrefix() { _require := require.New(d.T()) - accountName, accountKey := testcommon.GetGenericAccountInfo(testcommon.TestAccountDefault) - _require.Greater(len(accountName), 0) - _require.Greater(len(accountKey), 0) + testName := d.T().Name() - cred, err := service.NewSharedKeyCredential(accountName, accountKey) + svcClient, err := testcommon.GetServiceClient(d.T(), testcommon.TestAccountDefault, nil) _require.NoError(err) + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + dirName := testcommon.GenerateDirectoryName(testName) + fileName := testcommon.GenerateFileName(testName) + + for i := 0; i < 10; i++ { + _ = testcommon.CreateNewDirectory(context.Background(), _require, fmt.Sprintf("%v", i)+dirName, shareClient) + } + + for i := 0; i < 5; i++ { + _ = testcommon.CreateNewFileFromShare(context.Background(), _require, fmt.Sprintf("%v", i)+fileName, 2048, shareClient) + } + + dirCtr, fileCtr := 0, 0 + pager := shareClient.NewRootDirectoryClient().NewListFilesAndDirectoriesPager(&directory.ListFilesAndDirectoriesOptions{ + Prefix: to.Ptr("1"), + }) + for pager.More() { + resp, err := pager.NextPage(context.Background()) + _require.NoError(err) + dirCtr += len(resp.Segment.Directories) + fileCtr += len(resp.Segment.Files) + if len(resp.Segment.Directories) > 0 { + _require.NotNil(resp.Segment.Directories[0].Name) + _require.Equal(*resp.Segment.Directories[0].Name, "1"+dirName) + } + if len(resp.Segment.Files) > 0 { + _require.NotNil(resp.Segment.Files[0].Name) + _require.Equal(*resp.Segment.Files[0].Name, "1"+fileName) + } + } + _require.Equal(dirCtr, 1) + _require.Equal(fileCtr, 1) +} + +func (d *DirectoryRecordedTestsSuite) TestDirListFilesAndDirsMaxResultsNegative() { + _require := require.New(d.T()) testName := d.T().Name() - shareName := testcommon.GenerateShareName(testName) + + svcClient, err := testcommon.GetServiceClient(d.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + dirName := testcommon.GenerateDirectoryName(testName) - dirClient, err := directory.NewClientWithSharedKeyCredential(fmt.Sprintf("https://%s.file.core.windows.net/%v%v", accountName, shareName, dirName), cred, nil) + fileName := testcommon.GenerateFileName(testName) + + for i := 0; i < 2; i++ { + _ = testcommon.CreateNewDirectory(context.Background(), _require, dirName+fmt.Sprintf("%v", i), shareClient) + } + + for i := 0; i < 2; i++ { + _ = testcommon.CreateNewFileFromShare(context.Background(), _require, fileName+fmt.Sprintf("%v", i), 2048, shareClient) + } + + pager := shareClient.NewRootDirectoryClient().NewListFilesAndDirectoriesPager(&directory.ListFilesAndDirectoriesOptions{ + MaxResults: to.Ptr(int32(-1)), + }) + _, err = pager.NextPage(context.Background()) + _require.Error(err) + testcommon.ValidateFileErrorCode(_require, err, fileerror.OutOfRangeQueryParameterValue) +} + +func (d *DirectoryRecordedTestsSuite) TestDirListFilesAndDirsSnapshot() { + _require := require.New(d.T()) + testName := d.T().Name() + + svcClient, err := testcommon.GetServiceClient(d.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer func() { + _, err := shareClient.Delete(context.Background(), &share.DeleteOptions{DeleteSnapshots: to.Ptr(share.DeleteSnapshotsOptionTypeInclude)}) + _require.NoError(err) + }() + + dirName := testcommon.GenerateDirectoryName(testName) + fileName := testcommon.GenerateFileName(testName) + + for i := 0; i < 10; i++ { + _ = testcommon.CreateNewDirectory(context.Background(), _require, dirName+fmt.Sprintf("%v", i), shareClient) + } + + for i := 0; i < 5; i++ { + _ = testcommon.CreateNewFileFromShare(context.Background(), _require, fileName+fmt.Sprintf("%v", i), 2048, shareClient) + } + + snapResp, err := shareClient.CreateSnapshot(context.Background(), nil) + _require.NoError(err) + _require.NotNil(snapResp.Snapshot) + + _, err = shareClient.NewRootDirectoryClient().GetProperties(context.Background(), &directory.GetPropertiesOptions{ShareSnapshot: snapResp.Snapshot}) + _require.NoError(err) + + dirCtr, fileCtr := 0, 0 + pager := shareClient.NewRootDirectoryClient().NewListFilesAndDirectoriesPager(&directory.ListFilesAndDirectoriesOptions{ + ShareSnapshot: snapResp.Snapshot, + }) + for pager.More() { + resp, err := pager.NextPage(context.Background()) + _require.NoError(err) + dirCtr += len(resp.Segment.Directories) + fileCtr += len(resp.Segment.Files) + } + _require.Equal(dirCtr, 10) + _require.Equal(fileCtr, 5) +} + +func (d *DirectoryRecordedTestsSuite) TestDirListFilesAndDirsInsideDir() { + _require := require.New(d.T()) + testName := d.T().Name() + + svcClient, err := testcommon.GetServiceClient(d.T(), testcommon.TestAccountDefault, nil) _require.NoError(err) - permissions := sas.FilePermissions{ - Read: true, - Write: true, - Delete: true, - Create: true, + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + dirName := testcommon.GenerateDirectoryName(testName) + fileName := testcommon.GenerateFileName(testName) + + dirClient := testcommon.CreateNewDirectory(context.Background(), _require, dirName, shareClient) + + for i := 0; i < 5; i++ { + _, err = dirClient.NewSubdirectoryClient("subdir"+fmt.Sprintf("%v", i)).Create(context.Background(), nil) + _require.NoError(err) } - expiry := time.Time{} - _, err = dirClient.GetSASURL(permissions, expiry, nil) - _require.Equal(err.Error(), "service SAS is missing at least one of these: ExpiryTime or Permissions") + for i := 0; i < 5; i++ { + _, err = dirClient.NewFileClient(fileName+fmt.Sprintf("%v", i)).Create(context.Background(), 0, nil) + _require.NoError(err) + } + + dirCtr, fileCtr := 0, 0 + pager := dirClient.NewListFilesAndDirectoriesPager(&directory.ListFilesAndDirectoriesOptions{ + Include: directory.ListFilesInclude{Timestamps: true, ETag: true, Attributes: true, PermissionKey: true}, + }) + for pager.More() { + resp, err := pager.NextPage(context.Background()) + _require.NoError(err) + dirCtr += len(resp.Segment.Directories) + fileCtr += len(resp.Segment.Files) + for _, dir := range resp.Segment.Directories { + _require.NotNil(dir.Name) + _require.NotNil(dir.ID) + _require.NotNil(dir.Attributes) + _require.NotNil(dir.PermissionKey) + _require.NotNil(dir.Properties.ETag) + _require.NotNil(dir.Properties.ChangeTime) + _require.NotNil(dir.Properties.CreationTime) + _require.Nil(dir.Properties.ContentLength) + } + for _, f := range resp.Segment.Files { + _require.NotNil(f.Name) + _require.NotNil(f.ID) + _require.NotNil(f.Attributes) + _require.NotNil(f.PermissionKey) + _require.NotNil(f.Properties.ETag) + _require.NotNil(f.Properties.ChangeTime) + _require.NotNil(f.Properties.CreationTime) + _require.NotNil(f.Properties.ContentLength) + _require.Equal(*f.Properties.ContentLength, int64(0)) + } + } + _require.Equal(dirCtr, 5) + _require.Equal(fileCtr, 5) +} + +func (d *DirectoryRecordedTestsSuite) TestDirListHandlesDefault() { + _require := require.New(d.T()) + testName := d.T().Name() + + svcClient, err := testcommon.GetServiceClient(d.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + dirClient := testcommon.CreateNewDirectory(context.Background(), _require, testcommon.GenerateDirectoryName(testName), shareClient) + + resp, err := dirClient.ListHandles(context.Background(), nil) + _require.NoError(err) + _require.Len(resp.Handles, 0) + _require.NotNil(resp.NextMarker) + _require.Equal(*resp.NextMarker, "") +} + +func (d *DirectoryRecordedTestsSuite) TestDirForceCloseHandlesDefault() { + _require := require.New(d.T()) + testName := d.T().Name() + + svcClient, err := testcommon.GetServiceClient(d.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + dirClient := testcommon.CreateNewDirectory(context.Background(), _require, testcommon.GenerateDirectoryName(testName), shareClient) + + resp, err := dirClient.ForceCloseHandles(context.Background(), "*", nil) + _require.NoError(err) + _require.EqualValues(*resp.NumberOfHandlesClosed, 0) + _require.EqualValues(*resp.NumberOfHandlesFailedToClose, 0) + _require.Nil(resp.Marker) +} + +func (d *DirectoryRecordedTestsSuite) TestDirectoryCreateNegativeWithoutSAS() { + _require := require.New(d.T()) + testName := d.T().Name() + + accountName, _ := testcommon.GetGenericAccountInfo(testcommon.TestAccountDefault) + _require.Greater(len(accountName), 0) + + svcClient, err := testcommon.GetServiceClient(d.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareName := testcommon.GenerateShareName(testName) + shareClient := testcommon.CreateNewShare(context.Background(), _require, shareName, svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + dirName := testcommon.GenerateDirectoryName(testName) + dirURL := "https://" + accountName + ".file.core.windows.net/" + shareName + "/" + dirName + + options := &directory.ClientOptions{} + testcommon.SetClientOptions(d.T(), &options.ClientOptions) + dirClient, err := directory.NewClientWithNoCredential(dirURL, nil) + _require.NoError(err) + + _, err = dirClient.Create(context.Background(), nil) + _require.Error(err) } // TODO: add tests for listing files and directories after file client is completed diff --git a/sdk/storage/azfile/directory/models.go b/sdk/storage/azfile/directory/models.go index 7aac2ba2587c..8bbe023709b4 100644 --- a/sdk/storage/azfile/directory/models.go +++ b/sdk/storage/azfile/directory/models.go @@ -13,7 +13,6 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/internal/generated" "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/internal/shared" "reflect" - "time" ) // SharedKeyCredential contains an account's name and its primary or secondary key. @@ -158,7 +157,7 @@ func (l ListFilesInclude) format() []generated.ListFilesIncludeType { return nil } - include := []generated.ListFilesIncludeType{} + var include []generated.ListFilesIncludeType if l.Timestamps { include = append(include, ListFilesIncludeTypeTimestamps) @@ -190,27 +189,6 @@ type FileProperty = generated.FileProperty // --------------------------------------------------------------------------------------------------------------------- -// GetSASURLOptions contains the optional parameters for the Client.GetSASURL method. -type GetSASURLOptions struct { - StartTime *time.Time -} - -func (o *GetSASURLOptions) format() time.Time { - if o == nil { - return time.Time{} - } - - var st time.Time - if o.StartTime != nil { - st = o.StartTime.UTC() - } else { - st = time.Time{} - } - return st -} - -// --------------------------------------------------------------------------------------------------------------------- - // ListHandlesOptions contains the optional parameters for the Client.ListHandles method. type ListHandlesOptions struct { // A string value that identifies the portion of the list to be returned with the next list operation. The operation returns diff --git a/sdk/storage/azfile/file/client.go b/sdk/storage/azfile/file/client.go index 6b060d589da2..449822d84ff3 100644 --- a/sdk/storage/azfile/file/client.go +++ b/sdk/storage/azfile/file/client.go @@ -255,13 +255,13 @@ func (f *Client) GetSASURL(permissions sas.FilePermissions, expiry time.Time, o } qps, err := sas.SignatureValues{ - Version: sas.Version, - Protocol: sas.ProtocolHTTPS, - ShareName: urlParts.ShareName, - DirectoryOrFilePath: urlParts.DirectoryOrFilePath, - Permissions: permissions.String(), - StartTime: st, - ExpiryTime: expiry.UTC(), + Version: sas.Version, + Protocol: sas.ProtocolHTTPS, + ShareName: urlParts.ShareName, + FilePath: urlParts.DirectoryOrFilePath, + Permissions: permissions.String(), + StartTime: st, + ExpiryTime: expiry.UTC(), }.SignWithSharedKey(f.sharedKey()) if err != nil { return "", err @@ -360,6 +360,74 @@ func (f *Client) UploadStream(ctx context.Context, body io.Reader, options *Uplo // Concurrent Download Functions ----------------------------------------------------------------------------------------- +// download method downloads an Azure file to a WriterAt in parallel. +func (f *Client) download(ctx context.Context, writer io.WriterAt, o downloadOptions) (int64, error) { + if o.ChunkSize == 0 { + o.ChunkSize = DefaultDownloadChunkSize + } + + count := o.Range.Count + if count == CountToEnd { // If size not specified, calculate it + // If we don't have the length at all, get it + getFilePropertiesOptions := o.getFilePropertiesOptions() + gr, err := f.GetProperties(ctx, getFilePropertiesOptions) + if err != nil { + return 0, err + } + count = *gr.ContentLength - o.Range.Offset + } + + if count <= 0 { + // The file is empty, there is nothing to download. + return 0, nil + } + + // Prepare and do parallel download. + progress := int64(0) + progressLock := &sync.Mutex{} + + err := shared.DoBatchTransfer(ctx, &shared.BatchTransferOptions{ + OperationName: "downloadFileToWriterAt", + TransferSize: count, + ChunkSize: o.ChunkSize, + Concurrency: o.Concurrency, + Operation: func(ctx context.Context, chunkStart int64, count int64) error { + downloadFileOptions := o.getDownloadFileOptions(HTTPRange{ + Offset: chunkStart + o.Range.Offset, + Count: count, + }) + dr, err := f.DownloadStream(ctx, downloadFileOptions) + if err != nil { + return err + } + var body io.ReadCloser = dr.NewRetryReader(ctx, &o.RetryReaderOptionsPerChunk) + if o.Progress != nil { + rangeProgress := int64(0) + body = streaming.NewResponseProgress( + body, + func(bytesTransferred int64) { + diff := bytesTransferred - rangeProgress + rangeProgress = bytesTransferred + progressLock.Lock() + progress += diff + o.Progress(progress) + progressLock.Unlock() + }) + } + _, err = io.Copy(shared.NewSectionWriter(writer, chunkStart, count), body) + if err != nil { + return err + } + err = body.Close() + return err + }, + }) + if err != nil { + return 0, err + } + return count, nil +} + // DownloadStream operation reads or downloads a file from the system, including its metadata and properties. // For more information, see https://learn.microsoft.com/en-us/rest/api/storageservices/get-file. func (f *Client) DownloadStream(ctx context.Context, options *DownloadStreamOptions) (DownloadStreamResponse, error) { @@ -369,20 +437,65 @@ func (f *Client) DownloadStream(ctx context.Context, options *DownloadStreamOpti } resp, err := f.generated().Download(ctx, opts, leaseAccessConditions) + if err != nil { + return DownloadStreamResponse{}, err + } + return DownloadStreamResponse{ - DownloadResponse: resp, - client: f, - getInfo: httpGetterInfo{Range: options.Range, ETag: resp.ETag}, + DownloadResponse: resp, + client: f, + getInfo: httpGetterInfo{Range: options.Range}, + leaseAccessConditions: options.LeaseAccessConditions, }, err } // DownloadBuffer downloads an Azure file to a buffer with parallel. func (f *Client) DownloadBuffer(ctx context.Context, buffer []byte, o *DownloadBufferOptions) (int64, error) { - return 0, nil + if o == nil { + o = &DownloadBufferOptions{} + } + + return f.download(ctx, shared.NewBytesWriter(buffer), (downloadOptions)(*o)) } // DownloadFile downloads an Azure file to a local file. // The file would be truncated if the size doesn't match. func (f *Client) DownloadFile(ctx context.Context, file *os.File, o *DownloadFileOptions) (int64, error) { - return 0, nil + if o == nil { + o = &DownloadFileOptions{} + } + do := (*downloadOptions)(o) + + // 1. Calculate the size of the destination file + var size int64 + + count := do.Range.Count + if count == CountToEnd { + // Try to get Azure file's size + getFilePropertiesOptions := do.getFilePropertiesOptions() + props, err := f.GetProperties(ctx, getFilePropertiesOptions) + if err != nil { + return 0, err + } + size = *props.ContentLength - do.Range.Offset + } else { + size = count + } + + // 2. Compare and try to resize local file's size if it doesn't match Azure file's size. + stat, err := file.Stat() + if err != nil { + return 0, err + } + if stat.Size() != size { + if err = file.Truncate(size); err != nil { + return 0, err + } + } + + if size > 0 { + return f.download(ctx, file, *do) + } else { // if the file's size is 0, there is no need in downloading it + return 0, nil + } } diff --git a/sdk/storage/azfile/file/client_test.go b/sdk/storage/azfile/file/client_test.go index 28735c7b63a7..9a822c6d2572 100644 --- a/sdk/storage/azfile/file/client_test.go +++ b/sdk/storage/azfile/file/client_test.go @@ -28,6 +28,7 @@ import ( "hash/crc64" "io" "io/ioutil" + "net/http" "os" "strings" "testing" @@ -441,6 +442,58 @@ func (f *FileUnrecordedTestsSuite) TestFileGetSetPropertiesNonDefault() { _require.NotNil(getResp.IsServerEncrypted) } +func (f *FileRecordedTestsSuite) TestFileGetSetPropertiesDefault() { + _require := require.New(f.T()) + testName := f.T().Name() + + svcClient, err := testcommon.GetServiceClient(f.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + fClient := testcommon.CreateNewFileFromShare(context.Background(), _require, testcommon.GenerateFileName(testName), 0, shareClient) + + setResp, err := fClient.SetHTTPHeaders(context.Background(), nil) + _require.NoError(err) + _require.NotEqual(*setResp.ETag, "") + _require.Equal(setResp.LastModified.IsZero(), false) + _require.NotEqual(setResp.RequestID, "") + _require.NotEqual(setResp.Version, "") + _require.Equal(setResp.Date.IsZero(), false) + _require.NotNil(setResp.IsServerEncrypted) + + metadata := map[string]*string{ + "Foo": to.Ptr("Foovalue"), + "Bar": to.Ptr("Barvalue"), + } + _, err = fClient.SetMetadata(context.Background(), &file.SetMetadataOptions{ + Metadata: metadata, + }) + _require.NoError(err) + + // get properties on the share snapshot + getResp, err := fClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(setResp.LastModified.IsZero(), false) + _require.Equal(*getResp.FileType, "File") + + _require.Nil(getResp.ContentType) + _require.Nil(getResp.ContentEncoding) + _require.Nil(getResp.ContentLanguage) + _require.Nil(getResp.ContentMD5) + _require.Nil(getResp.CacheControl) + _require.Nil(getResp.ContentDisposition) + _require.Equal(*getResp.ContentLength, int64(0)) + + _require.NotNil(getResp.ETag) + _require.NotNil(getResp.RequestID) + _require.NotNil(getResp.Version) + _require.Equal(getResp.Date.IsZero(), false) + _require.NotNil(getResp.IsServerEncrypted) + _require.EqualValues(getResp.Metadata, metadata) +} + func (f *FileRecordedTestsSuite) TestFilePreservePermissions() { _require := require.New(f.T()) testName := f.T().Name() @@ -721,6 +774,98 @@ func (f *FileRecordedTestsSuite) TestFileSetMetadataInvalidField() { _require.Error(err) } +func (f *FileRecordedTestsSuite) TestStartCopyDefault() { + _require := require.New(f.T()) + testName := f.T().Name() + + svcClient, err := testcommon.GetServiceClient(f.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + srcFile := shareClient.NewRootDirectoryClient().NewFileClient("src" + testcommon.GenerateFileName(testName)) + destFile := shareClient.NewRootDirectoryClient().NewFileClient("dest" + testcommon.GenerateFileName(testName)) + + fileSize := int64(2048) + _, err = srcFile.Create(context.Background(), fileSize, nil) + _require.NoError(err) + + contentR, srcContent := testcommon.GenerateData(int(fileSize)) + srcContentMD5 := md5.Sum(srcContent) + + _, err = srcFile.UploadRange(context.Background(), 0, contentR, nil) + _require.NoError(err) + + copyResp, err := destFile.StartCopyFromURL(context.Background(), srcFile.URL(), nil) + _require.NoError(err) + _require.NotNil(copyResp.ETag) + _require.Equal(copyResp.LastModified.IsZero(), false) + _require.NotNil(copyResp.RequestID) + _require.NotNil(copyResp.Version) + _require.Equal(copyResp.Date.IsZero(), false) + _require.NotEqual(copyResp.CopyStatus, "") + + time.Sleep(time.Duration(5) * time.Second) + + getResp, err := destFile.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.EqualValues(getResp.CopyID, copyResp.CopyID) + _require.NotEqual(*getResp.CopyStatus, "") + _require.Equal(*getResp.CopySource, srcFile.URL()) + _require.Equal(*getResp.CopyStatus, file.CopyStatusTypeSuccess) + + // Abort will fail after copy finished + _, err = destFile.AbortCopy(context.Background(), *copyResp.CopyID, nil) + _require.Error(err) + testcommon.ValidateHTTPErrorCode(_require, err, http.StatusConflict) + + // validate data copied + dResp, err := destFile.DownloadStream(context.Background(), &file.DownloadStreamOptions{ + Range: file.HTTPRange{Offset: 0, Count: fileSize}, + RangeGetContentMD5: to.Ptr(true), + }) + _require.NoError(err) + + destContent, err := io.ReadAll(dResp.Body) + _require.NoError(err) + _require.EqualValues(srcContent, destContent) + _require.Equal(dResp.ContentMD5, srcContentMD5[:]) +} + +func (f *FileRecordedTestsSuite) TestFileStartCopyDestEmpty() { + _require := require.New(f.T()) + testName := f.T().Name() + + svcClient, err := testcommon.GetServiceClient(f.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + fClient := testcommon.CreateNewFileFromShareWithData(context.Background(), _require, "src"+testcommon.GenerateFileName(testName), shareClient) + copyFClient := testcommon.GetFileClientFromShare("dest"+testcommon.GenerateFileName(testName), shareClient) + + _, err = copyFClient.StartCopyFromURL(context.Background(), fClient.URL(), nil) + _require.NoError(err) + + time.Sleep(4 * time.Second) + + resp, err := copyFClient.DownloadStream(context.Background(), nil) + _require.NoError(err) + + // Read the file data to verify the copy + data, err := ioutil.ReadAll(resp.Body) + defer func() { + err = resp.Body.Close() + _require.NoError(err) + }() + + _require.NoError(err) + _require.Equal(*resp.ContentLength, int64(len(testcommon.FileDefaultData))) + _require.Equal(string(data), testcommon.FileDefaultData) +} + func (f *FileRecordedTestsSuite) TestFileStartCopyMetadata() { _require := require.New(f.T()) testName := f.T().Name() @@ -1239,6 +1384,48 @@ func (f *FileRecordedTestsSuite) TestFileStartCopySourceNonExistent() { testcommon.ValidateFileErrorCode(_require, err, fileerror.ResourceNotFound) } +func (f *FileUnrecordedTestsSuite) TestFileStartCopyUsingSASSrc() { + _require := require.New(f.T()) + testName := f.T().Name() + + svcClient, err := testcommon.GetServiceClient(f.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareName := testcommon.GenerateShareName(testName) + shareClient := testcommon.CreateNewShare(context.Background(), _require, "src"+shareName, svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + fileName := testcommon.GenerateFileName(testName) + fClient := testcommon.CreateNewFileFromShareWithData(context.Background(), _require, "src"+fileName, shareClient) + + fileURLWithSAS, err := fClient.GetSASURL(sas.FilePermissions{Read: true, Write: true, Create: true, Delete: true}, time.Now().Add(5*time.Minute).UTC(), nil) + _require.NoError(err) + + // Create a new share for the destination + copyShareClient := testcommon.CreateNewShare(context.Background(), _require, "dest"+shareName, svcClient) + defer testcommon.DeleteShare(context.Background(), _require, copyShareClient) + + copyFileClient := testcommon.GetFileClientFromShare("dst"+fileName, copyShareClient) + + _, err = copyFileClient.StartCopyFromURL(context.Background(), fileURLWithSAS, nil) + _require.NoError(err) + + time.Sleep(4 * time.Second) + + dResp, err := copyFileClient.DownloadStream(context.Background(), nil) + _require.NoError(err) + + data, err := ioutil.ReadAll(dResp.Body) + defer func() { + err = dResp.Body.Close() + _require.NoError(err) + }() + + _require.NoError(err) + _require.Equal(*dResp.ContentLength, int64(len(testcommon.FileDefaultData))) + _require.Equal(string(data), testcommon.FileDefaultData) +} + func (f *FileRecordedTestsSuite) TestFileAbortCopyNoCopyStarted() { _require := require.New(f.T()) testName := f.T().Name() @@ -1414,11 +1601,20 @@ func (f *FileRecordedTestsSuite) TestSASFileClientSignNegative() { } expiry := time.Time{} - _, err = fileClient.GetSASURL(permissions, expiry, nil) + // zero expiry time + _, err = fileClient.GetSASURL(permissions, expiry, &file.GetSASURLOptions{StartTime: to.Ptr(time.Now())}) + _require.Equal(err.Error(), "service SAS is missing at least one of these: ExpiryTime or Permissions") + + // zero start and expiry time + _, err = fileClient.GetSASURL(permissions, expiry, &file.GetSASURLOptions{}) + _require.Equal(err.Error(), "service SAS is missing at least one of these: ExpiryTime or Permissions") + + // empty permissions + _, err = fileClient.GetSASURL(sas.FilePermissions{}, expiry, nil) _require.Equal(err.Error(), "service SAS is missing at least one of these: ExpiryTime or Permissions") } -func (f *FileUnrecordedTestsSuite) TestFileUploadClearListRange() { +func (f *FileRecordedTestsSuite) TestFileUploadClearListRange() { _require := require.New(f.T()) testName := f.T().Name() @@ -1428,7 +1624,7 @@ func (f *FileUnrecordedTestsSuite) TestFileUploadClearListRange() { shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) defer testcommon.DeleteShare(context.Background(), _require, shareClient) - var fileSize int64 = 1024 * 1024 * 10 + var fileSize int64 = 1024 * 10 fClient := shareClient.NewRootDirectoryClient().NewFileClient(testcommon.GenerateFileName(testName)) _, err = fClient.Create(context.Background(), fileSize, nil) _require.NoError(err) @@ -1437,14 +1633,12 @@ func (f *FileUnrecordedTestsSuite) TestFileUploadClearListRange() { _require.NoError(err) _require.Equal(*gResp.ContentLength, fileSize) - contentSize := 1024 * 8 // 8KB - content := make([]byte, contentSize) - body := bytes.NewReader(content) - rsc := streaming.NopCloser(body) - md5Value := md5.Sum(content) + contentSize := 1024 * 2 // 2KB + contentR, contentD := testcommon.GenerateData(contentSize) + md5Value := md5.Sum(contentD) contentMD5 := md5Value[:] - uResp, err := fClient.UploadRange(context.Background(), 0, rsc, &file.UploadRangeOptions{ + uResp, err := fClient.UploadRange(context.Background(), 0, contentR, &file.UploadRangeOptions{ TransactionalValidation: file.TransferValidationTypeMD5(contentMD5), }) _require.NoError(err) @@ -1453,7 +1647,8 @@ func (f *FileUnrecordedTestsSuite) TestFileUploadClearListRange() { rangeList, err := fClient.GetRangeList(context.Background(), nil) _require.NoError(err) - _require.NotNil(rangeList.RequestID) + _require.Len(rangeList.Ranges, 1) + _require.EqualValues(*rangeList.Ranges[0], file.ShareFileRange{Start: to.Ptr(int64(0)), End: to.Ptr(int64(contentSize - 1))}) cResp, err := fClient.ClearRange(context.Background(), file.HTTPRange{Offset: 0, Count: int64(contentSize)}, nil) _require.NoError(err) @@ -1461,7 +1656,7 @@ func (f *FileUnrecordedTestsSuite) TestFileUploadClearListRange() { rangeList2, err := fClient.GetRangeList(context.Background(), nil) _require.NoError(err) - _require.NotNil(rangeList2.RequestID) + _require.Len(rangeList2.Ranges, 0) } func (f *FileUnrecordedTestsSuite) TestFileUploadRangeFromURL() { @@ -1499,11 +1694,11 @@ func (f *FileUnrecordedTestsSuite) TestFileUploadRangeFromURL() { perms := sas.FilePermissions{Read: true, Write: true} sasQueryParams, err := sas.SignatureValues{ - Protocol: sas.ProtocolHTTPS, // Users MUST use HTTPS (not HTTP) - ExpiryTime: time.Now().UTC().Add(48 * time.Hour), // 48-hours before expiration - ShareName: shareName, - DirectoryOrFilePath: srcFileName, - Permissions: perms.String(), + Protocol: sas.ProtocolHTTPS, // Users MUST use HTTPS (not HTTP) + ExpiryTime: time.Now().UTC().Add(48 * time.Hour), // 48-hours before expiration + ShareName: shareName, + FilePath: srcFileName, + Permissions: perms.String(), }.SignWithSharedKey(cred) _require.NoError(err) @@ -1522,7 +1717,9 @@ func (f *FileUnrecordedTestsSuite) TestFileUploadRangeFromURL() { rangeList, err := destFClient.GetRangeList(context.Background(), nil) _require.NoError(err) - _require.NotNil(rangeList.RequestID) + _require.Len(rangeList.Ranges, 1) + _require.Equal(*rangeList.Ranges[0].Start, int64(0)) + _require.Equal(*rangeList.Ranges[0].End, int64(contentSize-1)) cResp, err := destFClient.ClearRange(context.Background(), file.HTTPRange{Offset: 0, Count: int64(contentSize)}, nil) _require.NoError(err) @@ -1530,7 +1727,69 @@ func (f *FileUnrecordedTestsSuite) TestFileUploadRangeFromURL() { rangeList2, err := destFClient.GetRangeList(context.Background(), nil) _require.NoError(err) - _require.NotNil(rangeList2.RequestID) + _require.Len(rangeList2.Ranges, 0) +} + +func (f *FileRecordedTestsSuite) TestFileUploadRangeFromURLNegative() { + _require := require.New(f.T()) + testName := f.T().Name() + + svcClient, err := testcommon.GetServiceClient(f.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareName := testcommon.GenerateShareName(testName) + shareClient := testcommon.CreateNewShare(context.Background(), _require, shareName, svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + var fileSize int64 = 1024 * 20 + srcFileName := "src" + testcommon.GenerateFileName(testName) + srcFClient := testcommon.CreateNewFileFromShare(context.Background(), _require, srcFileName, fileSize, shareClient) + + gResp, err := srcFClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp.ContentLength, fileSize) + + contentSize := 1024 * 8 // 8KB + rsc, content := testcommon.GenerateData(contentSize) + contentCRC64 := crc64.Checksum(content, shared.CRC64Table) + + _, err = srcFClient.UploadRange(context.Background(), 0, rsc, nil) + _require.NoError(err) + + destFClient := testcommon.CreateNewFileFromShare(context.Background(), _require, "dest"+testcommon.GenerateFileName(testName), fileSize, shareClient) + + _, err = destFClient.UploadRangeFromURL(context.Background(), srcFClient.URL(), 0, 0, int64(contentSize), &file.UploadRangeFromURLOptions{ + SourceContentCRC64: contentCRC64, + }) + _require.Error(err) +} + +func (f *FileRecordedTestsSuite) TestFileUploadRangeFromURLOffsetNegative() { + _require := require.New(f.T()) + testName := f.T().Name() + + svcClient, err := testcommon.GetServiceClient(f.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareName := testcommon.GenerateShareName(testName) + shareClient := testcommon.CreateNewShare(context.Background(), _require, shareName, svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + var fileSize int64 = 1024 * 20 + srcFileName := "src" + testcommon.GenerateFileName(testName) + srcFClient := testcommon.CreateNewFileFromShare(context.Background(), _require, srcFileName, fileSize, shareClient) + + gResp, err := srcFClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp.ContentLength, fileSize) + + contentSize := 1024 * 8 // 8KB + destFClient := testcommon.CreateNewFileFromShare(context.Background(), _require, "dest"+testcommon.GenerateFileName(testName), fileSize, shareClient) + + // error is returned when source offset is negative + _, err = destFClient.UploadRangeFromURL(context.Background(), srcFClient.URL(), -1, 0, int64(contentSize), nil) + _require.Error(err) + _require.Equal(err.Error(), "invalid argument: source and destination offsets must be >= 0") } // TODO: check why this is failing @@ -1644,7 +1903,9 @@ func (f *FileUnrecordedTestsSuite) TestFileUploadBuffer() { rangeList, err := fClient.GetRangeList(context.Background(), nil) _require.NoError(err) - _require.NotNil(rangeList.RequestID) + _require.Len(rangeList.Ranges, 1) + _require.Equal(*rangeList.Ranges[0].Start, int64(0)) + _require.Equal(*rangeList.Ranges[0].End, fileSize-1) } func (f *FileUnrecordedTestsSuite) TestFileUploadFile() { @@ -1714,7 +1975,9 @@ func (f *FileUnrecordedTestsSuite) TestFileUploadFile() { rangeList, err := fClient.GetRangeList(context.Background(), nil) _require.NoError(err) - _require.NotNil(rangeList.RequestID) + _require.Len(rangeList.Ranges, 1) + _require.Equal(*rangeList.Ranges[0].Start, int64(0)) + _require.Equal(*rangeList.Ranges[0].End, fileSize-1) } func (f *FileUnrecordedTestsSuite) TestFileUploadStream() { @@ -1765,11 +2028,1086 @@ func (f *FileUnrecordedTestsSuite) TestFileUploadStream() { rangeList, err := fClient.GetRangeList(context.Background(), nil) _require.NoError(err) - _require.NotNil(rangeList.RequestID) + _require.Len(rangeList.Ranges, 1) + _require.Equal(*rangeList.Ranges[0].Start, int64(0)) + _require.Equal(*rangeList.Ranges[0].End, fileSize-1) +} + +func (f *FileUnrecordedTestsSuite) TestFileDownloadBuffer() { + _require := require.New(f.T()) + testName := f.T().Name() + + svcClient, err := testcommon.GetServiceClient(f.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + var fileSize int64 = 100 * 1024 * 1024 + fClient := shareClient.NewRootDirectoryClient().NewFileClient(testcommon.GenerateFileName(testName)) + _, err = fClient.Create(context.Background(), fileSize, nil) + _require.NoError(err) + + gResp, err := fClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp.ContentLength, fileSize) + + content := make([]byte, fileSize) + _, err = rand.Read(content) + _require.NoError(err) + md5Value := md5.Sum(content) + contentMD5 := md5Value[:] + + err = fClient.UploadBuffer(context.Background(), content, &file.UploadBufferOptions{ + Concurrency: 5, + ChunkSize: 4 * 1024 * 1024, + }) + _require.NoError(err) + + destBuffer := make([]byte, fileSize) + cnt, err := fClient.DownloadBuffer(context.Background(), destBuffer, &file.DownloadBufferOptions{ + ChunkSize: 10 * 1024 * 1024, + Concurrency: 5, + }) + _require.NoError(err) + _require.Equal(cnt, fileSize) + + downloadedMD5Value := md5.Sum(destBuffer) + downloadedContentMD5 := downloadedMD5Value[:] + + _require.EqualValues(downloadedContentMD5, contentMD5) + + gResp2, err := fClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp2.ContentLength, fileSize) + + rangeList, err := fClient.GetRangeList(context.Background(), nil) + _require.NoError(err) + _require.Len(rangeList.Ranges, 1) + _require.Equal(*rangeList.Ranges[0].Start, int64(0)) + _require.Equal(*rangeList.Ranges[0].End, fileSize-1) } -// TODO: Content validation in StartCopyFromURL() after adding upload and download methods. +func (f *FileUnrecordedTestsSuite) TestFileDownloadFile() { + _require := require.New(f.T()) + testName := f.T().Name() + + svcClient, err := testcommon.GetServiceClient(f.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) -// TODO: Add tests for upload and download methods + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) -// TODO: Add tests for GetRangeList, ListHandles and ForceCloseHandles + var fileSize int64 = 100 * 1024 * 1024 + fClient := shareClient.NewRootDirectoryClient().NewFileClient(testcommon.GenerateFileName(testName)) + _, err = fClient.Create(context.Background(), fileSize, nil) + _require.NoError(err) + + gResp, err := fClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp.ContentLength, fileSize) + + content := make([]byte, fileSize) + _, err = rand.Read(content) + _require.NoError(err) + md5Value := md5.Sum(content) + contentMD5 := md5Value[:] + + err = fClient.UploadBuffer(context.Background(), content, &file.UploadBufferOptions{ + Concurrency: 5, + ChunkSize: 4 * 1024 * 1024, + }) + _require.NoError(err) + + destFileName := "BigFile-downloaded.bin" + destFile, err := os.Create(destFileName) + _require.NoError(err) + defer func(name string) { + err = os.Remove(name) + _require.NoError(err) + }(destFileName) + defer func(destFile *os.File) { + err = destFile.Close() + _require.NoError(err) + }(destFile) + + cnt, err := fClient.DownloadFile(context.Background(), destFile, &file.DownloadFileOptions{ + ChunkSize: 10 * 1024 * 1024, + Concurrency: 5, + }) + _require.NoError(err) + _require.Equal(cnt, fileSize) + + hash := md5.New() + _, err = io.Copy(hash, destFile) + _require.NoError(err) + downloadedContentMD5 := hash.Sum(nil) + + _require.EqualValues(downloadedContentMD5, contentMD5) + + gResp2, err := fClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp2.ContentLength, fileSize) + + rangeList, err := fClient.GetRangeList(context.Background(), nil) + _require.NoError(err) + _require.Len(rangeList.Ranges, 1) + _require.Equal(*rangeList.Ranges[0].Start, int64(0)) + _require.Equal(*rangeList.Ranges[0].End, fileSize-1) +} + +func (f *FileRecordedTestsSuite) TestUploadDownloadDefaultNonDefaultMD5() { + _require := require.New(f.T()) + testName := f.T().Name() + + svcClient, err := testcommon.GetServiceClient(f.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + fClient := testcommon.CreateNewFileFromShare(context.Background(), _require, "src"+testcommon.GenerateFileName(testName), 2048, shareClient) + defer testcommon.DeleteFile(context.Background(), _require, fClient) + + contentR, contentD := testcommon.GenerateData(2048) + + pResp, err := fClient.UploadRange(context.Background(), 0, contentR, nil) + _require.NoError(err) + _require.NotNil(pResp.ContentMD5) + _require.NotNil(pResp.IsServerEncrypted) + _require.NotNil(pResp.ETag) + _require.Equal(pResp.LastModified.IsZero(), false) + _require.NotNil(pResp.RequestID) + _require.NotNil(pResp.Version) + _require.Equal(pResp.Date.IsZero(), false) + + // Get with rangeGetContentMD5 enabled. + // Partial data, check status code 206. + resp, err := fClient.DownloadStream(context.Background(), &file.DownloadStreamOptions{ + Range: file.HTTPRange{Offset: 0, Count: 1024}, + RangeGetContentMD5: to.Ptr(true), + }) + _require.NoError(err) + _require.Equal(*resp.ContentLength, int64(1024)) + _require.NotNil(resp.ContentMD5) + _require.Equal(*resp.ContentType, "application/octet-stream") + + downloadedData, err := ioutil.ReadAll(resp.Body) + _require.NoError(err) + _require.EqualValues(downloadedData, contentD[:1024]) + + // Set ContentMD5 for the entire file. + _, err = fClient.SetHTTPHeaders(context.Background(), &file.SetHTTPHeadersOptions{ + HTTPHeaders: &file.HTTPHeaders{ + ContentMD5: pResp.ContentMD5, + ContentLanguage: to.Ptr("test")}, + }) + _require.NoError(err) + + // Test get with another type of range index, and validate if FileContentMD5 can be got correct. + resp, err = fClient.DownloadStream(context.Background(), &file.DownloadStreamOptions{ + Range: file.HTTPRange{Offset: 1024, Count: file.CountToEnd}, + }) + _require.NoError(err) + _require.Equal(*resp.ContentLength, int64(1024)) + _require.Nil(resp.ContentMD5) + _require.EqualValues(resp.FileContentMD5, pResp.ContentMD5) + _require.Equal(*resp.ContentLanguage, "test") + // Note: when it's downloading range, range's MD5 is returned, when set rangeGetContentMD5=true, currently set it to false, so should be empty + + downloadedData, err = ioutil.ReadAll(resp.Body) + _require.NoError(err) + _require.EqualValues(downloadedData, contentD[1024:]) + + _require.Equal(*resp.AcceptRanges, "bytes") + _require.Nil(resp.CacheControl) + _require.Nil(resp.ContentDisposition) + _require.Nil(resp.ContentEncoding) + _require.Equal(*resp.ContentRange, "bytes 1024-2047/2048") + _require.Nil(resp.ContentType) // Note ContentType is set to empty during SetHTTPHeaders + _require.Nil(resp.CopyID) + _require.Nil(resp.CopyProgress) + _require.Nil(resp.CopySource) + _require.Nil(resp.CopyStatus) + _require.Nil(resp.CopyStatusDescription) + _require.Equal(resp.Date.IsZero(), false) + _require.NotEqual(*resp.ETag, "") + _require.Equal(resp.LastModified.IsZero(), false) + _require.Nil(resp.Metadata) + _require.NotEqual(*resp.RequestID, "") + _require.NotEqual(*resp.Version, "") + _require.NotNil(resp.IsServerEncrypted) + + // Get entire fClient, check status code 200. + resp, err = fClient.DownloadStream(context.Background(), nil) + _require.NoError(err) + _require.Equal(*resp.ContentLength, int64(2048)) + _require.EqualValues(resp.ContentMD5, pResp.ContentMD5) // Note: This case is inted to get entire fClient, entire file's MD5 will be returned. + _require.Nil(resp.FileContentMD5) // Note: FileContentMD5 is returned, only when range is specified explicitly. + + downloadedData, err = ioutil.ReadAll(resp.Body) + _require.NoError(err) + _require.EqualValues(downloadedData, contentD[:]) + + _require.Equal(*resp.AcceptRanges, "bytes") + _require.Nil(resp.CacheControl) + _require.Nil(resp.ContentDisposition) + _require.Nil(resp.ContentEncoding) + _require.Nil(resp.ContentRange) // Note: ContentRange is returned, only when range is specified explicitly. + _require.Nil(resp.ContentType) + _require.Nil(resp.CopyCompletionTime) + _require.Nil(resp.CopyID) + _require.Nil(resp.CopyProgress) + _require.Nil(resp.CopySource) + _require.Nil(resp.CopyStatus) + _require.Nil(resp.CopyStatusDescription) + _require.Equal(resp.Date.IsZero(), false) + _require.NotEqual(*resp.ETag, "") + _require.Equal(resp.LastModified.IsZero(), false) + _require.Nil(resp.Metadata) + _require.NotEqual(*resp.RequestID, "") + _require.NotEqual(*resp.Version, "") + _require.NotNil(resp.IsServerEncrypted) +} + +func (f *FileRecordedTestsSuite) TestFileDownloadDataNonExistentFile() { + _require := require.New(f.T()) + testName := f.T().Name() + + svcClient, err := testcommon.GetServiceClient(f.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + fClient := testcommon.GetFileClientFromShare(testcommon.GenerateFileName(testName), shareClient) + + _, err = fClient.DownloadStream(context.Background(), nil) + testcommon.ValidateFileErrorCode(_require, err, fileerror.ResourceNotFound) +} + +func (f *FileRecordedTestsSuite) TestFileDownloadDataOffsetOutOfRange() { + _require := require.New(f.T()) + testName := f.T().Name() + + svcClient, err := testcommon.GetServiceClient(f.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + fClient := testcommon.CreateNewFileFromShare(context.Background(), _require, testcommon.GenerateFileName(testName), 0, shareClient) + + _, err = fClient.DownloadStream(context.Background(), &file.DownloadStreamOptions{ + Range: file.HTTPRange{ + Offset: int64(len(testcommon.FileDefaultData)), + Count: file.CountToEnd, + }, + }) + testcommon.ValidateFileErrorCode(_require, err, fileerror.InvalidRange) +} + +func (f *FileRecordedTestsSuite) TestFileDownloadDataEntireFile() { + _require := require.New(f.T()) + testName := f.T().Name() + + svcClient, err := testcommon.GetServiceClient(f.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + fClient := testcommon.CreateNewFileFromShareWithData(context.Background(), _require, testcommon.GenerateFileName(testName), shareClient) + + resp, err := fClient.DownloadStream(context.Background(), nil) + _require.NoError(err) + + // Specifying a count of 0 results in the value being ignored + data, err := ioutil.ReadAll(resp.Body) + _require.NoError(err) + _require.EqualValues(string(data), testcommon.FileDefaultData) +} + +func (f *FileRecordedTestsSuite) TestFileDownloadDataCountExact() { + _require := require.New(f.T()) + testName := f.T().Name() + + svcClient, err := testcommon.GetServiceClient(f.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + fClient := testcommon.CreateNewFileFromShareWithData(context.Background(), _require, testcommon.GenerateFileName(testName), shareClient) + + resp, err := fClient.DownloadStream(context.Background(), &file.DownloadStreamOptions{ + Range: file.HTTPRange{ + Offset: 0, + Count: int64(len(testcommon.FileDefaultData)), + }, + }) + _require.NoError(err) + + data, err := ioutil.ReadAll(resp.Body) + _require.NoError(err) + _require.EqualValues(string(data), testcommon.FileDefaultData) +} + +func (f *FileRecordedTestsSuite) TestFileDownloadDataCountOutOfRange() { + _require := require.New(f.T()) + testName := f.T().Name() + + svcClient, err := testcommon.GetServiceClient(f.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + fClient := testcommon.CreateNewFileFromShareWithData(context.Background(), _require, testcommon.GenerateFileName(testName), shareClient) + + resp, err := fClient.DownloadStream(context.Background(), &file.DownloadStreamOptions{ + Range: file.HTTPRange{ + Offset: 0, + Count: int64(len(testcommon.FileDefaultData)) * 2, + }, + }) + _require.NoError(err) + + data, err := ioutil.ReadAll(resp.Body) + _require.NoError(err) + _require.EqualValues(string(data), testcommon.FileDefaultData) +} + +func (f *FileRecordedTestsSuite) TestFileUploadRangeNilBody() { + _require := require.New(f.T()) + testName := f.T().Name() + svcClient, err := testcommon.GetServiceClient(f.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + fClient := testcommon.CreateNewFileFromShare(context.Background(), _require, "src"+testcommon.GenerateFileName(testName), 0, shareClient) + + _, err = fClient.UploadRange(context.Background(), 0, nil, nil) + _require.Error(err) + _require.Contains(err.Error(), "body must not be nil") +} + +func (f *FileRecordedTestsSuite) TestFileUploadRangeEmptyBody() { + _require := require.New(f.T()) + testName := f.T().Name() + + svcClient, err := testcommon.GetServiceClient(f.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + fClient := testcommon.CreateNewFileFromShare(context.Background(), _require, testcommon.GenerateFileName(testName), 0, shareClient) + + _, err = fClient.UploadRange(context.Background(), 0, streaming.NopCloser(bytes.NewReader([]byte{})), nil) + _require.Error(err) + _require.Contains(err.Error(), "body must contain readable data whose size is > 0") +} + +func (f *FileRecordedTestsSuite) TestFileUploadRangeNonExistentFile() { + _require := require.New(f.T()) + testName := f.T().Name() + svcClient, err := testcommon.GetServiceClient(f.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + fClient := testcommon.GetFileClientFromShare(testcommon.GenerateFileName(testName), shareClient) + + rsc, _ := testcommon.GenerateData(12) + _, err = fClient.UploadRange(context.Background(), 0, rsc, nil) + _require.Error(err) + testcommon.ValidateFileErrorCode(_require, err, fileerror.ResourceNotFound) +} + +func (f *FileRecordedTestsSuite) TestFileUploadRangeTransactionalMD5() { + _require := require.New(f.T()) + testName := f.T().Name() + + svcClient, err := testcommon.GetServiceClient(f.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + fClient := testcommon.CreateNewFileFromShare(context.Background(), _require, testcommon.GenerateFileName(testName), 2048, shareClient) + + contentR, contentD := testcommon.GenerateData(2048) + _md5 := md5.Sum(contentD) + + // Upload range with correct transactional MD5 + pResp, err := fClient.UploadRange(context.Background(), 0, contentR, &file.UploadRangeOptions{ + TransactionalValidation: file.TransferValidationTypeMD5(_md5[:]), + }) + _require.NoError(err) + _require.NotNil(pResp.ContentMD5) + _require.NotNil(pResp.ETag) + _require.Equal(pResp.LastModified.IsZero(), false) + _require.NotNil(pResp.RequestID) + _require.NotNil(pResp.Version) + _require.Equal(pResp.Date.IsZero(), false) + _require.EqualValues(pResp.ContentMD5, _md5[:]) + + // Upload range with empty MD5, nil MD5 is covered by other cases. + pResp, err = fClient.UploadRange(context.Background(), 1024, streaming.NopCloser(bytes.NewReader(contentD[1024:])), nil) + _require.NoError(err) + _require.NotNil(pResp.ContentMD5) + + resp, err := fClient.DownloadStream(context.Background(), nil) + _require.NoError(err) + _require.Equal(*resp.ContentLength, int64(2048)) + + downloadedData, err := ioutil.ReadAll(resp.Body) + _require.NoError(err) + _require.EqualValues(downloadedData, contentD[:]) +} + +func (f *FileRecordedTestsSuite) TestFileUploadRangeIncorrectTransactionalMD5() { + _require := require.New(f.T()) + testName := f.T().Name() + + svcClient, err := testcommon.GetServiceClient(f.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + fClient := testcommon.CreateNewFileFromShare(context.Background(), _require, testcommon.GenerateFileName(testName), 2048, shareClient) + + contentR, _ := testcommon.GenerateData(2048) + _, incorrectMD5 := testcommon.GenerateData(16) + + // Upload range with incorrect transactional MD5 + _, err = fClient.UploadRange(context.Background(), 0, contentR, &file.UploadRangeOptions{ + TransactionalValidation: file.TransferValidationTypeMD5(incorrectMD5[:]), + }) + _require.Error(err) + testcommon.ValidateFileErrorCode(_require, err, fileerror.MD5Mismatch) +} + +// Testings for GetRangeList and ClearRange +func (f *FileRecordedTestsSuite) TestGetRangeListNonDefaultExact() { + _require := require.New(f.T()) + testName := f.T().Name() + + svcClient, err := testcommon.GetServiceClient(f.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + fClient := testcommon.GetFileClientFromShare(testcommon.GenerateFileName(testName), shareClient) + + fileSize := int64(5 * 1024) + _, err = fClient.Create(context.Background(), fileSize, &file.CreateOptions{HTTPHeaders: &file.HTTPHeaders{}}) + _require.NoError(err) + defer testcommon.DeleteFile(context.Background(), _require, fClient) + + rsc, _ := testcommon.GenerateData(1024) + putResp, err := fClient.UploadRange(context.Background(), 0, rsc, nil) + _require.NoError(err) + _require.Equal(putResp.LastModified.IsZero(), false) + _require.NotNil(putResp.ETag) + _require.NotNil(putResp.ContentMD5) + _require.NotNil(putResp.RequestID) + _require.NotNil(putResp.Version) + _require.Equal(putResp.Date.IsZero(), false) + + rangeList, err := fClient.GetRangeList(context.Background(), &file.GetRangeListOptions{ + Range: file.HTTPRange{ + Offset: 0, + Count: fileSize, + }, + }) + _require.NoError(err) + _require.Equal(rangeList.LastModified.IsZero(), false) + _require.NotNil(rangeList.ETag) + _require.Equal(*rangeList.FileContentLength, fileSize) + _require.NotNil(rangeList.RequestID) + _require.NotNil(rangeList.Version) + _require.Equal(rangeList.Date.IsZero(), false) + _require.Len(rangeList.Ranges, 1) + _require.Equal(*rangeList.Ranges[0].Start, int64(0)) + _require.Equal(*rangeList.Ranges[0].End, int64(1023)) +} + +// Default means clear the entire file's range +func (f *FileRecordedTestsSuite) TestClearRangeDefault() { + _require := require.New(f.T()) + testName := f.T().Name() + + svcClient, err := testcommon.GetServiceClient(f.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + fClient := testcommon.CreateNewFileFromShare(context.Background(), _require, testcommon.GenerateFileName(testName), 2048, shareClient) + defer testcommon.DeleteFile(context.Background(), _require, fClient) + + rsc, _ := testcommon.GenerateData(2048) + _, err = fClient.UploadRange(context.Background(), 0, rsc, nil) + _require.NoError(err) + + _, err = fClient.ClearRange(context.Background(), file.HTTPRange{Offset: 0, Count: 2048}, nil) + _require.NoError(err) + + rangeList, err := fClient.GetRangeList(context.Background(), &file.GetRangeListOptions{ + Range: file.HTTPRange{Offset: 0, Count: file.CountToEnd}, + }) + _require.NoError(err) + _require.Len(rangeList.Ranges, 0) +} + +func (f *FileRecordedTestsSuite) TestClearRangeNonDefault() { + _require := require.New(f.T()) + testName := f.T().Name() + svcClient, err := testcommon.GetServiceClient(f.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + fClient := testcommon.CreateNewFileFromShare(context.Background(), _require, testcommon.GenerateFileName(testName), 4096, shareClient) + defer testcommon.DeleteFile(context.Background(), _require, fClient) + + rsc, _ := testcommon.GenerateData(2048) + _, err = fClient.UploadRange(context.Background(), 2048, rsc, nil) + _require.NoError(err) + + _, err = fClient.ClearRange(context.Background(), file.HTTPRange{Offset: 2048, Count: 2048}, nil) + _require.NoError(err) + + rangeList, err := fClient.GetRangeList(context.Background(), nil) + _require.NoError(err) + _require.Len(rangeList.Ranges, 0) +} + +func (f *FileRecordedTestsSuite) TestClearRangeMultipleRanges() { + _require := require.New(f.T()) + testName := f.T().Name() + svcClient, err := testcommon.GetServiceClient(f.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + fClient := testcommon.CreateNewFileFromShare(context.Background(), _require, testcommon.GenerateFileName(testName), 2048, shareClient) + defer testcommon.DeleteFile(context.Background(), _require, fClient) + + rsc, _ := testcommon.GenerateData(2048) + _, err = fClient.UploadRange(context.Background(), 0, rsc, nil) + _require.NoError(err) + + _, err = fClient.ClearRange(context.Background(), file.HTTPRange{Offset: 1024, Count: 1024}, nil) + _require.NoError(err) + + rangeList, err := fClient.GetRangeList(context.Background(), nil) + _require.NoError(err) + _require.Len(rangeList.Ranges, 1) + _require.EqualValues(*rangeList.Ranges[0], file.ShareFileRange{Start: to.Ptr(int64(0)), End: to.Ptr(int64(1023))}) +} + +// When not 512 aligned, clear range will set 0 the non-512 aligned range, and will not eliminate the range. +func (f *FileRecordedTestsSuite) TestClearRangeNonDefaultCount() { + _require := require.New(f.T()) + testName := f.T().Name() + + svcClient, err := testcommon.GetServiceClient(f.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + fClient := testcommon.CreateNewFileFromShare(context.Background(), _require, testcommon.GenerateFileName(testName), 1, shareClient) + defer testcommon.DeleteFile(context.Background(), _require, fClient) + + d := []byte{65} + _, err = fClient.UploadRange(context.Background(), 0, streaming.NopCloser(bytes.NewReader(d)), nil) + _require.NoError(err) + + rangeList, err := fClient.GetRangeList(context.Background(), &file.GetRangeListOptions{ + Range: file.HTTPRange{Offset: 0, Count: file.CountToEnd}, + }) + _require.NoError(err) + _require.Len(rangeList.Ranges, 1) + _require.EqualValues(*rangeList.Ranges[0], file.ShareFileRange{Start: to.Ptr(int64(0)), End: to.Ptr(int64(0))}) + + _, err = fClient.ClearRange(context.Background(), file.HTTPRange{Offset: 0, Count: 1}, nil) + _require.NoError(err) + + rangeList, err = fClient.GetRangeList(context.Background(), &file.GetRangeListOptions{ + Range: file.HTTPRange{Offset: 0, Count: file.CountToEnd}, + }) + _require.NoError(err) + _require.Len(rangeList.Ranges, 1) + _require.EqualValues(*rangeList.Ranges[0], file.ShareFileRange{Start: to.Ptr(int64(0)), End: to.Ptr(int64(0))}) + + dResp, err := fClient.DownloadStream(context.Background(), nil) + _require.NoError(err) + + _bytes, err := ioutil.ReadAll(dResp.Body) + _require.NoError(err) + _require.EqualValues(_bytes, []byte{0}) +} + +func (f *FileRecordedTestsSuite) TestFileClearRangeNegativeInvalidCount() { + _require := require.New(f.T()) + testName := f.T().Name() + svcClient, err := testcommon.GetServiceClient(f.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareClient := testcommon.GetShareClient(testcommon.GenerateShareName(testName), svcClient) + fClient := testcommon.GetFileClientFromShare(testcommon.GenerateFileName(testName), shareClient) + + _, err = fClient.ClearRange(context.Background(), file.HTTPRange{Offset: 0, Count: 0}, nil) + _require.Error(err) + _require.Contains(err.Error(), "invalid argument: either offset is < 0 or count <= 0") +} + +func (f *FileRecordedTestsSuite) TestFileGetRangeListDefaultEmptyFile() { + _require := require.New(f.T()) + testName := f.T().Name() + + svcClient, err := testcommon.GetServiceClient(f.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + fClient := testcommon.CreateNewFileFromShare(context.Background(), _require, testcommon.GenerateFileName(testName), 0, shareClient) + + resp, err := fClient.GetRangeList(context.Background(), nil) + _require.NoError(err) + _require.Len(resp.Ranges, 0) +} + +func setupGetRangeListTest(_require *require.Assertions, testName string, fileSize int64, shareClient *share.Client) *file.Client { + fClient := testcommon.CreateNewFileFromShare(context.Background(), _require, testcommon.GenerateFileName(testName), fileSize, shareClient) + rsc, _ := testcommon.GenerateData(int(fileSize)) + _, err := fClient.UploadRange(context.Background(), 0, rsc, nil) + _require.NoError(err) + return fClient +} + +func (f *FileRecordedTestsSuite) TestFileGetRangeListDefaultRange() { + _require := require.New(f.T()) + testName := f.T().Name() + + svcClient, err := testcommon.GetServiceClient(f.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + fileSize := int64(512) + fClient := setupGetRangeListTest(_require, testName, fileSize, shareClient) + + resp, err := fClient.GetRangeList(context.Background(), &file.GetRangeListOptions{ + Range: file.HTTPRange{Offset: 0, Count: file.CountToEnd}, + }) + _require.NoError(err) + _require.Len(resp.Ranges, 1) + _require.EqualValues(*resp.Ranges[0], file.ShareFileRange{Start: to.Ptr(int64(0)), End: to.Ptr(fileSize - 1)}) +} + +func (f *FileRecordedTestsSuite) TestFileGetRangeListNonContiguousRanges() { + _require := require.New(f.T()) + testName := f.T().Name() + + svcClient, err := testcommon.GetServiceClient(f.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + fileSize := int64(512) + fClient := setupGetRangeListTest(_require, testName, fileSize, shareClient) + + _, err = fClient.Resize(context.Background(), fileSize*3, nil) + _require.NoError(err) + + rsc, _ := testcommon.GenerateData(int(fileSize)) + _, err = fClient.UploadRange(context.Background(), fileSize*2, rsc, nil) + _require.NoError(err) + + resp, err := fClient.GetRangeList(context.Background(), nil) + _require.NoError(err) + _require.Len(resp.Ranges, 2) + _require.EqualValues(*resp.Ranges[0], file.ShareFileRange{Start: to.Ptr(int64(0)), End: to.Ptr(fileSize - 1)}) + _require.EqualValues(*resp.Ranges[1], file.ShareFileRange{Start: to.Ptr(fileSize * 2), End: to.Ptr((fileSize * 3) - 1)}) +} + +func (f *FileRecordedTestsSuite) TestFileGetRangeListNonContiguousRangesCountLess() { + _require := require.New(f.T()) + testName := f.T().Name() + + svcClient, err := testcommon.GetServiceClient(f.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + fileSize := int64(512) + fClient := setupGetRangeListTest(_require, testName, fileSize, shareClient) + + resp, err := fClient.GetRangeList(context.Background(), &file.GetRangeListOptions{ + Range: file.HTTPRange{Offset: 0, Count: fileSize}, + }) + _require.NoError(err) + _require.Len(resp.Ranges, 1) + _require.EqualValues(int64(0), *(resp.Ranges[0].Start)) + _require.EqualValues(fileSize-1, *(resp.Ranges[0].End)) +} + +func (f *FileRecordedTestsSuite) TestFileGetRangeListNonContiguousRangesCountExceed() { + _require := require.New(f.T()) + testName := f.T().Name() + + svcClient, err := testcommon.GetServiceClient(f.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + fileSize := int64(512) + fClient := setupGetRangeListTest(_require, testName, fileSize, shareClient) + + resp, err := fClient.GetRangeList(context.Background(), &file.GetRangeListOptions{ + Range: file.HTTPRange{Offset: 0, Count: fileSize + 1}, + }) + _require.NoError(err) + _require.NoError(err) + _require.Len(resp.Ranges, 1) + _require.EqualValues(*resp.Ranges[0], file.ShareFileRange{Start: to.Ptr(int64(0)), End: to.Ptr(fileSize - 1)}) +} + +func (f *FileRecordedTestsSuite) TestFileGetRangeListSnapshot() { + _require := require.New(f.T()) + testName := f.T().Name() + + svcClient, err := testcommon.GetServiceClient(f.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer func() { + _, err := shareClient.Delete(context.Background(), &share.DeleteOptions{DeleteSnapshots: to.Ptr(share.DeleteSnapshotsOptionTypeInclude)}) + _require.NoError(err) + }() + + fileSize := int64(512) + fClient := setupGetRangeListTest(_require, testName, fileSize, shareClient) + + resp, _ := shareClient.CreateSnapshot(context.Background(), nil) + _require.NotNil(resp.Snapshot) + + resp2, err := fClient.GetRangeList(context.Background(), &file.GetRangeListOptions{ + Range: file.HTTPRange{Offset: 0, Count: file.CountToEnd}, + ShareSnapshot: resp.Snapshot, + }) + _require.NoError(err) + _require.Len(resp2.Ranges, 1) + _require.EqualValues(*resp2.Ranges[0], file.ShareFileRange{Start: to.Ptr(int64(0)), End: to.Ptr(fileSize - 1)}) +} + +func (f *FileRecordedTestsSuite) TestFileUploadDownloadSmallBuffer() { + _require := require.New(f.T()) + testName := f.T().Name() + + svcClient, err := testcommon.GetServiceClient(f.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + var fileSize int64 = 10 * 1024 + fClient := shareClient.NewRootDirectoryClient().NewFileClient(testcommon.GenerateFileName(testName)) + _, err = fClient.Create(context.Background(), fileSize, nil) + _require.NoError(err) + + gResp, err := fClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp.ContentLength, fileSize) + + _, content := testcommon.GenerateData(int(fileSize)) + md5Value := md5.Sum(content) + contentMD5 := md5Value[:] + + err = fClient.UploadBuffer(context.Background(), content, &file.UploadBufferOptions{ + Concurrency: 5, + ChunkSize: 2 * 1024, + }) + _require.NoError(err) + + destBuffer := make([]byte, fileSize) + cnt, err := fClient.DownloadBuffer(context.Background(), destBuffer, &file.DownloadBufferOptions{ + ChunkSize: 2 * 1024, + Concurrency: 5, + }) + _require.NoError(err) + _require.Equal(cnt, fileSize) + + downloadedMD5Value := md5.Sum(destBuffer) + downloadedContentMD5 := downloadedMD5Value[:] + + _require.EqualValues(downloadedContentMD5, contentMD5) + + gResp2, err := fClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp2.ContentLength, fileSize) + + rangeList, err := fClient.GetRangeList(context.Background(), nil) + _require.NoError(err) + _require.Len(rangeList.Ranges, 1) + _require.Equal(*rangeList.Ranges[0].Start, int64(0)) + _require.Equal(*rangeList.Ranges[0].End, fileSize-1) +} + +func (f *FileRecordedTestsSuite) TestFileUploadDownloadSmallFile() { + _require := require.New(f.T()) + testName := f.T().Name() + + svcClient, err := testcommon.GetServiceClient(f.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + var fileSize int64 = 10 * 1024 + fClient := shareClient.NewRootDirectoryClient().NewFileClient(testcommon.GenerateFileName(testName)) + _, err = fClient.Create(context.Background(), fileSize, nil) + _require.NoError(err) + + gResp, err := fClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp.ContentLength, fileSize) + + // create local file + _, content := testcommon.GenerateData(int(fileSize)) + srcFileName := "testFileUpload" + err = ioutil.WriteFile(srcFileName, content, 0644) + _require.NoError(err) + defer func() { + err = os.Remove(srcFileName) + _require.NoError(err) + }() + fh, err := os.Open(srcFileName) + _require.NoError(err) + defer func(fh *os.File) { + err := fh.Close() + _require.NoError(err) + }(fh) + + srcHash := md5.New() + _, err = io.Copy(srcHash, fh) + _require.NoError(err) + contentMD5 := srcHash.Sum(nil) + + err = fClient.UploadFile(context.Background(), fh, &file.UploadFileOptions{ + Concurrency: 5, + ChunkSize: 2 * 1024, + }) + _require.NoError(err) + + destFileName := "SmallFile-downloaded.bin" + destFile, err := os.Create(destFileName) + _require.NoError(err) + defer func(name string) { + err = os.Remove(name) + _require.NoError(err) + }(destFileName) + defer func(destFile *os.File) { + err = destFile.Close() + _require.NoError(err) + }(destFile) + + cnt, err := fClient.DownloadFile(context.Background(), destFile, &file.DownloadFileOptions{ + ChunkSize: 2 * 1024, + Concurrency: 5, + }) + _require.NoError(err) + _require.Equal(cnt, fileSize) + + destHash := md5.New() + _, err = io.Copy(destHash, destFile) + _require.NoError(err) + downloadedContentMD5 := destHash.Sum(nil) + + _require.EqualValues(downloadedContentMD5, contentMD5) + + gResp2, err := fClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp2.ContentLength, fileSize) + + rangeList, err := fClient.GetRangeList(context.Background(), nil) + _require.NoError(err) + _require.Len(rangeList.Ranges, 1) + _require.Equal(*rangeList.Ranges[0].Start, int64(0)) + _require.Equal(*rangeList.Ranges[0].End, fileSize-1) +} + +func (f *FileRecordedTestsSuite) TestFileUploadDownloadSmallStream() { + _require := require.New(f.T()) + testName := f.T().Name() + + svcClient, err := testcommon.GetServiceClient(f.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + var fileSize int64 = 10 * 1024 + fClient := shareClient.NewRootDirectoryClient().NewFileClient(testcommon.GenerateFileName(testName)) + _, err = fClient.Create(context.Background(), fileSize, nil) + _require.NoError(err) + + gResp, err := fClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp.ContentLength, fileSize) + + _, content := testcommon.GenerateData(int(fileSize)) + md5Value := md5.Sum(content) + contentMD5 := md5Value[:] + + err = fClient.UploadStream(context.Background(), streaming.NopCloser(bytes.NewReader(content)), &file.UploadStreamOptions{ + Concurrency: 5, + ChunkSize: 2 * 1024, + }) + _require.NoError(err) + + dResp, err := fClient.DownloadStream(context.Background(), nil) + _require.NoError(err) + + data, err := io.ReadAll(dResp.Body) + _require.NoError(err) + + downloadedMD5Value := md5.Sum(data) + downloadedContentMD5 := downloadedMD5Value[:] + + _require.EqualValues(downloadedContentMD5, contentMD5) + + gResp2, err := fClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp2.ContentLength, fileSize) + + rangeList, err := fClient.GetRangeList(context.Background(), nil) + _require.NoError(err) + _require.Len(rangeList.Ranges, 1) + _require.Equal(*rangeList.Ranges[0].Start, int64(0)) + _require.Equal(*rangeList.Ranges[0].End, fileSize-1) +} + +func (f *FileRecordedTestsSuite) TestFileUploadDownloadWithProgress() { + _require := require.New(f.T()) + testName := f.T().Name() + + svcClient, err := testcommon.GetServiceClient(f.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + var fileSize int64 = 10 * 1024 + fClient := shareClient.NewRootDirectoryClient().NewFileClient(testcommon.GenerateFileName(testName)) + _, err = fClient.Create(context.Background(), fileSize, nil) + _require.NoError(err) + + gResp, err := fClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp.ContentLength, fileSize) + + _, content := testcommon.GenerateData(int(fileSize)) + md5Value := md5.Sum(content) + contentMD5 := md5Value[:] + + bytesUploaded := int64(0) + err = fClient.UploadBuffer(context.Background(), content, &file.UploadBufferOptions{ + Concurrency: 5, + ChunkSize: 2 * 1024, + Progress: func(bytesTransferred int64) { + _require.GreaterOrEqual(bytesTransferred, bytesUploaded) + bytesUploaded = bytesTransferred + }, + }) + _require.NoError(err) + _require.Equal(bytesUploaded, fileSize) + + destBuffer := make([]byte, fileSize) + bytesDownloaded := int64(0) + cnt, err := fClient.DownloadBuffer(context.Background(), destBuffer, &file.DownloadBufferOptions{ + ChunkSize: 2 * 1024, + Concurrency: 5, + Progress: func(bytesTransferred int64) { + _require.GreaterOrEqual(bytesTransferred, bytesDownloaded) + bytesDownloaded = bytesTransferred + }, + }) + _require.NoError(err) + _require.Equal(cnt, fileSize) + _require.Equal(bytesDownloaded, fileSize) + + downloadedMD5Value := md5.Sum(destBuffer) + downloadedContentMD5 := downloadedMD5Value[:] + + _require.EqualValues(downloadedContentMD5, contentMD5) + + gResp2, err := fClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp2.ContentLength, fileSize) + + rangeList, err := fClient.GetRangeList(context.Background(), nil) + _require.NoError(err) + _require.Len(rangeList.Ranges, 1) + _require.Equal(*rangeList.Ranges[0].Start, int64(0)) + _require.Equal(*rangeList.Ranges[0].End, fileSize-1) +} + +func (f *FileRecordedTestsSuite) TestFileListHandlesDefault() { + _require := require.New(f.T()) + testName := f.T().Name() + + svcClient, err := testcommon.GetServiceClient(f.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + fClient := testcommon.CreateNewFileFromShare(context.Background(), _require, testcommon.GenerateFileName(testName), 2048, shareClient) + + resp, err := fClient.ListHandles(context.Background(), nil) + _require.NoError(err) + _require.Len(resp.Handles, 0) + _require.NotNil(resp.NextMarker) + _require.Equal(*resp.NextMarker, "") +} + +func (f *FileRecordedTestsSuite) TestFileForceCloseHandlesDefault() { + _require := require.New(f.T()) + testName := f.T().Name() + + svcClient, err := testcommon.GetServiceClient(f.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareClient := testcommon.CreateNewShare(context.Background(), _require, testcommon.GenerateShareName(testName), svcClient) + defer testcommon.DeleteShare(context.Background(), _require, shareClient) + + fClient := testcommon.CreateNewFileFromShare(context.Background(), _require, testcommon.GenerateFileName(testName), 2048, shareClient) + + resp, err := fClient.ForceCloseHandles(context.Background(), "*", nil) + _require.NoError(err) + _require.EqualValues(*resp.NumberOfHandlesClosed, 0) + _require.EqualValues(*resp.NumberOfHandlesFailedToClose, 0) + _require.Nil(resp.Marker) +} + +// TODO: Add tests for GetRangeList, ListHandles and ForceCloseHandles + +// TODO: Add tests for retry header options diff --git a/sdk/storage/azfile/file/constants.go b/sdk/storage/azfile/file/constants.go index 0498935ed4c3..c5687bd1b3b5 100644 --- a/sdk/storage/azfile/file/constants.go +++ b/sdk/storage/azfile/file/constants.go @@ -12,13 +12,17 @@ import ( ) const ( - _1MiB = 1024 * 1024 + _1MiB = 1024 * 1024 + CountToEnd = 0 // MaxUpdateRangeBytes indicates the maximum number of bytes that can be updated in a call to Client.UploadRange. MaxUpdateRangeBytes = 4 * 1024 * 1024 // 4MiB // MaxFileSize indicates the maximum size of the file allowed. MaxFileSize = 4 * 1024 * 1024 * 1024 * 1024 // 4 TiB + + // DefaultDownloadChunkSize is default chunk size + DefaultDownloadChunkSize = int64(4 * 1024 * 1024) // 4MiB ) // CopyStatusType defines the states of the copy operation. diff --git a/sdk/storage/azfile/file/mmf_linux.go b/sdk/storage/azfile/file/mmf_linux.go index 93d718c8487f..dc17528e6516 100644 --- a/sdk/storage/azfile/file/mmf_linux.go +++ b/sdk/storage/azfile/file/mmf_linux.go @@ -2,6 +2,9 @@ // +build go1.18 // +build linux darwin freebsd openbsd netbsd solaris +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + package file import ( diff --git a/sdk/storage/azfile/file/models.go b/sdk/storage/azfile/file/models.go index 11f36beda260..ea5b7e20c333 100644 --- a/sdk/storage/azfile/file/models.go +++ b/sdk/storage/azfile/file/models.go @@ -339,15 +339,51 @@ func (o *DownloadStreamOptions) format() (*generated.FileClientDownloadOptions, // --------------------------------------------------------------------------------------------------------------------- +// downloadOptions contains common options used by the Client.DownloadBuffer and Client.DownloadFile methods. +type downloadOptions struct { + // Range specifies a range of bytes. The default value is all bytes. + Range HTTPRange + + // ChunkSize specifies the chunk size to use for each parallel download; the default size is 4MB. + ChunkSize int64 + + // Progress is a function that is invoked periodically as bytes are received. + Progress func(bytesTransferred int64) + + // LeaseAccessConditions contains optional parameters to access leased entity. + LeaseAccessConditions *LeaseAccessConditions + + // Concurrency indicates the maximum number of chunks to download in parallel (0=default). + Concurrency uint16 + + // RetryReaderOptionsPerChunk is used when downloading each chunk. + RetryReaderOptionsPerChunk RetryReaderOptions +} + +func (o *downloadOptions) getFilePropertiesOptions() *GetPropertiesOptions { + if o == nil { + return nil + } + return &GetPropertiesOptions{ + LeaseAccessConditions: o.LeaseAccessConditions, + } +} + +func (o *downloadOptions) getDownloadFileOptions(rng HTTPRange) *DownloadStreamOptions { + downloadFileOptions := &DownloadStreamOptions{ + Range: rng, + } + if o != nil { + downloadFileOptions.LeaseAccessConditions = o.LeaseAccessConditions + } + return downloadFileOptions +} + // DownloadBufferOptions contains the optional parameters for the Client.DownloadBuffer method. type DownloadBufferOptions struct { // Range specifies a range of bytes. The default value is all bytes. Range HTTPRange - // When this header is set to true and specified together with the Range header, the service returns the MD5 hash for the - // range, as long as the range is less than or equal to 4 MB in size. - RangeGetContentMD5 *bool - // ChunkSize specifies the chunk size to use for each parallel download; the default size is 4MB. ChunkSize int64 @@ -360,8 +396,8 @@ type DownloadBufferOptions struct { // Concurrency indicates the maximum number of chunks to download in parallel (0=default). Concurrency uint16 - // RetryReaderOptionsPerRange is used when downloading each chunk. - RetryReaderOptionsPerRange RetryReaderOptions + // RetryReaderOptionsPerChunk is used when downloading each chunk. + RetryReaderOptionsPerChunk RetryReaderOptions } // --------------------------------------------------------------------------------------------------------------------- @@ -371,10 +407,6 @@ type DownloadFileOptions struct { // Range specifies a range of bytes. The default value is all bytes. Range HTTPRange - // When this header is set to true and specified together with the Range header, the service returns the MD5 hash for the - // range, as long as the range is less than or equal to 4 MB in size. - RangeGetContentMD5 *bool - // ChunkSize specifies the chunk size to use for each parallel download; the default size is 4MB. ChunkSize int64 @@ -387,8 +419,8 @@ type DownloadFileOptions struct { // Concurrency indicates the maximum number of chunks to download in parallel (0=default). Concurrency uint16 - // RetryReaderOptionsPerRange is used when downloading each chunk. - RetryReaderOptionsPerRange RetryReaderOptions + // RetryReaderOptionsPerChunk is used when downloading each chunk. + RetryReaderOptionsPerChunk RetryReaderOptions } // --------------------------------------------------------------------------------------------------------------------- @@ -456,7 +488,7 @@ func (o *UploadRangeOptions) format(offset int64, body io.ReadSeekCloser) (strin leaseAccessConditions = o.LeaseAccessConditions } if o != nil && o.TransactionalValidation != nil { - body, err = o.TransactionalValidation.Apply(body, uploadRangeOptions) + _, err = o.TransactionalValidation.Apply(body, uploadRangeOptions) if err != nil { return "", 0, nil, nil, err } @@ -667,6 +699,7 @@ func (o *uploadFromReaderOptions) getUploadRangeOptions() *UploadRangeOptions { // UploadStreamOptions provides set of configurations for Client.UploadStream operation. type UploadStreamOptions struct { // ChunkSize defines the size of the buffer used during upload. The default and minimum value is 1 MiB. + // Maximum size of a chunk is MaxUpdateRangeBytes. ChunkSize int64 // Concurrency defines the max number of concurrent uploads to be performed to upload the file. diff --git a/sdk/storage/azfile/file/responses.go b/sdk/storage/azfile/file/responses.go index 208c38666b13..e47d87741861 100644 --- a/sdk/storage/azfile/file/responses.go +++ b/sdk/storage/azfile/file/responses.go @@ -7,7 +7,9 @@ package file import ( + "context" "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/internal/generated" + "io" ) // CreateResponse contains the response from method Client.Create. @@ -38,8 +40,32 @@ type DownloadResponse = generated.FileClientDownloadResponse // To read from the stream, read from the Body field, or call the NewRetryReader method. type DownloadStreamResponse struct { DownloadResponse - client *Client - getInfo httpGetterInfo + + client *Client + getInfo httpGetterInfo + leaseAccessConditions *LeaseAccessConditions +} + +// NewRetryReader constructs new RetryReader stream for reading data. If a connection fails while +// reading, it will make additional requests to reestablish a connection and continue reading. +// Pass nil for options to accept the default options. +// Callers of this method should not access the DownloadStreamResponse.Body field. +func (r *DownloadStreamResponse) NewRetryReader(ctx context.Context, options *RetryReaderOptions) *RetryReader { + if options == nil { + options = &RetryReaderOptions{} + } + + return newRetryReader(ctx, r.Body, r.getInfo, func(ctx context.Context, getInfo httpGetterInfo) (io.ReadCloser, error) { + options := DownloadStreamOptions{ + Range: getInfo.Range, + LeaseAccessConditions: r.leaseAccessConditions, + } + resp, err := r.client.DownloadStream(ctx, &options) + if err != nil { + return nil, err + } + return resp.Body, err + }, *options) } // ResizeResponse contains the response from method Client.Resize. diff --git a/sdk/storage/azfile/file/retry_reader.go b/sdk/storage/azfile/file/retry_reader.go index 723f31fe72f4..2e76a91f3169 100644 --- a/sdk/storage/azfile/file/retry_reader.go +++ b/sdk/storage/azfile/file/retry_reader.go @@ -6,16 +6,21 @@ package file -import "github.com/Azure/azure-sdk-for-go/sdk/azcore" +import ( + "context" + "io" + "net" + "strings" + "sync" +) + +// HTTPGetter is a function type that refers to a method that performs an HTTP GET operation. +type httpGetter func(ctx context.Context, i httpGetterInfo) (io.ReadCloser, error) // httpGetterInfo is passed to an HTTPGetter function passing it parameters // that should be used to make an HTTP GET request. type httpGetterInfo struct { Range HTTPRange - - // ETag specifies the resource's etag that should be used when creating - // the HTTP GET request's If-Match header - ETag *azcore.ETag } // RetryReaderOptions configures the retry reader's behavior. @@ -46,3 +51,136 @@ type RetryReaderOptions struct { doInjectErrorRound int32 injectedError error } + +// RetryReader attempts to read from response, and if there is a retry-able network error +// returned during reading, it will retry according to retry reader option through executing +// user defined action with provided data to get a new response, and continue the overall reading process +// through reading from the new response. +// RetryReader implements the io.ReadCloser interface. +type RetryReader struct { + ctx context.Context + info httpGetterInfo + retryReaderOptions RetryReaderOptions + getter httpGetter + countWasBounded bool + + // we support Close-ing during Reads (from other goroutines), so we protect the shared state, which is response + responseMu *sync.Mutex + response io.ReadCloser +} + +// newRetryReader creates a retry reader. +func newRetryReader(ctx context.Context, initialResponse io.ReadCloser, info httpGetterInfo, getter httpGetter, o RetryReaderOptions) *RetryReader { + if o.MaxRetries < 1 { + o.MaxRetries = 3 + } + return &RetryReader{ + ctx: ctx, + getter: getter, + info: info, + countWasBounded: info.Range.Count != CountToEnd, + response: initialResponse, + responseMu: &sync.Mutex{}, + retryReaderOptions: o, + } +} + +// setResponse function +func (s *RetryReader) setResponse(r io.ReadCloser) { + s.responseMu.Lock() + defer s.responseMu.Unlock() + s.response = r +} + +// Read from retry reader +func (s *RetryReader) Read(p []byte) (n int, err error) { + for try := int32(0); ; try++ { + //fmt.Println(try) // Comment out for debugging. + if s.countWasBounded && s.info.Range.Count == CountToEnd { + // User specified an original count and the remaining bytes are 0, return 0, EOF + return 0, io.EOF + } + + s.responseMu.Lock() + resp := s.response + s.responseMu.Unlock() + if resp == nil { // We don't have a response stream to read from, try to get one. + newResponse, err := s.getter(s.ctx, s.info) + if err != nil { + return 0, err + } + // Successful GET; this is the network stream we'll read from. + s.setResponse(newResponse) + resp = newResponse + } + n, err := resp.Read(p) // Read from the stream (this will return non-nil err if forceRetry is called, from another goroutine, while it is running) + + // Injection mechanism for testing. + if s.retryReaderOptions.doInjectError && try == s.retryReaderOptions.doInjectErrorRound { + if s.retryReaderOptions.injectedError != nil { + err = s.retryReaderOptions.injectedError + } else { + err = &net.DNSError{IsTemporary: true} + } + } + + // We successfully read data or end EOF. + if err == nil || err == io.EOF { + s.info.Range.Offset += int64(n) // Increments the start offset in case we need to make a new HTTP request in the future + if s.info.Range.Count != CountToEnd { + s.info.Range.Count -= int64(n) // Decrement the count in case we need to make a new HTTP request in the future + } + return n, err // Return the return to the caller + } + _ = s.Close() + + s.setResponse(nil) // Our stream is no longer good + + // Check the retry count and error code, and decide whether to retry. + retriesExhausted := try >= s.retryReaderOptions.MaxRetries + _, isNetError := err.(net.Error) + isUnexpectedEOF := err == io.ErrUnexpectedEOF + willRetry := (isNetError || isUnexpectedEOF || s.wasRetryableEarlyClose(err)) && !retriesExhausted + + // Notify, for logging purposes, of any failures + if s.retryReaderOptions.OnFailedRead != nil { + failureCount := try + 1 // because try is zero-based + s.retryReaderOptions.OnFailedRead(failureCount, err, s.info.Range, willRetry) + } + + if willRetry { + continue + // Loop around and try to get and read from new stream. + } + return n, err // Not retryable, or retries exhausted, so just return + } +} + +// By default, we allow early Closing, from another concurrent goroutine, to be used to force a retry +// Is this safe, to close early from another goroutine? Early close ultimately ends up calling +// net.Conn.Close, and that is documented as "Any blocked Read or Write operations will be unblocked and return errors" +// which is exactly the behaviour we want. +// NOTE: that if caller has forced an early Close from a separate goroutine (separate from the Read) +// then there are two different types of error that may happen - either the one we check for here, +// or a net.Error (due to closure of connection). Which one happens depends on timing. We only need this routine +// to check for one, since the other is a net.Error, which our main Read retry loop is already handing. +func (s *RetryReader) wasRetryableEarlyClose(err error) bool { + if s.retryReaderOptions.EarlyCloseAsError { + return false // user wants all early closes to be errors, and so not retryable + } + // unfortunately, http.errReadOnClosedResBody is private, so the best we can do here is to check for its text + return strings.HasSuffix(err.Error(), ReadOnClosedBodyMessage) +} + +// ReadOnClosedBodyMessage of retry reader +const ReadOnClosedBodyMessage = "read on closed response body" + +// Close retry reader +func (s *RetryReader) Close() error { + s.responseMu.Lock() + defer s.responseMu.Unlock() + if s.response != nil { + return s.response.Close() + } + return nil +} diff --git a/sdk/storage/azfile/internal/shared/bytes_writer.go b/sdk/storage/azfile/internal/shared/bytes_writer.go new file mode 100644 index 000000000000..8d4d35bdeffd --- /dev/null +++ b/sdk/storage/azfile/internal/shared/bytes_writer.go @@ -0,0 +1,30 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package shared + +import ( + "errors" +) + +type bytesWriter []byte + +func NewBytesWriter(b []byte) bytesWriter { + return b +} + +func (c bytesWriter) WriteAt(b []byte, off int64) (int, error) { + if off >= int64(len(c)) || off < 0 { + return 0, errors.New("offset value is out of range") + } + + n := copy(c[int(off):], b) + if n < len(b) { + return n, errors.New("not enough space for all bytes") + } + + return n, nil +} diff --git a/sdk/storage/azfile/internal/shared/bytes_writer_test.go b/sdk/storage/azfile/internal/shared/bytes_writer_test.go new file mode 100644 index 000000000000..5f1bc53c29ca --- /dev/null +++ b/sdk/storage/azfile/internal/shared/bytes_writer_test.go @@ -0,0 +1,37 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package shared + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBytesWriter(t *testing.T) { + b := make([]byte, 10) + buffer := NewBytesWriter(b) + + count, err := buffer.WriteAt([]byte{1, 2}, 10) + require.Contains(t, err.Error(), "offset value is out of range") + require.Equal(t, count, 0) + + count, err = buffer.WriteAt([]byte{1, 2}, -1) + require.Contains(t, err.Error(), "offset value is out of range") + require.Equal(t, count, 0) + + count, err = buffer.WriteAt([]byte{1, 2}, 9) + require.Contains(t, err.Error(), "not enough space for all bytes") + require.Equal(t, count, 1) + require.Equal(t, bytes.Compare(b, []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 1}), 0) + + count, err = buffer.WriteAt([]byte{1, 2}, 8) + require.NoError(t, err) + require.Equal(t, count, 2) + require.Equal(t, bytes.Compare(b, []byte{0, 0, 0, 0, 0, 0, 0, 0, 1, 2}), 0) +} diff --git a/sdk/storage/azfile/internal/shared/section_writer.go b/sdk/storage/azfile/internal/shared/section_writer.go new file mode 100644 index 000000000000..c8528a2e3ed2 --- /dev/null +++ b/sdk/storage/azfile/internal/shared/section_writer.go @@ -0,0 +1,53 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package shared + +import ( + "errors" + "io" +) + +type SectionWriter struct { + Count int64 + Offset int64 + Position int64 + WriterAt io.WriterAt +} + +func NewSectionWriter(c io.WriterAt, off int64, count int64) *SectionWriter { + return &SectionWriter{ + Count: count, + Offset: off, + WriterAt: c, + } +} + +func (c *SectionWriter) Write(p []byte) (int, error) { + remaining := c.Count - c.Position + + if remaining <= 0 { + return 0, errors.New("end of section reached") + } + + slice := p + + if int64(len(slice)) > remaining { + slice = slice[:remaining] + } + + n, err := c.WriterAt.WriteAt(slice, c.Offset+c.Position) + c.Position += int64(n) + if err != nil { + return n, err + } + + if len(p) > n { + return n, errors.New("not enough space for all bytes") + } + + return n, nil +} diff --git a/sdk/storage/azfile/internal/shared/section_writer_test.go b/sdk/storage/azfile/internal/shared/section_writer_test.go new file mode 100644 index 000000000000..a1cf22da410a --- /dev/null +++ b/sdk/storage/azfile/internal/shared/section_writer_test.go @@ -0,0 +1,98 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package shared + +import ( + "bytes" + "io" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSectionWriter(t *testing.T) { + b := [10]byte{} + buffer := NewBytesWriter(b[:]) + + section := NewSectionWriter(buffer, 0, 5) + require.Equal(t, section.Count, int64(5)) + require.Equal(t, section.Offset, int64(0)) + require.Equal(t, section.Position, int64(0)) + + count, err := section.Write([]byte{1, 2, 3}) + require.NoError(t, err) + require.Equal(t, count, 3) + require.Equal(t, section.Position, int64(3)) + require.Equal(t, b, [10]byte{1, 2, 3, 0, 0, 0, 0, 0, 0, 0}) + + count, err = section.Write([]byte{4, 5, 6}) + require.Contains(t, err.Error(), "not enough space for all bytes") + require.Equal(t, count, 2) + require.Equal(t, section.Position, int64(5)) + require.Equal(t, b, [10]byte{1, 2, 3, 4, 5, 0, 0, 0, 0, 0}) + + count, err = section.Write([]byte{6, 7, 8}) + require.Contains(t, err.Error(), "end of section reached") + require.Equal(t, count, 0) + require.Equal(t, section.Position, int64(5)) + require.Equal(t, b, [10]byte{1, 2, 3, 4, 5, 0, 0, 0, 0, 0}) + + // Intentionally create a section writer which will attempt to write + // outside the bounds of the buffer. + section = NewSectionWriter(buffer, 5, 6) + require.Equal(t, section.Count, int64(6)) + require.Equal(t, section.Offset, int64(5)) + require.Equal(t, section.Position, int64(0)) + + count, err = section.Write([]byte{6, 7, 8}) + require.NoError(t, err) + require.Equal(t, count, 3) + require.Equal(t, section.Position, int64(3)) + require.Equal(t, b, [10]byte{1, 2, 3, 4, 5, 6, 7, 8, 0, 0}) + + // Attempt to write past the end of the section. Since the underlying + // buffer rejects the write it gives the same error as in the normal case. + count, err = section.Write([]byte{9, 10, 11}) + require.Contains(t, err.Error(), "not enough space for all bytes") + require.Equal(t, count, 2) + require.Equal(t, section.Position, int64(5)) + require.Equal(t, b, [10]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) + + // Attempt to write past the end of the buffer. In this case the buffer + // rejects the write completely since it falls completely out of bounds. + count, err = section.Write([]byte{11, 12, 13}) + require.Contains(t, err.Error(), "offset value is out of range") + require.Equal(t, count, 0) + require.Equal(t, section.Position, int64(5)) + require.Equal(t, b, [10]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) +} + +func TestSectionWriterCopySrcDestEmpty(t *testing.T) { + input := make([]byte, 0) + reader := bytes.NewReader(input) + + output := make([]byte, 0) + buffer := NewBytesWriter(output) + section := NewSectionWriter(buffer, 0, 0) + + count, err := io.Copy(section, reader) + require.NoError(t, err) + require.Equal(t, count, int64(0)) +} + +func TestSectionWriterCopyDestEmpty(t *testing.T) { + input := make([]byte, 10) + reader := bytes.NewReader(input) + + output := make([]byte, 0) + buffer := NewBytesWriter(output) + section := NewSectionWriter(buffer, 0, 0) + + count, err := io.Copy(section, reader) + require.Contains(t, err.Error(), "end of section reached") + require.Equal(t, count, int64(0)) +} diff --git a/sdk/storage/azfile/internal/shared/shared_test.go b/sdk/storage/azfile/internal/shared/shared_test.go new file mode 100644 index 000000000000..1cd5da99469d --- /dev/null +++ b/sdk/storage/azfile/internal/shared/shared_test.go @@ -0,0 +1,95 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package shared + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseConnectionStringInvalid(t *testing.T) { + badConnectionStrings := []string{ + "", + "foobar", + "foo;bar;baz", + "foo=;bar=;", + "=", + ";", + "=;==", + "foobar=baz=foo", + } + + for _, badConnStr := range badConnectionStrings { + parsed, err := ParseConnectionString(badConnStr) + require.Error(t, err) + require.Zero(t, parsed) + } +} + +func TestParseConnectionString(t *testing.T) { + connStr := "DefaultEndpointsProtocol=https;AccountName=dummyaccount;AccountKey=secretkeykey;EndpointSuffix=core.windows.net" + parsed, err := ParseConnectionString(connStr) + require.NoError(t, err) + require.Equal(t, "https://dummyaccount.file.core.windows.net", parsed.ServiceURL) + require.Equal(t, "dummyaccount", parsed.AccountName) + require.Equal(t, "secretkeykey", parsed.AccountKey) +} + +func TestParseConnectionStringHTTP(t *testing.T) { + connStr := "DefaultEndpointsProtocol=http;AccountName=dummyaccount;AccountKey=secretkeykey;EndpointSuffix=core.windows.net" + parsed, err := ParseConnectionString(connStr) + require.NoError(t, err) + require.Equal(t, "http://dummyaccount.file.core.windows.net", parsed.ServiceURL) + require.Equal(t, "dummyaccount", parsed.AccountName) + require.Equal(t, "secretkeykey", parsed.AccountKey) +} + +func TestParseConnectionStringBasic(t *testing.T) { + connStr := "AccountName=dummyaccount;AccountKey=secretkeykey" + parsed, err := ParseConnectionString(connStr) + require.NoError(t, err) + require.Equal(t, "https://dummyaccount.file.core.windows.net", parsed.ServiceURL) + require.Equal(t, "dummyaccount", parsed.AccountName) + require.Equal(t, "secretkeykey", parsed.AccountKey) +} + +func TestParseConnectionStringCustomDomain(t *testing.T) { + connStr := "AccountName=dummyaccount;AccountKey=secretkeykey;FileEndpoint=www.mydomain.com;" + parsed, err := ParseConnectionString(connStr) + require.NoError(t, err) + require.Equal(t, "www.mydomain.com", parsed.ServiceURL) + require.Equal(t, "dummyaccount", parsed.AccountName) + require.Equal(t, "secretkeykey", parsed.AccountKey) +} + +func TestParseConnectionStringSAS(t *testing.T) { + connStr := "AccountName=dummyaccount;SharedAccessSignature=fakesharedaccesssignature;" + parsed, err := ParseConnectionString(connStr) + require.NoError(t, err) + require.Equal(t, "https://dummyaccount.file.core.windows.net/?fakesharedaccesssignature", parsed.ServiceURL) + require.Empty(t, parsed.AccountName) + require.Empty(t, parsed.AccountKey) +} + +func TestParseConnectionStringChinaCloud(t *testing.T) { + connStr := "AccountName=dummyaccountname;AccountKey=secretkeykey;DefaultEndpointsProtocol=http;EndpointSuffix=core.chinacloudapi.cn;" + parsed, err := ParseConnectionString(connStr) + require.NoError(t, err) + require.Equal(t, "http://dummyaccountname.file.core.chinacloudapi.cn", parsed.ServiceURL) + require.Equal(t, "dummyaccountname", parsed.AccountName) + require.Equal(t, "secretkeykey", parsed.AccountKey) +} + +func TestCParseConnectionStringAzurite(t *testing.T) { + connStr := "DefaultEndpointsProtocol=http;AccountName=dummyaccountname;AccountKey=secretkeykey;FileEndpoint=http://local-machine:11002/custom/account/path/faketokensignature;" + parsed, err := ParseConnectionString(connStr) + require.NoError(t, err) + require.Equal(t, "http://local-machine:11002/custom/account/path/faketokensignature", parsed.ServiceURL) + require.Equal(t, "dummyaccountname", parsed.AccountName) + require.Equal(t, "secretkeykey", parsed.AccountKey) +} diff --git a/sdk/storage/azfile/internal/testcommon/clients_auth.go b/sdk/storage/azfile/internal/testcommon/clients_auth.go index 1e1bad7bda4c..b7804577c2fb 100644 --- a/sdk/storage/azfile/internal/testcommon/clients_auth.go +++ b/sdk/storage/azfile/internal/testcommon/clients_auth.go @@ -12,12 +12,15 @@ import ( "errors" "fmt" "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming" "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "github.com/Azure/azure-sdk-for-go/sdk/internal/recording" "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/directory" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/file" "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/service" "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/share" "github.com/stretchr/testify/require" + "strings" "testing" ) @@ -167,3 +170,33 @@ func DeleteDirectory(ctx context.Context, _require *require.Assertions, dirClien _, err := dirClient.Delete(ctx, nil) _require.NoError(err) } + +func GetFileClientFromShare(fileName string, shareClient *share.Client) *file.Client { + return shareClient.NewRootDirectoryClient().NewFileClient(fileName) +} + +func CreateNewFileFromShare(ctx context.Context, _require *require.Assertions, fileName string, fileSize int64, shareClient *share.Client) *file.Client { + fClient := GetFileClientFromShare(fileName, shareClient) + + _, err := fClient.Create(ctx, fileSize, nil) + _require.NoError(err) + + return fClient +} + +func CreateNewFileFromShareWithData(ctx context.Context, _require *require.Assertions, fileName string, shareClient *share.Client) *file.Client { + fClient := GetFileClientFromShare(fileName, shareClient) + + _, err := fClient.Create(ctx, int64(len(FileDefaultData)), nil) + _require.NoError(err) + + _, err = fClient.UploadRange(ctx, 0, streaming.NopCloser(strings.NewReader(FileDefaultData)), nil) + _require.NoError(err) + + return fClient +} + +func DeleteFile(ctx context.Context, _require *require.Assertions, fileClient *file.Client) { + _, err := fileClient.Delete(ctx, nil) + _require.NoError(err) +} diff --git a/sdk/storage/azfile/internal/testcommon/common.go b/sdk/storage/azfile/internal/testcommon/common.go index f81f0b982cad..e83f8d00114d 100644 --- a/sdk/storage/azfile/internal/testcommon/common.go +++ b/sdk/storage/azfile/internal/testcommon/common.go @@ -8,12 +8,16 @@ package testcommon import ( + "bytes" "errors" "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming" "github.com/Azure/azure-sdk-for-go/sdk/internal/recording" "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/fileerror" "github.com/stretchr/testify/require" + "io" "os" + "strconv" "strings" "testing" ) @@ -22,6 +26,7 @@ const ( SharePrefix = "gos" DirectoryPrefix = "godir" FilePrefix = "gotestfile" + FileDefaultData = "GoFileDefaultData" ) func GenerateShareName(testName string) string { @@ -40,6 +45,34 @@ func GenerateFileName(testName string) string { return FilePrefix + GenerateEntityName(testName) } +const random64BString string = "2SDgZj6RkKYzJpu04sweQek4uWHO8ndPnYlZ0tnFS61hjnFZ5IkvIGGY44eKABov" + +func GenerateData(sizeInBytes int) (io.ReadSeekCloser, []byte) { + data := make([]byte, sizeInBytes) + _len := len(random64BString) + if sizeInBytes > _len { + count := sizeInBytes / _len + if sizeInBytes%_len != 0 { + count = count + 1 + } + copy(data[:], strings.Repeat(random64BString, count)) + } else { + copy(data[:], random64BString) + } + return streaming.NopCloser(bytes.NewReader(data)), data +} + +func ValidateHTTPErrorCode(_require *require.Assertions, err error, code int) { + _require.Error(err) + var responseErr *azcore.ResponseError + errors.As(err, &responseErr) + if responseErr != nil { + _require.Equal(responseErr.StatusCode, code) + } else { + _require.Equal(strings.Contains(err.Error(), strconv.Itoa(code)), true) + } +} + func ValidateFileErrorCode(_require *require.Assertions, err error, code fileerror.Code) { _require.Error(err) var responseErr *azcore.ResponseError diff --git a/sdk/storage/azfile/sas/service.go b/sdk/storage/azfile/sas/service.go index 061b8336ee18..50192f9ef58b 100644 --- a/sdk/storage/azfile/sas/service.go +++ b/sdk/storage/azfile/sas/service.go @@ -20,21 +20,21 @@ import ( // For more information on creating service sas, see https://docs.microsoft.com/rest/api/storageservices/constructing-a-service-sas // User Delegation SAS not supported for files service type SignatureValues struct { - Version string `param:"sv"` // If not specified, this defaults to Version - Protocol Protocol `param:"spr"` // See the Protocol* constants - StartTime time.Time `param:"st"` // Not specified if IsZero - ExpiryTime time.Time `param:"se"` // Not specified if IsZero - SnapshotTime time.Time - Permissions string `param:"sp"` // Create by initializing SharePermissions or FilePermissions and then call String() - IPRange IPRange `param:"sip"` - Identifier string `param:"si"` - ShareName string - DirectoryOrFilePath string // Ex: "directory/FileName". Use "" to create a Share SAS, directory path for Directory SAS and file path for File SAS. - CacheControl string // rscc - ContentDisposition string // rscd - ContentEncoding string // rsce - ContentLanguage string // rscl - ContentType string // rsct + Version string `param:"sv"` // If not specified, this defaults to Version + Protocol Protocol `param:"spr"` // See the Protocol* constants + StartTime time.Time `param:"st"` // Not specified if IsZero + ExpiryTime time.Time `param:"se"` // Not specified if IsZero + SnapshotTime time.Time + Permissions string `param:"sp"` // Create by initializing SharePermissions or FilePermissions and then call String() + IPRange IPRange `param:"sip"` + Identifier string `param:"si"` + ShareName string + FilePath string // Ex: "directory/FileName". Use "" to create a Share SAS and file path for File SAS. + CacheControl string // rscc + ContentDisposition string // rscd + ContentEncoding string // rsce + ContentLanguage string // rscl + ContentType string // rsct } // SignWithSharedKey uses an account's SharedKeyCredential to sign this signature values to produce the proper SAS query parameters. @@ -44,7 +44,7 @@ func (v SignatureValues) SignWithSharedKey(sharedKeyCredential *SharedKeyCredent } resource := "s" - if v.DirectoryOrFilePath == "" { + if v.FilePath == "" { //Make sure the permission characters are in the correct order perms, err := parseSharePermissions(v.Permissions) if err != nil { @@ -71,7 +71,7 @@ func (v SignatureValues) SignWithSharedKey(sharedKeyCredential *SharedKeyCredent v.Permissions, startTime, expiryTime, - getCanonicalName(sharedKeyCredential.AccountName(), v.ShareName, v.DirectoryOrFilePath), + getCanonicalName(sharedKeyCredential.AccountName(), v.ShareName, v.FilePath), v.Identifier, v.IPRange.String(), string(v.Protocol), diff --git a/sdk/storage/azfile/service/client_test.go b/sdk/storage/azfile/service/client_test.go index 236da07ba3f7..d9c3642c4628 100644 --- a/sdk/storage/azfile/service/client_test.go +++ b/sdk/storage/azfile/service/client_test.go @@ -365,6 +365,90 @@ func (s *ServiceRecordedTestsSuite) TestSASServiceClientSignNegative() { Create: true, } expiry := time.Time{} - _, err = serviceClient.GetSASURL(resources, permissions, expiry, nil) + + // zero expiry time + _, err = serviceClient.GetSASURL(resources, permissions, expiry, &service.GetSASURLOptions{StartTime: to.Ptr(time.Now())}) + _require.Equal(err.Error(), "account SAS is missing at least one of these: ExpiryTime, Permissions, Service, or ResourceType") + + // zero start and expiry time + _, err = serviceClient.GetSASURL(resources, permissions, expiry, &service.GetSASURLOptions{}) + _require.Equal(err.Error(), "account SAS is missing at least one of these: ExpiryTime, Permissions, Service, or ResourceType") + + // empty permissions + _, err = serviceClient.GetSASURL(sas.AccountResourceTypes{}, sas.AccountPermissions{}, expiry, nil) _require.Equal(err.Error(), "account SAS is missing at least one of these: ExpiryTime, Permissions, Service, or ResourceType") } + +func (s *ServiceRecordedTestsSuite) TestServiceSetPropertiesDefault() { + _require := require.New(s.T()) + + svcClient, err := testcommon.GetServiceClient(s.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + _, err = svcClient.SetProperties(context.Background(), nil) + _require.NoError(err) +} + +func (s *ServiceRecordedTestsSuite) TestServiceCreateDeleteRestoreShare() { + _require := require.New(s.T()) + testName := s.T().Name() + + svcClient, err := testcommon.GetServiceClient(s.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareName := testcommon.GenerateShareName(testName) + + _, err = svcClient.CreateShare(context.Background(), shareName, nil) + _require.NoError(err) + + defer func() { + _, err := svcClient.DeleteShare(context.Background(), shareName, nil) + _require.NoError(err) + }() + + _, err = svcClient.DeleteShare(context.Background(), shareName, nil) + _require.NoError(err) + + // wait for share deletion + time.Sleep(60 * time.Second) + + sharesCnt := 0 + shareVersion := "" + + pager := svcClient.NewListSharesPager(&service.ListSharesOptions{ + Include: service.ListSharesInclude{Deleted: true}, + Prefix: &shareName, + }) + + for pager.More() { + resp, err := pager.NextPage(context.Background()) + _require.NoError(err) + for _, s := range resp.Shares { + if s.Deleted != nil && *s.Deleted { + _require.NotNil(s.Version) + shareVersion = *s.Version + } else { + sharesCnt++ + } + } + } + + _require.Equal(sharesCnt, 0) + _require.NotEmpty(shareVersion) + + restoreResp, err := svcClient.RestoreShare(context.Background(), shareName, shareVersion, nil) + _require.NoError(err) + _require.NotNil(restoreResp.RequestID) + + sharesCnt = 0 + pager = svcClient.NewListSharesPager(&service.ListSharesOptions{ + Prefix: &shareName, + }) + + for pager.More() { + resp, err := pager.NextPage(context.Background()) + _require.NoError(err) + sharesCnt += len(resp.Shares) + } + _require.Equal(sharesCnt, 1) +} diff --git a/sdk/storage/azfile/share/client_test.go b/sdk/storage/azfile/share/client_test.go index 80574dd3cbba..30bc1db399de 100644 --- a/sdk/storage/azfile/share/client_test.go +++ b/sdk/storage/azfile/share/client_test.go @@ -1446,6 +1446,15 @@ func (s *ShareRecordedTestsSuite) TestSASShareClientSignNegative() { } expiry := time.Time{} - _, err = shareClient.GetSASURL(permissions, expiry, nil) + // zero expiry time + _, err = shareClient.GetSASURL(permissions, expiry, &share.GetSASURLOptions{StartTime: to.Ptr(time.Now())}) + _require.Equal(err.Error(), "service SAS is missing at least one of these: ExpiryTime or Permissions") + + // zero start and expiry time + _, err = shareClient.GetSASURL(permissions, expiry, &share.GetSASURLOptions{}) + _require.Equal(err.Error(), "service SAS is missing at least one of these: ExpiryTime or Permissions") + + // empty permissions + _, err = shareClient.GetSASURL(sas.SharePermissions{}, expiry, nil) _require.Equal(err.Error(), "service SAS is missing at least one of these: ExpiryTime or Permissions") }