Skip to content

Commit 0abbe92

Browse files
authored
Use default STS endpoint (#2044)
1 parent 5757f2c commit 0abbe92

11 files changed

+149
-76
lines changed

api.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -1026,6 +1026,7 @@ func (c *Client) CredContext() *credentials.CredContext {
10261026
httpClient = http.DefaultClient
10271027
}
10281028
return &credentials.CredContext{
1029-
Client: httpClient,
1029+
Client: httpClient,
1030+
Endpoint: c.endpointURL.String(),
10301031
}
10311032
}

pkg/credentials/assume_role.go

+21-12
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,6 @@ type STSAssumeRoleOptions struct {
109109
// NewSTSAssumeRole returns a pointer to a new
110110
// Credentials object wrapping the STSAssumeRole.
111111
func NewSTSAssumeRole(stsEndpoint string, opts STSAssumeRoleOptions) (*Credentials, error) {
112-
if stsEndpoint == "" {
113-
return nil, errors.New("STS endpoint cannot be empty")
114-
}
115112
if opts.AccessKey == "" || opts.SecretKey == "" {
116113
return nil, errors.New("AssumeRole credentials access/secretkey is mandatory")
117114
}
@@ -220,12 +217,30 @@ func getAssumeRoleCredentials(clnt *http.Client, endpoint string, opts STSAssume
220217
return a, nil
221218
}
222219

223-
func (m *STSAssumeRole) retrieve(cc *CredContext) (Value, error) {
220+
// RetrieveWithCredContext retrieves credentials from the MinIO service.
221+
// Error will be returned if the request fails, optional cred context.
222+
func (m *STSAssumeRole) RetrieveWithCredContext(cc *CredContext) (Value, error) {
223+
if cc == nil {
224+
cc = defaultCredContext
225+
}
226+
224227
client := m.Client
225228
if client == nil {
226229
client = cc.Client
227230
}
228-
a, err := getAssumeRoleCredentials(client, m.STSEndpoint, m.Options)
231+
if client == nil {
232+
client = defaultCredContext.Client
233+
}
234+
235+
stsEndpoint := m.STSEndpoint
236+
if stsEndpoint == "" {
237+
stsEndpoint = cc.Endpoint
238+
}
239+
if stsEndpoint == "" {
240+
return Value{}, errors.New("STS endpoint unknown")
241+
}
242+
243+
a, err := getAssumeRoleCredentials(client, stsEndpoint, m.Options)
229244
if err != nil {
230245
return Value{}, err
231246
}
@@ -242,14 +257,8 @@ func (m *STSAssumeRole) retrieve(cc *CredContext) (Value, error) {
242257
}, nil
243258
}
244259

245-
// RetrieveWithCredContext retrieves credentials from the MinIO service.
246-
// Error will be returned if the request fails, optional cred context.
247-
func (m *STSAssumeRole) RetrieveWithCredContext(cc *CredContext) (Value, error) {
248-
return m.retrieve(cc)
249-
}
250-
251260
// Retrieve retrieves credentials from the MinIO service.
252261
// Error will be returned if the request fails.
253262
func (m *STSAssumeRole) Retrieve() (Value, error) {
254-
return m.retrieve(defaultCredContext)
263+
return m.RetrieveWithCredContext(nil)
255264
}

pkg/credentials/chain.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ func (c *Chain) RetrieveWithCredContext(cc *CredContext) (Value, error) {
8080
// to IsExpired() will return the expired state of the cached provider.
8181
func (c *Chain) Retrieve() (Value, error) {
8282
for _, p := range c.Providers {
83-
creds, _ := p.RetrieveWithCredContext(defaultCredContext)
83+
creds, _ := p.Retrieve()
8484
// Always prioritize non-anonymous providers, if any.
8585
if creds.AccessKeyID == "" && creds.SecretAccessKey == "" {
8686
continue

pkg/credentials/credentials.go

+12-1
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ type Provider interface {
6464

6565
// Retrieve returns nil if it successfully retrieved the value.
6666
// Error is returned if the value were not obtainable, or empty.
67+
//
68+
// Deprecated: Retrieve() exists for historical compatibility and should not
69+
// be used. To get new credentials use the RetrieveWithCredContext function
70+
// to ensure the proper context (i.e. HTTP client) will be used.
6771
Retrieve() (Value, error)
6872

6973
// IsExpired returns if the credentials are no longer valid, and need
@@ -77,6 +81,10 @@ type CredContext struct {
7781
// Client specifies the HTTP client that should be used if an HTTP
7882
// request is to be made to fetch the credentials.
7983
Client *http.Client
84+
85+
// Endpoint specifies the MinIO endpoint that will be used if no
86+
// explicit endpoint is provided.
87+
Endpoint string
8088
}
8189

8290
// A Expiry provides shared expiration logic to be used by credentials
@@ -169,7 +177,7 @@ func New(provider Provider) *Credentials {
169177
// used. To get new credentials use the Credentials.GetWithContext function
170178
// to ensure the proper context (i.e. HTTP client) will be used.
171179
func (c *Credentials) Get() (Value, error) {
172-
return c.GetWithContext(defaultCredContext)
180+
return c.GetWithContext(nil)
173181
}
174182

175183
// GetWithContext returns the credentials value, or error if the
@@ -185,6 +193,9 @@ func (c *Credentials) GetWithContext(cc *CredContext) (Value, error) {
185193
if c == nil {
186194
return Value{}, nil
187195
}
196+
if cc == nil {
197+
cc = defaultCredContext
198+
}
188199

189200
c.Lock()
190201
defer c.Unlock()

pkg/credentials/env_aws.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ func (e *EnvAWS) Retrieve() (Value, error) {
6969
return e.retrieve()
7070
}
7171

72-
// RetrieveWithContext is like Retrieve (no-op input of Cred Context)
72+
// RetrieveWithCredContext is like Retrieve (no-op input of Cred Context)
7373
func (e *EnvAWS) RetrieveWithCredContext(_ *CredContext) (Value, error) {
7474
return e.retrieve()
7575
}

pkg/credentials/iam_aws.go

+14-7
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,12 @@ func NewIAM(endpoint string) *Credentials {
9595
})
9696
}
9797

98-
func (m *IAM) retrieve(cc *CredContext) (Value, error) {
98+
// RetrieveWithCredContext is like Retrieve with Cred Context
99+
func (m *IAM) RetrieveWithCredContext(cc *CredContext) (Value, error) {
100+
if cc == nil {
101+
cc = defaultCredContext
102+
}
103+
99104
token := os.Getenv("AWS_CONTAINER_AUTHORIZATION_TOKEN")
100105
if token == "" {
101106
token = m.Container.AuthorizationToken
@@ -143,8 +148,15 @@ func (m *IAM) retrieve(cc *CredContext) (Value, error) {
143148
if client == nil {
144149
client = cc.Client
145150
}
151+
if client == nil {
152+
client = defaultCredContext.Client
153+
}
146154

147155
endpoint := m.Endpoint
156+
if endpoint == "" {
157+
endpoint = cc.Endpoint
158+
}
159+
148160
switch {
149161
case identityFile != "":
150162
if len(endpoint) == 0 {
@@ -228,12 +240,7 @@ func (m *IAM) retrieve(cc *CredContext) (Value, error) {
228240
// Error will be returned if the request fails, or unable to extract
229241
// the desired
230242
func (m *IAM) Retrieve() (Value, error) {
231-
return m.retrieve(defaultCredContext)
232-
}
233-
234-
// RetrieveWithCredContext is like Retrieve with Cred Context
235-
func (m *IAM) RetrieveWithCredContext(cc *CredContext) (Value, error) {
236-
return m.retrieve(cc)
243+
return m.RetrieveWithCredContext(nil)
237244
}
238245

239246
// A ec2RoleCredRespBody provides the shape for unmarshaling credential

pkg/credentials/sts_client_grants.go

+20-11
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,6 @@ type STSClientGrants struct {
9191
// NewSTSClientGrants returns a pointer to a new
9292
// Credentials object wrapping the STSClientGrants.
9393
func NewSTSClientGrants(stsEndpoint string, getClientGrantsTokenExpiry func() (*ClientGrantsToken, error)) (*Credentials, error) {
94-
if stsEndpoint == "" {
95-
return nil, errors.New("STS endpoint cannot be empty")
96-
}
9794
if getClientGrantsTokenExpiry == nil {
9895
return nil, errors.New("Client grants access token and expiry retrieval function should be defined")
9996
}
@@ -160,12 +157,29 @@ func getClientGrantsCredentials(clnt *http.Client, endpoint string,
160157
return a, nil
161158
}
162159

163-
func (m *STSClientGrants) retrieve(cc *CredContext) (Value, error) {
160+
// RetrieveWithCredContext is like Retrieve() with cred context
161+
func (m *STSClientGrants) RetrieveWithCredContext(cc *CredContext) (Value, error) {
162+
if cc == nil {
163+
cc = defaultCredContext
164+
}
165+
164166
client := m.Client
165167
if client == nil {
166168
client = cc.Client
167169
}
168-
a, err := getClientGrantsCredentials(client, m.STSEndpoint, m.GetClientGrantsTokenExpiry)
170+
if client == nil {
171+
client = defaultCredContext.Client
172+
}
173+
174+
stsEndpoint := m.STSEndpoint
175+
if stsEndpoint == "" {
176+
stsEndpoint = cc.Endpoint
177+
}
178+
if stsEndpoint == "" {
179+
return Value{}, errors.New("STS endpoint unknown")
180+
}
181+
182+
a, err := getClientGrantsCredentials(client, stsEndpoint, m.GetClientGrantsTokenExpiry)
169183
if err != nil {
170184
return Value{}, err
171185
}
@@ -182,13 +196,8 @@ func (m *STSClientGrants) retrieve(cc *CredContext) (Value, error) {
182196
}, nil
183197
}
184198

185-
// RetrieveWithCredContext is like Retrieve() with cred context
186-
func (m *STSClientGrants) RetrieveWithCredContext(cc *CredContext) (Value, error) {
187-
return m.retrieve(cc)
188-
}
189-
190199
// Retrieve retrieves credentials from the MinIO service.
191200
// Error will be returned if the request fails.
192201
func (m *STSClientGrants) Retrieve() (Value, error) {
193-
return m.retrieve(defaultCredContext)
202+
return m.RetrieveWithCredContext(nil)
194203
}

pkg/credentials/sts_custom_identity.go

+19-8
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,21 @@ type CustomTokenIdentity struct {
7171
RequestedExpiry time.Duration
7272
}
7373

74-
func (c *CustomTokenIdentity) retrieve(cc *CredContext) (value Value, err error) {
75-
u, err := url.Parse(c.STSEndpoint)
74+
// RetrieveWithCredContext with Retrieve optionally cred context
75+
func (c *CustomTokenIdentity) RetrieveWithCredContext(cc *CredContext) (value Value, err error) {
76+
if cc == nil {
77+
cc = defaultCredContext
78+
}
79+
80+
stsEndpoint := c.STSEndpoint
81+
if stsEndpoint == "" {
82+
stsEndpoint = cc.Endpoint
83+
}
84+
if stsEndpoint == "" {
85+
return Value{}, errors.New("STS endpoint unknown")
86+
}
87+
88+
u, err := url.Parse(stsEndpoint)
7689
if err != nil {
7790
return value, err
7891
}
@@ -97,6 +110,9 @@ func (c *CustomTokenIdentity) retrieve(cc *CredContext) (value Value, err error)
97110
if client == nil {
98111
client = cc.Client
99112
}
113+
if client == nil {
114+
client = defaultCredContext.Client
115+
}
100116

101117
resp, err := client.Do(req)
102118
if err != nil {
@@ -126,12 +142,7 @@ func (c *CustomTokenIdentity) retrieve(cc *CredContext) (value Value, err error)
126142

127143
// Retrieve - to satisfy Provider interface; fetches credentials from MinIO.
128144
func (c *CustomTokenIdentity) Retrieve() (value Value, err error) {
129-
return c.retrieve(defaultCredContext)
130-
}
131-
132-
// RetrieveWithCredContext with Retrieve optionally cred context
133-
func (c *CustomTokenIdentity) RetrieveWithCredContext(cc *CredContext) (value Value, err error) {
134-
return c.retrieve(cc)
145+
return c.RetrieveWithCredContext(nil)
135146
}
136147

137148
// NewCustomTokenCredentials - returns credentials using the

pkg/credentials/sts_ldap_identity.go

+21-9
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package credentials
2020
import (
2121
"bytes"
2222
"encoding/xml"
23+
"errors"
2324
"fmt"
2425
"io"
2526
"net/http"
@@ -120,8 +121,22 @@ func NewLDAPIdentityWithSessionPolicy(stsEndpoint, ldapUsername, ldapPassword, p
120121
}), nil
121122
}
122123

123-
func (k *LDAPIdentity) retrieve(cc *CredContext) (value Value, err error) {
124-
u, err := url.Parse(k.STSEndpoint)
124+
// RetrieveWithCredContext gets the credential by calling the MinIO STS API for
125+
// LDAP on the configured stsEndpoint.
126+
func (k *LDAPIdentity) RetrieveWithCredContext(cc *CredContext) (value Value, err error) {
127+
if cc == nil {
128+
cc = defaultCredContext
129+
}
130+
131+
stsEndpoint := k.STSEndpoint
132+
if stsEndpoint == "" {
133+
stsEndpoint = cc.Endpoint
134+
}
135+
if stsEndpoint == "" {
136+
return Value{}, errors.New("STS endpoint unknown")
137+
}
138+
139+
u, err := url.Parse(stsEndpoint)
125140
if err != nil {
126141
return value, err
127142
}
@@ -149,6 +164,9 @@ func (k *LDAPIdentity) retrieve(cc *CredContext) (value Value, err error) {
149164
if client == nil {
150165
client = cc.Client
151166
}
167+
if client == nil {
168+
client = defaultCredContext.Client
169+
}
152170

153171
resp, err := client.Do(req)
154172
if err != nil {
@@ -194,11 +212,5 @@ func (k *LDAPIdentity) retrieve(cc *CredContext) (value Value, err error) {
194212
// Retrieve gets the credential by calling the MinIO STS API for
195213
// LDAP on the configured stsEndpoint.
196214
func (k *LDAPIdentity) Retrieve() (value Value, err error) {
197-
return k.retrieve(defaultCredContext)
198-
}
199-
200-
// RetrieveWithCredContext gets the credential by calling the MinIO STS API for
201-
// LDAP on the configured stsEndpoint.
202-
func (k *LDAPIdentity) RetrieveWithCredContext(cc *CredContext) (value Value, err error) {
203-
return k.retrieve(cc)
215+
return k.RetrieveWithCredContext(defaultCredContext)
204216
}

pkg/credentials/sts_tls_identity.go

+19-14
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,6 @@ type STSCertificateIdentity struct {
8686
// to the given STS endpoint with the given TLS certificate and retrieves and
8787
// rotates S3 credentials.
8888
func NewSTSCertificateIdentity(endpoint string, certificate tls.Certificate, options ...CertificateIdentityOption) (*Credentials, error) {
89-
if endpoint == "" {
90-
return nil, errors.New("STS endpoint cannot be empty")
91-
}
92-
if _, err := url.Parse(endpoint); err != nil {
93-
return nil, err
94-
}
9589
identity := &STSCertificateIdentity{
9690
STSEndpoint: endpoint,
9791
Certificate: certificate,
@@ -102,8 +96,21 @@ func NewSTSCertificateIdentity(endpoint string, certificate tls.Certificate, opt
10296
return New(identity), nil
10397
}
10498

105-
func (i *STSCertificateIdentity) retrieve(cc *CredContext) (Value, error) {
106-
endpointURL, err := url.Parse(i.STSEndpoint)
99+
// RetrieveWithCredContext is Retrieve with cred context
100+
func (i *STSCertificateIdentity) RetrieveWithCredContext(cc *CredContext) (Value, error) {
101+
if cc == nil {
102+
cc = defaultCredContext
103+
}
104+
105+
stsEndpoint := i.STSEndpoint
106+
if stsEndpoint == "" {
107+
stsEndpoint = cc.Endpoint
108+
}
109+
if stsEndpoint == "" {
110+
return Value{}, errors.New("STS endpoint unknown")
111+
}
112+
113+
endpointURL, err := url.Parse(stsEndpoint)
107114
if err != nil {
108115
return Value{}, err
109116
}
@@ -130,6 +137,9 @@ func (i *STSCertificateIdentity) retrieve(cc *CredContext) (Value, error) {
130137
if client == nil {
131138
client = cc.Client
132139
}
140+
if client == nil {
141+
client = defaultCredContext.Client
142+
}
133143

134144
tr, ok := client.Transport.(*http.Transport)
135145
if !ok {
@@ -192,14 +202,9 @@ func (i *STSCertificateIdentity) retrieve(cc *CredContext) (Value, error) {
192202
}, nil
193203
}
194204

195-
// RetrieveWithCredContext is Retrieve with cred context
196-
func (i *STSCertificateIdentity) RetrieveWithCredContext(cc *CredContext) (Value, error) {
197-
return i.retrieve(cc)
198-
}
199-
200205
// Retrieve fetches a new set of S3 credentials from the configured STS API endpoint.
201206
func (i *STSCertificateIdentity) Retrieve() (Value, error) {
202-
return i.retrieve(defaultCredContext)
207+
return i.RetrieveWithCredContext(defaultCredContext)
203208
}
204209

205210
// Expiration returns the expiration time of the current S3 credentials.

0 commit comments

Comments
 (0)