diff --git a/sdk/storage/azdatalake/common.go b/sdk/storage/azdatalake/common.go index 32a7862e8256..03fe643423db 100644 --- a/sdk/storage/azdatalake/common.go +++ b/sdk/storage/azdatalake/common.go @@ -16,8 +16,6 @@ type ClientOptions struct { azcore.ClientOptions } -const SnapshotTimeFormat = "2006-01-02T15:04:05.0000000Z07:00" - // AccessConditions identifies container-specific access conditions which you optionally set. type AccessConditions struct { ModifiedAccessConditions *ModifiedAccessConditions diff --git a/sdk/storage/azdatalake/internal/exported/exported.go b/sdk/storage/azdatalake/internal/exported/exported.go index 9bc1ca47df84..6a91ea05453a 100644 --- a/sdk/storage/azdatalake/internal/exported/exported.go +++ b/sdk/storage/azdatalake/internal/exported/exported.go @@ -11,6 +11,8 @@ import ( "strconv" ) +const SnapshotTimeFormat = "2006-01-02T15:04:05.0000000Z07:00" + // HTTPRange defines a range of bytes within an HTTP resource, starting at offset and // ending at offset+count. A zero-value HTTPRange indicates the entire resource. An HTTPRange // which has an offset but no zero value count indicates from the offset to the resource's end. diff --git a/sdk/storage/azdatalake/internal/exported/user_delegation_credential.go b/sdk/storage/azdatalake/internal/exported/user_delegation_credential.go new file mode 100644 index 000000000000..91b933bf5737 --- /dev/null +++ b/sdk/storage/azdatalake/internal/exported/user_delegation_credential.go @@ -0,0 +1,64 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package exported + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/service" +) + +// NewUserDelegationCredential creates a new UserDelegationCredential using a Storage account's Name and a user delegation Key from it +func NewUserDelegationCredential(accountName string, udk UserDelegationKey) *UserDelegationCredential { + return &UserDelegationCredential{ + accountName: accountName, + userDelegationKey: udk, + } +} + +// UserDelegationKey contains UserDelegationKey. +type UserDelegationKey = service.UserDelegationKey + +// UserDelegationCredential contains an account's name and its user delegation key. +type UserDelegationCredential struct { + accountName string + userDelegationKey UserDelegationKey +} + +// getAccountName returns the Storage account's Name +func (f *UserDelegationCredential) getAccountName() string { + return f.accountName +} + +// GetAccountName is a helper method for accessing the user delegation key parameters outside this package. +func GetAccountName(udc *UserDelegationCredential) string { + return udc.getAccountName() +} + +// computeHMACSHA256 generates a hash signature for an HTTP request or for a SAS. +func (f *UserDelegationCredential) computeHMACSHA256(message string) (string, error) { + bytes, _ := base64.StdEncoding.DecodeString(*f.userDelegationKey.Value) + h := hmac.New(sha256.New, bytes) + _, err := h.Write([]byte(message)) + return base64.StdEncoding.EncodeToString(h.Sum(nil)), err +} + +// ComputeUDCHMACSHA256 is a helper method for computing the signed string outside this package. +func ComputeUDCHMACSHA256(udc *UserDelegationCredential, message string) (string, error) { + return udc.computeHMACSHA256(message) +} + +// getUDKParams returns UserDelegationKey +func (f *UserDelegationCredential) getUDKParams() *UserDelegationKey { + return &f.userDelegationKey +} + +// GetUDKParams is a helper method for accessing the user delegation key parameters outside this package. +func GetUDKParams(udc *UserDelegationCredential) *UserDelegationKey { + return udc.getUDKParams() +} diff --git a/sdk/storage/azdatalake/sas/account.go b/sdk/storage/azdatalake/sas/account.go new file mode 100644 index 000000000000..e5681c4ebdba --- /dev/null +++ b/sdk/storage/azdatalake/sas/account.go @@ -0,0 +1,226 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package sas + +import ( + "bytes" + "errors" + "fmt" + "strings" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/exported" +) + +// SharedKeyCredential contains an account's name and its primary or secondary key. +type SharedKeyCredential = exported.SharedKeyCredential + +// UserDelegationCredential contains an account's name and its user delegation key. +type UserDelegationCredential = exported.UserDelegationCredential + +// AccountSignatureValues is used to generate a Shared Access Signature (SAS) for an Azure Storage account. +// For more information, see https://docs.microsoft.com/rest/api/storageservices/constructing-an-account-sas +type AccountSignatureValues struct { + Version string `param:"sv"` // If not specified, this format to SASVersion + Protocol Protocol `param:"spr"` // See the SASProtocol* constants + StartTime time.Time `param:"st"` // Not specified if IsZero + ExpiryTime time.Time `param:"se"` // Not specified if IsZero + Permissions string `param:"sp"` // Create by initializing AccountPermissions and then call String() + IPRange IPRange `param:"sip"` + ResourceTypes string `param:"srt"` // Create by initializing AccountResourceTypes and then call String() +} + +// SignWithSharedKey uses an account's shared key credential to sign this signature values to produce +// the proper SAS query parameters. +func (v AccountSignatureValues) SignWithSharedKey(sharedKeyCredential *SharedKeyCredential) (QueryParameters, error) { + // https://docs.microsoft.com/en-us/rest/api/storageservices/Constructing-an-Account-SAS + if v.ExpiryTime.IsZero() || v.Permissions == "" || v.ResourceTypes == "" { + return QueryParameters{}, errors.New("account SAS is missing at least one of these: ExpiryTime, Permissions, Service, or ResourceType") + } + if v.Version == "" { + v.Version = Version + } + perms, err := parseAccountPermissions(v.Permissions) + if err != nil { + return QueryParameters{}, err + } + v.Permissions = perms.String() + + resources, err := parseAccountResourceTypes(v.ResourceTypes) + if err != nil { + return QueryParameters{}, err + } + v.ResourceTypes = resources.String() + + startTime, expiryTime := formatTimesForSigning(v.StartTime, v.ExpiryTime) + + stringToSign := strings.Join([]string{ + sharedKeyCredential.AccountName(), + v.Permissions, + "b", // blob service + v.ResourceTypes, + startTime, + expiryTime, + v.IPRange.String(), + string(v.Protocol), + v.Version, + ""}, // That is right, the account SAS requires a terminating extra newline + "\n") + + signature, err := exported.ComputeHMACSHA256(sharedKeyCredential, stringToSign) + if err != nil { + return QueryParameters{}, err + } + p := QueryParameters{ + // Common SAS parameters + version: v.Version, + protocol: v.Protocol, + startTime: v.StartTime, + expiryTime: v.ExpiryTime, + permissions: v.Permissions, + ipRange: v.IPRange, + + // Account-specific SAS parameters + services: "b", // will always be "b" + resourceTypes: v.ResourceTypes, + + // Calculated SAS signature + signature: signature, + } + + return p, nil +} + +// AccountPermissions type simplifies creating the permissions string for an Azure Storage Account SAS. +// Initialize an instance of this type and then call its String method to set AccountSignatureValues' Permissions field. +type AccountPermissions struct { + Read, Write, Delete, DeletePreviousVersion, PermanentDelete, List, Add, Create, Update, Process, FilterByTags, Tag, SetImmutabilityPolicy bool +} + +// String produces the SAS permissions string for an Azure Storage account. +// Call this method to set AccountSignatureValues' Permissions field. +func (p *AccountPermissions) String() string { + var buffer bytes.Buffer + if p.Read { + buffer.WriteRune('r') + } + if p.Write { + buffer.WriteRune('w') + } + if p.Delete { + buffer.WriteRune('d') + } + if p.DeletePreviousVersion { + buffer.WriteRune('x') + } + if p.PermanentDelete { + buffer.WriteRune('y') + } + if p.List { + buffer.WriteRune('l') + } + if p.Add { + buffer.WriteRune('a') + } + if p.Create { + buffer.WriteRune('c') + } + if p.Update { + buffer.WriteRune('u') + } + if p.Process { + buffer.WriteRune('p') + } + if p.FilterByTags { + buffer.WriteRune('f') + } + if p.Tag { + buffer.WriteRune('t') + } + if p.SetImmutabilityPolicy { + buffer.WriteRune('i') + } + return buffer.String() +} + +// Parse initializes the AccountPermissions' fields from a string. +func parseAccountPermissions(s string) (AccountPermissions, error) { + p := AccountPermissions{} // Clear out the flags + for _, r := range s { + switch r { + case 'r': + p.Read = true + case 'w': + p.Write = true + case 'd': + p.Delete = true + case 'x': + p.DeletePreviousVersion = true + case 'y': + p.PermanentDelete = true + case 'l': + p.List = true + case 'a': + p.Add = true + case 'c': + p.Create = true + case 'u': + p.Update = true + case 'p': + p.Process = true + case 't': + p.Tag = true + case 'f': + p.FilterByTags = true + case 'i': + p.SetImmutabilityPolicy = true + default: + return AccountPermissions{}, fmt.Errorf("invalid permission character: '%v'", r) + } + } + return p, nil +} + +// AccountResourceTypes type simplifies creating the resource types string for an Azure Storage Account SAS. +// Initialize an instance of this type and then call its String method to set AccountSignatureValues' ResourceTypes field. +type AccountResourceTypes struct { + Service, Container, Object bool +} + +// String produces the SAS resource types string for an Azure Storage account. +// Call this method to set AccountSignatureValues' ResourceTypes field. +func (rt *AccountResourceTypes) String() string { + var buffer bytes.Buffer + if rt.Service { + buffer.WriteRune('s') + } + if rt.Container { + buffer.WriteRune('c') + } + if rt.Object { + buffer.WriteRune('o') + } + return buffer.String() +} + +// parseAccountResourceTypes initializes the AccountResourceTypes' fields from a string. +func parseAccountResourceTypes(s string) (AccountResourceTypes, error) { + rt := AccountResourceTypes{} + for _, r := range s { + switch r { + case 's': + rt.Service = true + case 'c': + rt.Container = true + case 'o': + rt.Object = true + default: + return AccountResourceTypes{}, fmt.Errorf("invalid resource type character: '%v'", r) + } + } + return rt, nil +} diff --git a/sdk/storage/azdatalake/sas/account_test.go b/sdk/storage/azdatalake/sas/account_test.go new file mode 100644 index 000000000000..c995fe393b23 --- /dev/null +++ b/sdk/storage/azdatalake/sas/account_test.go @@ -0,0 +1,169 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package sas + +import ( + "github.com/stretchr/testify/require" + "testing" +) + +func TestAccountPermissions_String(t *testing.T) { + testdata := []struct { + input AccountPermissions + expected string + }{ + {input: AccountPermissions{Read: true}, expected: "r"}, + {input: AccountPermissions{Write: true}, expected: "w"}, + {input: AccountPermissions{Delete: true}, expected: "d"}, + {input: AccountPermissions{DeletePreviousVersion: true}, expected: "x"}, + {input: AccountPermissions{PermanentDelete: true}, expected: "y"}, + {input: AccountPermissions{List: true}, expected: "l"}, + {input: AccountPermissions{Add: true}, expected: "a"}, + {input: AccountPermissions{Create: true}, expected: "c"}, + {input: AccountPermissions{Update: true}, expected: "u"}, + {input: AccountPermissions{Process: true}, expected: "p"}, + {input: AccountPermissions{Tag: true}, expected: "t"}, + {input: AccountPermissions{FilterByTags: true}, expected: "f"}, + {input: AccountPermissions{SetImmutabilityPolicy: true}, expected: "i"}, + {input: AccountPermissions{ + Read: true, + Write: true, + Delete: true, + DeletePreviousVersion: true, + PermanentDelete: true, + List: true, + Add: true, + Create: true, + Update: true, + Process: true, + Tag: true, + FilterByTags: true, + SetImmutabilityPolicy: true, + }, expected: "rwdxylacupfti"}, + } + for _, c := range testdata { + require.Equal(t, c.expected, c.input.String()) + } +} + +func TestAccountPermissions_Parse(t *testing.T) { + testdata := []struct { + input string + expected AccountPermissions + }{ + {expected: AccountPermissions{Read: true}, input: "r"}, + {expected: AccountPermissions{Write: true}, input: "w"}, + {expected: AccountPermissions{Delete: true}, input: "d"}, + {expected: AccountPermissions{DeletePreviousVersion: true}, input: "x"}, + {expected: AccountPermissions{PermanentDelete: true}, input: "y"}, + {expected: AccountPermissions{List: true}, input: "l"}, + {expected: AccountPermissions{Add: true}, input: "a"}, + {expected: AccountPermissions{Create: true}, input: "c"}, + {expected: AccountPermissions{Update: true}, input: "u"}, + {expected: AccountPermissions{Process: true}, input: "p"}, + {expected: AccountPermissions{Tag: true}, input: "t"}, + {expected: AccountPermissions{FilterByTags: true}, input: "f"}, + {expected: AccountPermissions{SetImmutabilityPolicy: true}, input: "i"}, + {expected: AccountPermissions{ + Read: true, + Write: true, + Delete: true, + DeletePreviousVersion: true, + PermanentDelete: true, + List: true, + Add: true, + Create: true, + Update: true, + Process: true, + Tag: true, + FilterByTags: true, + SetImmutabilityPolicy: true, + }, input: "rwdxylacupfti"}, + {expected: AccountPermissions{ + Read: true, + Write: true, + Delete: true, + DeletePreviousVersion: true, + PermanentDelete: true, + List: true, + Add: true, + Create: true, + Update: true, + Process: true, + Tag: true, + FilterByTags: true, + SetImmutabilityPolicy: true, + }, input: "trwlapdixfycu"}, + } + for _, c := range testdata { + permissions, err := parseAccountPermissions(c.input) + require.Nil(t, err) + require.Equal(t, c.expected, permissions) + } +} + +func TestAccountPermissions_ParseNegative(t *testing.T) { + _, err := parseAccountPermissions("trwlapdixfycuz") // Here 'z' is invalid + require.NotNil(t, err) + require.Contains(t, err.Error(), "122") +} + +func TestAccountResourceTypes_String(t *testing.T) { + testdata := []struct { + input AccountResourceTypes + expected string + }{ + {input: AccountResourceTypes{Service: true}, expected: "s"}, + {input: AccountResourceTypes{Container: true}, expected: "c"}, + {input: AccountResourceTypes{Object: true}, expected: "o"}, + {input: AccountResourceTypes{ + Service: true, + Container: true, + Object: true, + }, expected: "sco"}, + } + for _, c := range testdata { + require.Equal(t, c.expected, c.input.String()) + } +} + +func TestAccountResourceTypes_Parse(t *testing.T) { + testdata := []struct { + input string + expected AccountResourceTypes + }{ + {expected: AccountResourceTypes{Service: true}, input: "s"}, + {expected: AccountResourceTypes{Container: true}, input: "c"}, + {expected: AccountResourceTypes{Object: true}, input: "o"}, + {expected: AccountResourceTypes{ + Service: true, + Container: true, + Object: true, + }, input: "sco"}, + {expected: AccountResourceTypes{ + Service: true, + Container: true, + Object: true, + }, input: "osc"}, + } + for _, c := range testdata { + permissions, err := parseAccountResourceTypes(c.input) + require.Nil(t, err) + require.Equal(t, c.expected, permissions) + } +} + +func TestAccountResourceTypes_ParseNegative(t *testing.T) { + _, err := parseAccountResourceTypes("scoz") // Here 'z' is invalid + require.NotNil(t, err) + require.Contains(t, err.Error(), "122") +} + +// TODO: Sign With Shared Key +// Negative Case +// Version not provided +// SignWithSharedKey tests diff --git a/sdk/storage/azdatalake/sas/query_params.go b/sdk/storage/azdatalake/sas/query_params.go new file mode 100644 index 000000000000..ed3cd252d93f --- /dev/null +++ b/sdk/storage/azdatalake/sas/query_params.go @@ -0,0 +1,506 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package sas + +import ( + "errors" + "net" + "net/url" + "strings" + "time" +) + +// timeFormat represents the format of a SAS start or expiry time. Use it when formatting/parsing a time.Time. +const ( + timeFormat = "2006-01-02T15:04:05Z" // "2017-07-27T00:00:00Z" // ISO 8601 +) + +var ( + // Version is the default version encoded in the SAS token. + Version = "2020-02-10" +) + +// TimeFormats ISO 8601 format. +// Please refer to https://docs.microsoft.com/en-us/rest/api/storageservices/constructing-a-service-sas for more details. +var timeFormats = []string{"2006-01-02T15:04:05.0000000Z", timeFormat, "2006-01-02T15:04Z", "2006-01-02"} + +// Protocol indicates the http/https. +type Protocol string + +const ( + // ProtocolHTTPS can be specified for a SAS protocol. + ProtocolHTTPS Protocol = "https" + + // ProtocolHTTPSandHTTP can be specified for a SAS protocol. + ProtocolHTTPSandHTTP Protocol = "https,http" +) + +// FormatTimesForSigning converts a time.Time to a snapshotTimeFormat string suitable for a +// Field's StartTime or ExpiryTime fields. Returns "" if value.IsZero(). +func formatTimesForSigning(startTime, expiryTime time.Time) (string, string) { + ss := "" + if !startTime.IsZero() { + ss = formatTimeWithDefaultFormat(&startTime) + } + se := "" + if !expiryTime.IsZero() { + se = formatTimeWithDefaultFormat(&expiryTime) + } + return ss, se +} + +// formatTimeWithDefaultFormat format time with ISO 8601 in "yyyy-MM-ddTHH:mm:ssZ". +func formatTimeWithDefaultFormat(t *time.Time) string { + return formatTime(t, timeFormat) // By default, "yyyy-MM-ddTHH:mm:ssZ" is used +} + +// formatTime format time with given format, use ISO 8601 in "yyyy-MM-ddTHH:mm:ssZ" by default. +func formatTime(t *time.Time, format string) string { + if format != "" { + return t.Format(format) + } + return t.Format(timeFormat) // By default, "yyyy-MM-ddTHH:mm:ssZ" is used +} + +// ParseTime try to parse a SAS time string. +func parseTime(val string) (t time.Time, timeFormat string, err error) { + for _, sasTimeFormat := range timeFormats { + t, err = time.Parse(sasTimeFormat, val) + if err == nil { + timeFormat = sasTimeFormat + break + } + } + + if err != nil { + err = errors.New("fail to parse time with IOS 8601 formats, please refer to https://docs.microsoft.com/en-us/rest/api/storageservices/constructing-a-service-sas for more details") + } + + return +} + +// IPRange represents a SAS IP range's start IP and (optionally) end IP. +type IPRange struct { + Start net.IP // Not specified if length = 0 + End net.IP // Not specified if length = 0 +} + +// String returns a string representation of an IPRange. +func (ipr *IPRange) String() string { + if len(ipr.Start) == 0 { + return "" + } + start := ipr.Start.String() + if len(ipr.End) == 0 { + return start + } + return start + "-" + ipr.End.String() +} + +// https://docs.microsoft.com/en-us/rest/api/storageservices/constructing-a-service-sas + +// QueryParameters object represents the components that make up an Azure Storage SAS' query parameters. +// You parse a map of query parameters into its fields by calling NewQueryParameters(). You add the components +// to a query parameter map by calling AddToValues(). +// NOTE: Changing any field requires computing a new SAS signature using a XxxSASSignatureValues type. +// This type defines the components used by all Azure Storage resources (Containers, Blobs, Files, & Queues). +type QueryParameters struct { + // All members are immutable or values so copies of this struct are goroutine-safe. + version string `param:"sv"` + services string `param:"ss"` + resourceTypes string `param:"srt"` + protocol Protocol `param:"spr"` + startTime time.Time `param:"st"` + expiryTime time.Time `param:"se"` + snapshotTime time.Time `param:"snapshot"` + ipRange IPRange `param:"sip"` + identifier string `param:"si"` + resource string `param:"sr"` + permissions string `param:"sp"` + signature string `param:"sig"` + cacheControl string `param:"rscc"` + contentDisposition string `param:"rscd"` + contentEncoding string `param:"rsce"` + contentLanguage string `param:"rscl"` + contentType string `param:"rsct"` + signedOID string `param:"skoid"` + signedTID string `param:"sktid"` + signedStart time.Time `param:"skt"` + signedService string `param:"sks"` + signedExpiry time.Time `param:"ske"` + signedVersion string `param:"skv"` + signedDirectoryDepth string `param:"sdd"` + authorizedObjectID string `param:"saoid"` + unauthorizedObjectID string `param:"suoid"` + correlationID string `param:"scid"` + // private member used for startTime and expiryTime formatting. + stTimeFormat string + seTimeFormat string +} + +// AuthorizedObjectID returns authorizedObjectID. +func (p *QueryParameters) AuthorizedObjectID() string { + return p.authorizedObjectID +} + +// UnauthorizedObjectID returns unauthorizedObjectID. +func (p *QueryParameters) UnauthorizedObjectID() string { + return p.unauthorizedObjectID +} + +// SignedCorrelationID returns signedCorrelationID. +func (p *QueryParameters) SignedCorrelationID() string { + return p.correlationID +} + +// SignedOID returns signedOID. +func (p *QueryParameters) SignedOID() string { + return p.signedOID +} + +// SignedTID returns signedTID. +func (p *QueryParameters) SignedTID() string { + return p.signedTID +} + +// SignedStart returns signedStart. +func (p *QueryParameters) SignedStart() time.Time { + return p.signedStart +} + +// SignedExpiry returns signedExpiry. +func (p *QueryParameters) SignedExpiry() time.Time { + return p.signedExpiry +} + +// SignedService returns signedService. +func (p *QueryParameters) SignedService() string { + return p.signedService +} + +// SignedVersion returns signedVersion. +func (p *QueryParameters) SignedVersion() string { + return p.signedVersion +} + +// SnapshotTime returns snapshotTime. +func (p *QueryParameters) SnapshotTime() time.Time { + return p.snapshotTime +} + +// Version returns version. +func (p *QueryParameters) Version() string { + return p.version +} + +// Services returns services. +func (p *QueryParameters) Services() string { + return p.services +} + +// ResourceTypes returns resourceTypes. +func (p *QueryParameters) ResourceTypes() string { + return p.resourceTypes +} + +// Protocol returns protocol. +func (p *QueryParameters) Protocol() Protocol { + return p.protocol +} + +// StartTime returns startTime. +func (p *QueryParameters) StartTime() time.Time { + return p.startTime +} + +// ExpiryTime returns expiryTime. +func (p *QueryParameters) ExpiryTime() time.Time { + return p.expiryTime +} + +// IPRange returns ipRange. +func (p *QueryParameters) IPRange() IPRange { + return p.ipRange +} + +// Identifier returns identifier. +func (p *QueryParameters) Identifier() string { + return p.identifier +} + +// Resource returns resource. +func (p *QueryParameters) Resource() string { + return p.resource +} + +// Permissions returns permissions. +func (p *QueryParameters) Permissions() string { + return p.permissions +} + +// Signature returns signature. +func (p *QueryParameters) Signature() string { + return p.signature +} + +// CacheControl returns cacheControl. +func (p *QueryParameters) CacheControl() string { + return p.cacheControl +} + +// ContentDisposition returns contentDisposition. +func (p *QueryParameters) ContentDisposition() string { + return p.contentDisposition +} + +// ContentEncoding returns contentEncoding. +func (p *QueryParameters) ContentEncoding() string { + return p.contentEncoding +} + +// ContentLanguage returns contentLanguage. +func (p *QueryParameters) ContentLanguage() string { + return p.contentLanguage +} + +// ContentType returns contentType. +func (p *QueryParameters) ContentType() string { + return p.contentType +} + +// SignedDirectoryDepth returns signedDirectoryDepth. +func (p *QueryParameters) SignedDirectoryDepth() string { + return p.signedDirectoryDepth +} + +// Encode encodes the SAS query parameters into URL encoded form sorted by key. +func (p *QueryParameters) Encode() string { + v := url.Values{} + + if p.version != "" { + v.Add("sv", p.version) + } + if p.services != "" { + v.Add("ss", p.services) + } + if p.resourceTypes != "" { + v.Add("srt", p.resourceTypes) + } + if p.protocol != "" { + v.Add("spr", string(p.protocol)) + } + if !p.startTime.IsZero() { + v.Add("st", formatTime(&(p.startTime), p.stTimeFormat)) + } + if !p.expiryTime.IsZero() { + v.Add("se", formatTime(&(p.expiryTime), p.seTimeFormat)) + } + if len(p.ipRange.Start) > 0 { + v.Add("sip", p.ipRange.String()) + } + if p.identifier != "" { + v.Add("si", p.identifier) + } + if p.resource != "" { + v.Add("sr", p.resource) + } + if p.permissions != "" { + v.Add("sp", p.permissions) + } + if p.signedOID != "" { + v.Add("skoid", p.signedOID) + v.Add("sktid", p.signedTID) + v.Add("skt", p.signedStart.Format(timeFormat)) + v.Add("ske", p.signedExpiry.Format(timeFormat)) + v.Add("sks", p.signedService) + v.Add("skv", p.signedVersion) + } + if p.signature != "" { + v.Add("sig", p.signature) + } + if p.cacheControl != "" { + v.Add("rscc", p.cacheControl) + } + if p.contentDisposition != "" { + v.Add("rscd", p.contentDisposition) + } + if p.contentEncoding != "" { + v.Add("rsce", p.contentEncoding) + } + if p.contentLanguage != "" { + v.Add("rscl", p.contentLanguage) + } + if p.contentType != "" { + v.Add("rsct", p.contentType) + } + if p.signedDirectoryDepth != "" { + v.Add("sdd", p.signedDirectoryDepth) + } + if p.authorizedObjectID != "" { + v.Add("saoid", p.authorizedObjectID) + } + if p.unauthorizedObjectID != "" { + v.Add("suoid", p.unauthorizedObjectID) + } + if p.correlationID != "" { + v.Add("scid", p.correlationID) + } + + return v.Encode() +} + +// NewQueryParameters creates and initializes a QueryParameters object based on the +// query parameter map's passed-in values. If a key is unrecognized, it is ignored +func NewQueryParameters(values url.Values) QueryParameters { + p := QueryParameters{} + for k, v := range values { + val := v[0] + switch strings.ToLower(k) { + case "sv": + p.version = val + case "ss": + p.services = val + case "srt": + p.resourceTypes = val + case "spr": + p.protocol = Protocol(val) + case "st": + p.startTime, p.stTimeFormat, _ = parseTime(val) + case "se": + p.expiryTime, p.seTimeFormat, _ = parseTime(val) + case "sip": + dashIndex := strings.Index(val, "-") + if dashIndex == -1 { + p.ipRange.Start = net.ParseIP(val) + } else { + p.ipRange.Start = net.ParseIP(val[:dashIndex]) + p.ipRange.End = net.ParseIP(val[dashIndex+1:]) + } + case "si": + p.identifier = val + case "sr": + p.resource = val + case "sp": + p.permissions = val + case "sig": + p.signature = val + case "rscc": + p.cacheControl = val + case "rscd": + p.contentDisposition = val + case "rsce": + p.contentEncoding = val + case "rscl": + p.contentLanguage = val + case "rsct": + p.contentType = val + case "skoid": + p.signedOID = val + case "sktid": + p.signedTID = val + case "skt": + p.signedStart, _ = time.Parse(timeFormat, val) + case "ske": + p.signedExpiry, _ = time.Parse(timeFormat, val) + case "sks": + p.signedService = val + case "skv": + p.signedVersion = val + case "sdd": + p.signedDirectoryDepth = val + case "saoid": + p.authorizedObjectID = val + case "suoid": + p.unauthorizedObjectID = val + case "scid": + p.correlationID = val + default: + continue // query param didn't get recognized + } + } + return p +} + +// newQueryParameters creates and initializes a QueryParameters object based on the +// query parameter map's passed-in values. If deleteSASParametersFromValues is true, +// all SAS-related query parameters are removed from the passed-in map. If +// deleteSASParametersFromValues is false, the map passed-in map is unaltered. +func newQueryParameters(values url.Values, deleteSASParametersFromValues bool) QueryParameters { + p := QueryParameters{} + for k, v := range values { + val := v[0] + isSASKey := true + switch strings.ToLower(k) { + case "sv": + p.version = val + case "ss": + p.services = val + case "srt": + p.resourceTypes = val + case "spr": + p.protocol = Protocol(val) + case "st": + p.startTime, p.stTimeFormat, _ = parseTime(val) + case "se": + p.expiryTime, p.seTimeFormat, _ = parseTime(val) + case "sip": + dashIndex := strings.Index(val, "-") + if dashIndex == -1 { + p.ipRange.Start = net.ParseIP(val) + } else { + p.ipRange.Start = net.ParseIP(val[:dashIndex]) + p.ipRange.End = net.ParseIP(val[dashIndex+1:]) + } + case "si": + p.identifier = val + case "sr": + p.resource = val + case "sp": + p.permissions = val + //case "snapshot": + // p.snapshotTime, _ = time.Parse(exported.SnapshotTimeFormat, val) + case "sig": + p.signature = val + case "rscc": + p.cacheControl = val + case "rscd": + p.contentDisposition = val + case "rsce": + p.contentEncoding = val + case "rscl": + p.contentLanguage = val + case "rsct": + p.contentType = val + case "skoid": + p.signedOID = val + case "sktid": + p.signedTID = val + case "skt": + p.signedStart, _ = time.Parse(timeFormat, val) + case "ske": + p.signedExpiry, _ = time.Parse(timeFormat, val) + case "sks": + p.signedService = val + case "skv": + p.signedVersion = val + case "sdd": + p.signedDirectoryDepth = val + case "saoid": + p.authorizedObjectID = val + case "suoid": + p.unauthorizedObjectID = val + case "scid": + p.correlationID = val + default: + isSASKey = false // We didn't recognize the query parameter + } + if isSASKey && deleteSASParametersFromValues { + delete(values, k) + } + } + return p +} diff --git a/sdk/storage/azdatalake/sas/query_params_test.go b/sdk/storage/azdatalake/sas/query_params_test.go new file mode 100644 index 000000000000..1e50647a2ca3 --- /dev/null +++ b/sdk/storage/azdatalake/sas/query_params_test.go @@ -0,0 +1,231 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package sas + +import ( + "fmt" + "net" + "net/url" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestFormatTimesForSigning(t *testing.T) { + testdata := []struct { + inputStart time.Time + inputEnd time.Time + inputSnapshot time.Time + expectedStart string + expectedEnd string + expectedSnapshot string + }{ + {expectedStart: "", expectedEnd: "", expectedSnapshot: ""}, + {inputStart: time.Date(1955, 6, 25, 22, 15, 56, 345456, time.UTC), expectedStart: "1955-06-25T22:15:56Z", expectedEnd: "", expectedSnapshot: ""}, + {inputEnd: time.Date(2023, 4, 5, 8, 50, 27, 4500, time.UTC), expectedStart: "", expectedEnd: "2023-04-05T08:50:27Z", expectedSnapshot: ""}, + {inputSnapshot: time.Date(2021, 1, 5, 22, 15, 33, 1234879, time.UTC), expectedStart: "", expectedEnd: "", expectedSnapshot: "2021-01-05T22:15:33.0012348Z"}, + { + inputStart: time.Date(1955, 6, 25, 22, 15, 56, 345456, time.UTC), + inputEnd: time.Date(2023, 4, 5, 8, 50, 27, 4500, time.UTC), + inputSnapshot: time.Date(2021, 1, 5, 22, 15, 33, 1234879, time.UTC), + expectedStart: "1955-06-25T22:15:56Z", + expectedEnd: "2023-04-05T08:50:27Z", + expectedSnapshot: "2021-01-05T22:15:33.0012348Z", + }, + } + for _, c := range testdata { + start, end := formatTimesForSigning(c.inputStart, c.inputEnd) + require.Equal(t, c.expectedStart, start) + require.Equal(t, c.expectedEnd, end) + } +} + +func TestFormatTimeWithDefaultFormat(t *testing.T) { + testdata := []struct { + input time.Time + expectedTime string + }{ + {input: time.Date(1955, 4, 5, 8, 50, 27, 4500, time.UTC), expectedTime: "1955-04-05T08:50:27Z"}, + {input: time.Date(1917, 3, 9, 16, 22, 56, 0, time.UTC), expectedTime: "1917-03-09T16:22:56Z"}, + {input: time.Date(2021, 1, 5, 22, 15, 0, 0, time.UTC), expectedTime: "2021-01-05T22:15:00Z"}, + {input: time.Date(2023, 6, 25, 0, 0, 0, 0, time.UTC), expectedTime: "2023-06-25T00:00:00Z"}, + } + for _, c := range testdata { + formattedTime := formatTimeWithDefaultFormat(&c.input) + require.Equal(t, c.expectedTime, formattedTime) + } +} + +func TestFormatTime(t *testing.T) { + testdata := []struct { + input time.Time + format string + expectedTime string + }{ + {input: time.Date(1955, 4, 5, 8, 50, 27, 4500, time.UTC), format: "2006-01-02T15:04:05.0000000Z", expectedTime: "1955-04-05T08:50:27.0000045Z"}, + {input: time.Date(1955, 4, 5, 8, 50, 27, 4500, time.UTC), format: "", expectedTime: "1955-04-05T08:50:27Z"}, + {input: time.Date(1917, 3, 9, 16, 22, 56, 0, time.UTC), format: "2006-01-02T15:04:05Z", expectedTime: "1917-03-09T16:22:56Z"}, + {input: time.Date(1917, 3, 9, 16, 22, 56, 0, time.UTC), format: "", expectedTime: "1917-03-09T16:22:56Z"}, + {input: time.Date(2021, 1, 5, 22, 15, 0, 0, time.UTC), format: "2006-01-02T15:04Z", expectedTime: "2021-01-05T22:15Z"}, + {input: time.Date(2021, 1, 5, 22, 15, 0, 0, time.UTC), format: "", expectedTime: "2021-01-05T22:15:00Z"}, + {input: time.Date(2023, 6, 25, 0, 0, 0, 0, time.UTC), format: "2006-01-02", expectedTime: "2023-06-25"}, + {input: time.Date(2023, 6, 25, 0, 0, 0, 0, time.UTC), format: "", expectedTime: "2023-06-25T00:00:00Z"}, + } + for _, c := range testdata { + formattedTime := formatTime(&c.input, c.format) + require.Equal(t, c.expectedTime, formattedTime) + } +} + +func TestParseTime(t *testing.T) { + testdata := []struct { + input string + expectedTime time.Time + expectedFormat string + }{ + {input: "1955-04-05T08:50:27.0000045Z", expectedTime: time.Date(1955, 4, 5, 8, 50, 27, 4500, time.UTC), expectedFormat: "2006-01-02T15:04:05.0000000Z"}, + {input: "1917-03-09T16:22:56Z", expectedTime: time.Date(1917, 3, 9, 16, 22, 56, 0, time.UTC), expectedFormat: "2006-01-02T15:04:05Z"}, + {input: "2021-01-05T22:15Z", expectedTime: time.Date(2021, 1, 5, 22, 15, 0, 0, time.UTC), expectedFormat: "2006-01-02T15:04Z"}, + {input: "2023-06-25", expectedTime: time.Date(2023, 6, 25, 0, 0, 0, 0, time.UTC), expectedFormat: "2006-01-02"}, + } + for _, c := range testdata { + parsedTime, format, err := parseTime(c.input) + require.Nil(t, err) + require.Equal(t, c.expectedTime, parsedTime) + require.Equal(t, c.expectedFormat, format) + } +} + +func TestParseTimeNegative(t *testing.T) { + _, _, err := parseTime("notatime") + require.Error(t, err, "fail to parse time with IOS 8601 formats, please refer to https://docs.microsoft.com/en-us/rest/api/storageservices/constructing-a-service-sas for more details") +} + +func TestIPRange_String(t *testing.T) { + testdata := []struct { + inputStart net.IP + inputEnd net.IP + expected string + }{ + {expected: ""}, + {inputStart: net.IPv4(10, 255, 0, 0), expected: "10.255.0.0"}, + {inputStart: net.IPv4(10, 255, 0, 0), inputEnd: net.IPv4(10, 255, 0, 50), expected: "10.255.0.0-10.255.0.50"}, + } + for _, c := range testdata { + var ipRange IPRange + if c.inputStart != nil { + ipRange.Start = c.inputStart + } + if c.inputEnd != nil { + ipRange.End = c.inputEnd + } + require.Equal(t, c.expected, ipRange.String()) + } +} + +func TestSAS(t *testing.T) { + // Note: This is a totally invalid fake SAS, this is just testing our ability to parse different query parameters on a SAS + const sas = "sv=2019-12-12&sr=b&st=2111-01-09T01:42:34.936Z&se=2222-03-09T01:42:34.936Z&sp=rw&sip=168.1.5.60-168.1.5.70&spr=https,http&si=myIdentifier&ss=bf&srt=s&rscc=cc&rscd=cd&rsce=ce&rscl=cl&rsct=ct&skoid=oid&sktid=tid&skt=2111-01-09T01:42:34.936Z&ske=2222-03-09T01:42:34.936Z&sks=s&skv=v&sdd=3&saoid=oid&suoid=oid&scid=cid&sig=clNxbtnkKSHw7f3KMEVVc4agaszoRFdbZr%2FWBmPNsrw%3D" + _url := fmt.Sprintf("https://teststorageaccount.blob.core.windows.net/testcontainer/testpath?%s", sas) + _uri, err := url.Parse(_url) + require.NoError(t, err) + sasQueryParams := newQueryParameters(_uri.Query(), true) + validateSAS(t, sas, sasQueryParams) +} + +func validateSAS(t *testing.T, sas string, parameters QueryParameters) { + sasCompMap := make(map[string]string) + for _, sasComp := range strings.Split(sas, "&") { + comp := strings.Split(sasComp, "=") + sasCompMap[comp[0]] = comp[1] + } + + require.Equal(t, parameters.Version(), sasCompMap["sv"]) + require.Equal(t, parameters.Services(), sasCompMap["ss"]) + require.Equal(t, parameters.ResourceTypes(), sasCompMap["srt"]) + require.Equal(t, string(parameters.Protocol()), sasCompMap["spr"]) + if _, ok := sasCompMap["st"]; ok { + startTime, _, err := parseTime(sasCompMap["st"]) + require.NoError(t, err) + require.Equal(t, parameters.StartTime(), startTime) + } + if _, ok := sasCompMap["se"]; ok { + endTime, _, err := parseTime(sasCompMap["se"]) + require.NoError(t, err) + require.Equal(t, parameters.ExpiryTime(), endTime) + } + + if _, ok := sasCompMap["snapshot"]; ok { + snapshotTime, _, err := parseTime(sasCompMap["snapshot"]) + require.NoError(t, err) + require.Equal(t, parameters.SnapshotTime(), snapshotTime) + } + ipRange := parameters.IPRange() + require.Equal(t, ipRange.String(), sasCompMap["sip"]) + require.Equal(t, parameters.Identifier(), sasCompMap["si"]) + require.Equal(t, parameters.Resource(), sasCompMap["sr"]) + require.Equal(t, parameters.Permissions(), sasCompMap["sp"]) + + sign, err := url.QueryUnescape(sasCompMap["sig"]) + require.NoError(t, err) + + require.Equal(t, parameters.Signature(), sign) + require.Equal(t, parameters.CacheControl(), sasCompMap["rscc"]) + require.Equal(t, parameters.ContentDisposition(), sasCompMap["rscd"]) + require.Equal(t, parameters.ContentEncoding(), sasCompMap["rsce"]) + require.Equal(t, parameters.ContentLanguage(), sasCompMap["rscl"]) + require.Equal(t, parameters.ContentType(), sasCompMap["rsct"]) + require.Equal(t, parameters.SignedOID(), sasCompMap["skoid"]) + require.Equal(t, parameters.SignedTID(), sasCompMap["sktid"]) + + if _, ok := sasCompMap["skt"]; ok { + signedStart, _, err := parseTime(sasCompMap["skt"]) + require.NoError(t, err) + require.Equal(t, parameters.SignedStart(), signedStart) + } + require.Equal(t, parameters.SignedService(), sasCompMap["sks"]) + + if _, ok := sasCompMap["ske"]; ok { + signedExpiry, _, err := parseTime(sasCompMap["ske"]) + require.NoError(t, err) + require.Equal(t, parameters.SignedExpiry(), signedExpiry) + } + + require.Equal(t, parameters.SignedVersion(), sasCompMap["skv"]) + require.Equal(t, parameters.SignedDirectoryDepth(), sasCompMap["sdd"]) + require.Equal(t, parameters.AuthorizedObjectID(), sasCompMap["saoid"]) + require.Equal(t, parameters.UnauthorizedObjectID(), sasCompMap["suoid"]) + require.Equal(t, parameters.SignedCorrelationID(), sasCompMap["scid"]) +} + +func TestSASInvalidQueryParameter(t *testing.T) { + // Signature is invalid below + const sas = "sv=2019-12-12&signature=clNxbtnkKSHw7f3KMEVVc4agaszoRFdbZr%2FWBmPNsrw%3D&sr=b" + _url := fmt.Sprintf("https://teststorageaccount.blob.core.windows.net/testcontainer/testpath?%s", sas) + _uri, err := url.Parse(_url) + require.NoError(t, err) + newQueryParameters(_uri.Query(), true) + // NewQueryParameters should not delete signature + require.Contains(t, _uri.Query(), "signature") +} + +func TestEncode(t *testing.T) { + // Note: This is a totally invalid fake SAS, this is just testing our ability to parse different query parameters on a SAS + expected := "rscc=cc&rscd=cd&rsce=ce&rscl=cl&rsct=ct&saoid=oid&scid=cid&sdd=3&se=2222-03-09T01%3A42%3A34Z&si=myIdentifier&sig=clNxbtnkKSHw7f3KMEVVc4agaszoRFdbZr%2FWBmPNsrw%3D&sip=168.1.5.60-168.1.5.70&ske=2222-03-09T01%3A42%3A34Z&skoid=oid&sks=s&skt=2111-01-09T01%3A42%3A34Z&sktid=tid&skv=v&sp=rw&spr=https%2Chttp&sr=b&srt=sco&ss=bf&st=2111-01-09T01%3A42%3A34Z&suoid=oid&sv=2019-12-12" + randomOrder := "sdd=3&scid=cid&se=2222-03-09T01:42:34.936Z&rsce=ce&ss=bf&skoid=oid&si=myIdentifier&ske=2222-03-09T01:42:34.936Z&saoid=oid&sip=168.1.5.60-168.1.5.70&rscc=cc&srt=sco&sig=clNxbtnkKSHw7f3KMEVVc4agaszoRFdbZr%2FWBmPNsrw%3D&rsct=ct&skt=2111-01-09T01:42:34.936Z&rscl=cl&suoid=oid&sv=2019-12-12&sr=b&st=2111-01-09T01:42:34.936Z&rscd=cd&sp=rw&sktid=tid&spr=https,http&sks=s&skv=v" + testdata := []string{expected, randomOrder} + + for _, sas := range testdata { + _url := fmt.Sprintf("https://teststorageaccount.blob.core.windows.net/testcontainer/testpath?%s", sas) + _uri, err := url.Parse(_url) + require.NoError(t, err) + queryParams := newQueryParameters(_uri.Query(), true) + require.Equal(t, expected, queryParams.Encode()) + } +} diff --git a/sdk/storage/azdatalake/sas/service.go b/sdk/storage/azdatalake/sas/service.go new file mode 100644 index 000000000000..23518b25676d --- /dev/null +++ b/sdk/storage/azdatalake/sas/service.go @@ -0,0 +1,413 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package sas + +import ( + "bytes" + "errors" + "fmt" + "strings" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/exported" +) + +// DatalakeSignatureValues is used to generate a Shared Access Signature (SAS) for an Azure Storage filesystem or path. +// For more information on creating service sas, see https://docs.microsoft.com/rest/api/storageservices/constructing-a-service-sas +// For more information on creating user delegation sas, see https://docs.microsoft.com/rest/api/storageservices/create-user-delegation-sas +type DatalakeSignatureValues struct { + Version string `param:"sv"` // If not specified, this defaults to Version + Protocol Protocol `param:"spr"` // See the Protocol* constants + StartTime time.Time `param:"st"` // Not specified if IsZero + ExpiryTime time.Time `param:"se"` // Not specified if IsZero + Permissions string `param:"sp"` // Create by initializing FilesystemPermissions, FilePermissions or DirectoryPermissions and then call String() + IPRange IPRange `param:"sip"` + Identifier string `param:"si"` + FilesystemName string + // Use "" to create a Filesystem SAS + // DirectoryPath will set this to "" if it is passed + FilePath string + // Not nil for a directory SAS (ie sr=d) + // Use "" to create a Filesystem SAS + DirectoryPath string + CacheControl string // rscc + ContentDisposition string // rscd + ContentEncoding string // rsce + ContentLanguage string // rscl + ContentType string // rsct + AuthorizedObjectID string // saoid + UnauthorizedObjectID string // suoid + CorrelationID string // scid +} + +//TODO: add snapshot and versioning support in the future + +func getDirectoryDepth(path string) string { + if path == "" { + return "" + } + return fmt.Sprint(strings.Count(path, "/") + 1) +} + +// SignWithSharedKey uses an account's SharedKeyCredential to sign this signature values to produce the proper SAS query parameters. +func (v DatalakeSignatureValues) SignWithSharedKey(sharedKeyCredential *SharedKeyCredential) (QueryParameters, error) { + if v.ExpiryTime.IsZero() || v.Permissions == "" { + return QueryParameters{}, errors.New("service SAS is missing at least one of these: ExpiryTime or Permissions") + } + + //Make sure the permission characters are in the correct order + perms, err := parsePathPermissions(v.Permissions) + if err != nil { + return QueryParameters{}, err + } + v.Permissions = perms.String() + + resource := "c" + if v.DirectoryPath != "" { + resource = "d" + v.FilePath = "" + } else if v.FilePath == "" { + // do nothing + } else { + resource = "b" + } + + if v.Version == "" { + v.Version = Version + } + startTime, expiryTime := formatTimesForSigning(v.StartTime, v.ExpiryTime) + + signedIdentifier := v.Identifier + + // String to sign: http://msdn.microsoft.com/en-us/library/azure/dn140255.aspx + stringToSign := strings.Join([]string{ + v.Permissions, + startTime, + expiryTime, + getCanonicalName(sharedKeyCredential.AccountName(), v.FilesystemName, v.FilePath, v.DirectoryPath), + signedIdentifier, + v.IPRange.String(), + string(v.Protocol), + v.Version, + resource, + "", //snapshot not supported + v.CacheControl, // rscc + v.ContentDisposition, // rscd + v.ContentEncoding, // rsce + v.ContentLanguage, // rscl + v.ContentType}, // rsct + "\n") + + signature, err := exported.ComputeHMACSHA256(sharedKeyCredential, stringToSign) + if err != nil { + return QueryParameters{}, err + } + + p := QueryParameters{ + // Common SAS parameters + version: v.Version, + protocol: v.Protocol, + startTime: v.StartTime, + expiryTime: v.ExpiryTime, + permissions: v.Permissions, + ipRange: v.IPRange, + + // Container/Blob-specific SAS parameters + resource: resource, + identifier: v.Identifier, + cacheControl: v.CacheControl, + contentDisposition: v.ContentDisposition, + contentEncoding: v.ContentEncoding, + contentLanguage: v.ContentLanguage, + contentType: v.ContentType, + signedDirectoryDepth: getDirectoryDepth(v.DirectoryPath), + authorizedObjectID: v.AuthorizedObjectID, + unauthorizedObjectID: v.UnauthorizedObjectID, + correlationID: v.CorrelationID, + // Calculated SAS signature + signature: signature, + } + + return p, nil +} + +// SignWithUserDelegation uses an account's UserDelegationCredential to sign this signature values to produce the proper SAS query parameters. +func (v DatalakeSignatureValues) SignWithUserDelegation(userDelegationCredential *UserDelegationCredential) (QueryParameters, error) { + if userDelegationCredential == nil { + return QueryParameters{}, fmt.Errorf("cannot sign SAS query without User Delegation Key") + } + + if v.ExpiryTime.IsZero() || v.Permissions == "" { + return QueryParameters{}, errors.New("user delegation SAS is missing at least one of these: ExpiryTime or Permissions") + } + + // Parse the resource + resource := "c" + if v.DirectoryPath != "" { + resource = "d" + v.FilePath = "" + } else if v.FilePath == "" { + // do nothing + } else { + resource = "b" + } + // make sure the permission characters are in the correct order + if resource == "c" { + perms, err := parseFilesystemPermissions(v.Permissions) + if err != nil { + return QueryParameters{}, err + } + v.Permissions = perms.String() + } else { + perms, err := parsePathPermissions(v.Permissions) + if err != nil { + return QueryParameters{}, err + } + v.Permissions = perms.String() + } + + if v.Version == "" { + v.Version = Version + } + startTime, expiryTime := formatTimesForSigning(v.StartTime, v.ExpiryTime) + + udk := exported.GetUDKParams(userDelegationCredential) + + udkStart, udkExpiry := formatTimesForSigning(*udk.SignedStart, *udk.SignedExpiry) + + stringToSign := strings.Join([]string{ + v.Permissions, + startTime, + expiryTime, + getCanonicalName(exported.GetAccountName(userDelegationCredential), v.FilesystemName, v.FilePath, v.DirectoryPath), + *udk.SignedOID, + *udk.SignedTID, + udkStart, + udkExpiry, + *udk.SignedService, + *udk.SignedVersion, + v.AuthorizedObjectID, + v.UnauthorizedObjectID, + v.CorrelationID, + v.IPRange.String(), + string(v.Protocol), + v.Version, + resource, + "", //snapshot not supported + v.CacheControl, // rscc + v.ContentDisposition, // rscd + v.ContentEncoding, // rsce + v.ContentLanguage, // rscl + v.ContentType}, // rsct + "\n") + + signature, err := exported.ComputeUDCHMACSHA256(userDelegationCredential, stringToSign) + if err != nil { + return QueryParameters{}, err + } + + p := QueryParameters{ + // Common SAS parameters + version: v.Version, + protocol: v.Protocol, + startTime: v.StartTime, + expiryTime: v.ExpiryTime, + permissions: v.Permissions, + ipRange: v.IPRange, + + // Container/Blob-specific SAS parameters + resource: resource, + identifier: v.Identifier, + cacheControl: v.CacheControl, + contentDisposition: v.ContentDisposition, + contentEncoding: v.ContentEncoding, + contentLanguage: v.ContentLanguage, + contentType: v.ContentType, + signedDirectoryDepth: getDirectoryDepth(v.DirectoryPath), + authorizedObjectID: v.AuthorizedObjectID, + unauthorizedObjectID: v.UnauthorizedObjectID, + correlationID: v.CorrelationID, + // Calculated SAS signature + signature: signature, + } + + //User delegation SAS specific parameters + p.signedOID = *udk.SignedOID + p.signedTID = *udk.SignedTID + p.signedStart = *udk.SignedStart + p.signedExpiry = *udk.SignedExpiry + p.signedService = *udk.SignedService + p.signedVersion = *udk.SignedVersion + + return p, nil +} + +// getCanonicalName computes the canonical name for a container or blob resource for SAS signing. +func getCanonicalName(account string, filesystemName string, fileName string, directoryName string) string { + // Container: "/blob/account/containername" + // Blob: "/blob/account/containername/blobname" + elements := []string{"/blob/", account, "/", filesystemName} + if fileName != "" { + elements = append(elements, "/", strings.Replace(fileName, "\\", "/", -1)) + } else if directoryName != "" { + elements = append(elements, "/", directoryName) + } + return strings.Join(elements, "") +} + +// FilesystemPermissions type simplifies creating the permissions string for an Azure Storage container SAS. +// Initialize an instance of this type and then call its String method to set BlobSignatureValues' Permissions field. +// All permissions descriptions can be found here: https://docs.microsoft.com/en-us/rest/api/storageservices/create-service-sas#permissions-for-a-directory-container-or-blob +type FilesystemPermissions struct { + Read, Add, Create, Write, Delete, List, Move bool + Execute, ModifyOwnership, ModifyPermissions bool // Meant for hierarchical namespace accounts +} + +// String produces the SAS permissions string for an Azure Storage container. +// Call this method to set BlobSignatureValues' Permissions field. +func (p *FilesystemPermissions) String() string { + var b bytes.Buffer + if p.Read { + b.WriteRune('r') + } + if p.Add { + b.WriteRune('a') + } + if p.Create { + b.WriteRune('c') + } + if p.Write { + b.WriteRune('w') + } + if p.Delete { + b.WriteRune('d') + } + if p.List { + b.WriteRune('l') + } + if p.Move { + b.WriteRune('m') + } + if p.Execute { + b.WriteRune('e') + } + if p.ModifyOwnership { + b.WriteRune('o') + } + if p.ModifyPermissions { + b.WriteRune('p') + } + return b.String() +} + +// Parse initializes ContainerPermissions' fields from a string. +func parseFilesystemPermissions(s string) (FilesystemPermissions, error) { + p := FilesystemPermissions{} // Clear the flags + for _, r := range s { + switch r { + case 'r': + p.Read = true + case 'a': + p.Add = true + case 'c': + p.Create = true + case 'w': + p.Write = true + case 'd': + p.Delete = true + case 'l': + p.List = true + case 'm': + p.Move = true + case 'e': + p.Execute = true + case 'o': + p.ModifyOwnership = true + case 'p': + p.ModifyPermissions = true + default: + return FilesystemPermissions{}, fmt.Errorf("invalid permission: '%v'", r) + } + } + return p, nil +} + +// PathPermissions type simplifies creating the permissions string for an Azure Storage blob SAS. +// Initialize an instance of this type and then call its String method to set BlobSignatureValues' Permissions field. +type PathPermissions struct { + Read, Add, Create, Write, Delete, List, Move bool + Execute, Ownership, Permissions bool +} + +// String produces the SAS permissions string for an Azure Storage blob. +// Call this method to set BlobSignatureValues' Permissions field. +func (p *PathPermissions) String() string { + var b bytes.Buffer + if p.Read { + b.WriteRune('r') + } + if p.Add { + b.WriteRune('a') + } + if p.Create { + b.WriteRune('c') + } + if p.Write { + b.WriteRune('w') + } + if p.Delete { + b.WriteRune('d') + } + if p.List { + b.WriteRune('l') + } + if p.Move { + b.WriteRune('m') + } + if p.Execute { + b.WriteRune('e') + } + if p.Ownership { + b.WriteRune('o') + } + if p.Permissions { + b.WriteRune('p') + } + return b.String() +} + +// Parse initializes BlobPermissions' fields from a string. +func parsePathPermissions(s string) (PathPermissions, error) { + p := PathPermissions{} // Clear the flags + for _, r := range s { + switch r { + case 'r': + p.Read = true + case 'a': + p.Add = true + case 'c': + p.Create = true + case 'w': + p.Write = true + case 'd': + p.Delete = true + case 'l': + p.List = true + case 'm': + p.Move = true + case 'e': + p.Execute = true + case 'o': + p.Ownership = true + case 'p': + p.Permissions = true + default: + return PathPermissions{}, fmt.Errorf("invalid permission: '%v'", r) + } + } + return p, nil +} diff --git a/sdk/storage/azdatalake/sas/service_test.go b/sdk/storage/azdatalake/sas/service_test.go new file mode 100644 index 000000000000..218a6c116eac --- /dev/null +++ b/sdk/storage/azdatalake/sas/service_test.go @@ -0,0 +1,219 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package sas + +import ( + "github.com/stretchr/testify/require" + "testing" +) + +func TestFilesystemPermissions_String(t *testing.T) { + testdata := []struct { + input FilesystemPermissions + expected string + }{ + {input: FilesystemPermissions{Read: true}, expected: "r"}, + {input: FilesystemPermissions{Add: true}, expected: "a"}, + {input: FilesystemPermissions{Create: true}, expected: "c"}, + {input: FilesystemPermissions{Write: true}, expected: "w"}, + {input: FilesystemPermissions{Delete: true}, expected: "d"}, + {input: FilesystemPermissions{List: true}, expected: "l"}, + {input: FilesystemPermissions{Move: true}, expected: "m"}, + {input: FilesystemPermissions{Execute: true}, expected: "e"}, + {input: FilesystemPermissions{ModifyOwnership: true}, expected: "o"}, + {input: FilesystemPermissions{ModifyPermissions: true}, expected: "p"}, + {input: FilesystemPermissions{ + Read: true, + Add: true, + Create: true, + Write: true, + Delete: true, + List: true, + Move: true, + Execute: true, + ModifyOwnership: true, + ModifyPermissions: true, + }, expected: "racwdlmeop"}, + } + for _, c := range testdata { + require.Equal(t, c.expected, c.input.String()) + } +} + +func TestFilesystemPermissions_Parse(t *testing.T) { + testdata := []struct { + input string + expected FilesystemPermissions + }{ + {expected: FilesystemPermissions{Read: true}, input: "r"}, + {expected: FilesystemPermissions{Add: true}, input: "a"}, + {expected: FilesystemPermissions{Create: true}, input: "c"}, + {expected: FilesystemPermissions{Write: true}, input: "w"}, + {expected: FilesystemPermissions{Delete: true}, input: "d"}, + {expected: FilesystemPermissions{List: true}, input: "l"}, + {expected: FilesystemPermissions{Move: true}, input: "m"}, + {expected: FilesystemPermissions{Execute: true}, input: "e"}, + {expected: FilesystemPermissions{ModifyOwnership: true}, input: "o"}, + {expected: FilesystemPermissions{ModifyPermissions: true}, input: "p"}, + {expected: FilesystemPermissions{ + Read: true, + Add: true, + Create: true, + Write: true, + Delete: true, + List: true, + Move: true, + Execute: true, + ModifyOwnership: true, + ModifyPermissions: true, + }, input: "racwdlmeop"}, + {expected: FilesystemPermissions{ + Read: true, + Add: true, + Create: true, + Write: true, + Delete: true, + List: true, + Move: true, + Execute: true, + ModifyOwnership: true, + ModifyPermissions: true, + }, input: "cpwmreodal"}, // Wrong order parses correctly + } + for _, c := range testdata { + permissions, err := parseFilesystemPermissions(c.input) + require.Nil(t, err) + require.Equal(t, c.expected, permissions) + } +} + +func TestFilesystemPermissions_ParseNegative(t *testing.T) { + _, err := parseFilesystemPermissions("cpwmreodalz") // Here 'z' is invalid + require.NotNil(t, err) + require.Contains(t, err.Error(), "122") +} + +func TestPathPermissions_String(t *testing.T) { + testdata := []struct { + input PathPermissions + expected string + }{ + {input: PathPermissions{Read: true}, expected: "r"}, + {input: PathPermissions{Add: true}, expected: "a"}, + {input: PathPermissions{Create: true}, expected: "c"}, + {input: PathPermissions{Write: true}, expected: "w"}, + {input: PathPermissions{Delete: true}, expected: "d"}, + {input: PathPermissions{List: true}, expected: "l"}, + {input: PathPermissions{Move: true}, expected: "m"}, + {input: PathPermissions{Execute: true}, expected: "e"}, + {input: PathPermissions{Ownership: true}, expected: "o"}, + {input: PathPermissions{Permissions: true}, expected: "p"}, + {input: PathPermissions{ + Read: true, + Add: true, + Create: true, + Write: true, + Delete: true, + List: true, + Move: true, + Execute: true, + Ownership: true, + Permissions: true, + }, expected: "racwdlmeop"}, + } + for _, c := range testdata { + require.Equal(t, c.expected, c.input.String()) + } +} + +func TestPathPermissions_Parse(t *testing.T) { + testdata := []struct { + expected PathPermissions + input string + }{ + {expected: PathPermissions{Read: true}, input: "r"}, + {expected: PathPermissions{Add: true}, input: "a"}, + {expected: PathPermissions{Create: true}, input: "c"}, + {expected: PathPermissions{Write: true}, input: "w"}, + {expected: PathPermissions{Delete: true}, input: "d"}, + {expected: PathPermissions{List: true}, input: "l"}, + {expected: PathPermissions{Move: true}, input: "m"}, + {expected: PathPermissions{Execute: true}, input: "e"}, + {expected: PathPermissions{Ownership: true}, input: "o"}, + {expected: PathPermissions{Permissions: true}, input: "p"}, + {expected: PathPermissions{ + Read: true, + Add: true, + Create: true, + Write: true, + Delete: true, + List: true, + Move: true, + Execute: true, + Ownership: true, + Permissions: true, + }, input: "racwdlmeop"}, + {expected: PathPermissions{ + Read: true, + Add: true, + Create: true, + Write: true, + Delete: true, + List: true, + Move: true, + Execute: true, + Ownership: true, + Permissions: true, + }, input: "apwecrdlmo"}, // Wrong order parses correctly + } + for _, c := range testdata { + permissions, err := parsePathPermissions(c.input) + require.Nil(t, err) + require.Equal(t, c.expected, permissions) + } +} + +func TestParsePermissionsNegative(t *testing.T) { + _, err := parsePathPermissions("awecrdlfmo") // Here 'f' is invalid + require.NotNil(t, err) + require.Contains(t, err.Error(), "102") +} + +func TestGetCanonicalName(t *testing.T) { + testdata := []struct { + inputAccount string + inputContainer string + inputBlob string + inputDirectory string + expected string + }{ + {inputAccount: "fakestorageaccount", inputContainer: "fakestoragecontainer", expected: "/blob/fakestorageaccount/fakestoragecontainer"}, + {inputAccount: "fakestorageaccount", inputContainer: "fakestoragecontainer", inputBlob: "fakestorageblob", expected: "/blob/fakestorageaccount/fakestoragecontainer/fakestorageblob"}, + {inputAccount: "fakestorageaccount", inputContainer: "fakestoragecontainer", inputBlob: "fakestoragevirtualdir/fakestorageblob", expected: "/blob/fakestorageaccount/fakestoragecontainer/fakestoragevirtualdir/fakestorageblob"}, + {inputAccount: "fakestorageaccount", inputContainer: "fakestoragecontainer", inputBlob: "fakestoragevirtualdir\\fakestorageblob", expected: "/blob/fakestorageaccount/fakestoragecontainer/fakestoragevirtualdir/fakestorageblob"}, + {inputAccount: "fakestorageaccount", inputContainer: "fakestoragecontainer", inputBlob: "fakestoragedirectory", expected: "/blob/fakestorageaccount/fakestoragecontainer/fakestoragedirectory"}, + } + for _, c := range testdata { + require.Equal(t, c.expected, getCanonicalName(c.inputAccount, c.inputContainer, c.inputBlob, c.inputDirectory)) + } +} + +func TestGetDirectoryDepth(t *testing.T) { + testdata := []struct { + input string + expected string + }{ + {input: "", expected: ""}, + {input: "myfile", expected: "1"}, + {input: "mydirectory", expected: "1"}, + {input: "mydirectory/myfile", expected: "2"}, + {input: "mydirectory/mysubdirectory", expected: "2"}, + } + for _, c := range testdata { + require.Equal(t, c.expected, getDirectoryDepth(c.input)) + } +} diff --git a/sdk/storage/azdatalake/sas/url_parts.go b/sdk/storage/azdatalake/sas/url_parts.go new file mode 100644 index 000000000000..8af46a1329ad --- /dev/null +++ b/sdk/storage/azdatalake/sas/url_parts.go @@ -0,0 +1,109 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package sas + +import ( + "net/url" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/shared" +) + +// IPEndpointStyleInfo is used for IP endpoint style URL when working with Azure storage emulator. +// Ex: "https://10.132.141.33/accountname/containername" +type IPEndpointStyleInfo struct { + AccountName string // "" if not using IP endpoint style +} + +// URLParts object represents the components that make up an Azure Storage Container/Blob URL. +// NOTE: Changing any SAS-related field requires computing a new SAS signature. +type URLParts struct { + Scheme string // Ex: "https://" + Host string // Ex: "account.blob.core.windows.net", "10.132.141.33", "10.132.141.33:80" + IPEndpointStyleInfo IPEndpointStyleInfo + FilesystemName string // "" if no container + PathName string // "" if no blob + SAS QueryParameters + UnparsedParams string +} + +// ParseURL parses a URL initializing URLParts' fields including any SAS-related & snapshot query parameters. +// Any other query parameters remain in the UnparsedParams field. +func ParseURL(u string) (URLParts, error) { + uri, err := url.Parse(u) + if err != nil { + return URLParts{}, err + } + + up := URLParts{ + Scheme: uri.Scheme, + Host: uri.Host, + } + + // Find the container & blob names (if any) + if uri.Path != "" { + path := uri.Path + if path[0] == '/' { + path = path[1:] // If path starts with a slash, remove it + } + if shared.IsIPEndpointStyle(up.Host) { + if accountEndIndex := strings.Index(path, "/"); accountEndIndex == -1 { // Slash not found; path has account name & no container name or blob + up.IPEndpointStyleInfo.AccountName = path + path = "" // No ContainerName present in the URL so path should be empty + } else { + up.IPEndpointStyleInfo.AccountName = path[:accountEndIndex] // The account name is the part between the slashes + path = path[accountEndIndex+1:] // path refers to portion after the account name now (container & blob names) + } + } + + filesystemEndIndex := strings.Index(path, "/") // Find the next slash (if it exists) + if filesystemEndIndex == -1 { // Slash not found; path has container name & no blob name + up.FilesystemName = path + } else { + up.FilesystemName = path[:filesystemEndIndex] // The container name is the part between the slashes + up.PathName = path[filesystemEndIndex+1:] // The blob name is after the container slash + } + } + + // Convert the query parameters to a case-sensitive map & trim whitespace + paramsMap := uri.Query() + up.SAS = newQueryParameters(paramsMap, true) + up.UnparsedParams = paramsMap.Encode() + return up, nil +} + +// String returns a URL object whose fields are initialized from the URLParts fields. The URL's RawQuery +// field contains the SAS, snapshot, and unparsed query parameters. +func (up URLParts) String() string { + path := "" + if shared.IsIPEndpointStyle(up.Host) && up.IPEndpointStyleInfo.AccountName != "" { + path += "/" + up.IPEndpointStyleInfo.AccountName + } + // Concatenate container & blob names (if they exist) + if up.FilesystemName != "" { + path += "/" + up.FilesystemName + if up.PathName != "" { + path += "/" + up.PathName + } + } + + rawQuery := up.UnparsedParams + sas := up.SAS.Encode() + if sas != "" { + if len(rawQuery) > 0 { + rawQuery += "&" + } + rawQuery += sas + } + u := url.URL{ + Scheme: up.Scheme, + Host: up.Host, + Path: path, + RawQuery: rawQuery, + } + return u.String() +} diff --git a/sdk/storage/azdatalake/sas/url_parts_test.go b/sdk/storage/azdatalake/sas/url_parts_test.go new file mode 100644 index 000000000000..88b8d94f7a10 --- /dev/null +++ b/sdk/storage/azdatalake/sas/url_parts_test.go @@ -0,0 +1,73 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package sas + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseURLIPStyle(t *testing.T) { + urlWithIP := "https://127.0.0.1:5000/fakestorageaccount" + blobURLParts, err := ParseURL(urlWithIP) + require.NoError(t, err) + require.Equal(t, blobURLParts.Scheme, "https") + require.Equal(t, blobURLParts.Host, "127.0.0.1:5000") + require.Equal(t, blobURLParts.IPEndpointStyleInfo.AccountName, "fakestorageaccount") + + urlWithIP = "https://127.0.0.1:5000/fakestorageaccount/fakecontainer" + blobURLParts, err = ParseURL(urlWithIP) + require.NoError(t, err) + require.Equal(t, blobURLParts.Scheme, "https") + require.Equal(t, blobURLParts.Host, "127.0.0.1:5000") + require.Equal(t, blobURLParts.IPEndpointStyleInfo.AccountName, "fakestorageaccount") + require.Equal(t, blobURLParts.FilesystemName, "fakecontainer") + + urlWithIP = "https://127.0.0.1:5000/fakestorageaccount/fakecontainer/fakeblob" + blobURLParts, err = ParseURL(urlWithIP) + require.NoError(t, err) + require.Equal(t, blobURLParts.Scheme, "https") + require.Equal(t, blobURLParts.Host, "127.0.0.1:5000") + require.Equal(t, blobURLParts.IPEndpointStyleInfo.AccountName, "fakestorageaccount") + require.Equal(t, blobURLParts.FilesystemName, "fakecontainer") + require.Equal(t, blobURLParts.PathName, "fakeblob") +} + +func TestParseURL(t *testing.T) { + testStorageAccount := "fakestorageaccount" + host := fmt.Sprintf("%s.blob.core.windows.net", testStorageAccount) + testContainer := "fakecontainer" + fileNames := []string{"/._.TESTT.txt", "/.gitignore/dummyfile1"} + + const sasStr = "sv=2019-12-12&sr=b&st=2111-01-09T01:42:34.936Z&se=2222-03-09T01:42:34.936Z&sp=rw&sip=168.1.5.60-168.1.5.70&spr=https,http&si=myIdentifier&ss=bf&srt=s&sig=clNxbtnkKSHw7f3KMEVVc4agaszoRFdbZr%2FWBmPNsrw%3D" + + for _, fileName := range fileNames { + urlWithVersion := fmt.Sprintf("https://%s.blob.core.windows.net/%s%s", testStorageAccount, testContainer, fileName) + blobURLParts, err := ParseURL(urlWithVersion) + require.NoError(t, err) + + require.Equal(t, blobURLParts.Scheme, "https") + require.Equal(t, blobURLParts.Host, host) + require.Equal(t, blobURLParts.FilesystemName, testContainer) + + validateSAS(t, sasStr, blobURLParts.SAS) + } + + for _, fileName := range fileNames { + urlWithVersion := fmt.Sprintf("https://%s.blob.core.windows.net/%s%s", testStorageAccount, testContainer, fileName) + blobURLParts, err := ParseURL(urlWithVersion) + require.NoError(t, err) + + require.Equal(t, blobURLParts.Scheme, "https") + require.Equal(t, blobURLParts.Host, host) + require.Equal(t, blobURLParts.FilesystemName, testContainer) + + validateSAS(t, sasStr, blobURLParts.SAS) + } +}