Skip to content

Commit c0ce81d

Browse files
author
Benjamin Stockwell
authored
Merge pull request #123 from benjsto/benjsto/internal-proxy-provider-url
sso-proxy: internal host should apply to redeem, /refresh, /validate, /profile
2 parents 3f58974 + 66fd6c6 commit c0ce81d

File tree

7 files changed

+97
-175
lines changed

7 files changed

+97
-175
lines changed

internal/proxy/options.go

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,11 @@ import (
4848
type Options struct {
4949
Port int `envconfig:"PORT" default:"4180"`
5050

51-
ProviderURLString string `envconfig:"PROVIDER_URL"`
52-
ProxyProviderURLString string `envconfig:"PROXY_PROVIDER_URL"`
53-
UpstreamConfigsFile string `envconfig:"UPSTREAM_CONFIGS"`
54-
Cluster string `envconfig:"CLUSTER"`
55-
Scheme string `envconfig:"SCHEME" default:"https"`
51+
ProviderURLString string `envconfig:"PROVIDER_URL"`
52+
ProviderURLInternalString string `envconfig:"PROVIDER_URL_INTERNAL"`
53+
UpstreamConfigsFile string `envconfig:"UPSTREAM_CONFIGS"`
54+
Cluster string `envconfig:"CLUSTER"`
55+
Scheme string `envconfig:"SCHEME" default:"https"`
5656

5757
SkipAuthPreflight bool `envconfig:"SKIP_AUTH_PREFLIGHT"`
5858

@@ -133,9 +133,6 @@ func (o *Options) Validate() error {
133133
if o.ProviderURLString == "" {
134134
msgs = append(msgs, "missing setting: provider-url")
135135
}
136-
if o.ProxyProviderURLString == "" {
137-
o.ProxyProviderURLString = o.ProviderURLString
138-
}
139136
if o.UpstreamConfigsFile == "" {
140137
msgs = append(msgs, "missing setting: upstream-configs")
141138
}
@@ -226,23 +223,27 @@ func parseProviderInfo(o *Options) error {
226223
return errors.New("provider-url must include scheme and host")
227224
}
228225

229-
proxyProviderURL, err := url.Parse(o.ProxyProviderURLString)
230-
if err != nil {
231-
return err
232-
}
233-
if proxyProviderURL.Scheme == "" || proxyProviderURL.Host == "" {
234-
return errors.New("proxy provider url must include scheme and host")
226+
var providerURLInternal *url.URL
227+
228+
if o.ProviderURLInternalString != "" {
229+
providerURLInternal, err = url.Parse(o.ProviderURLInternalString)
230+
if err != nil {
231+
return err
232+
}
233+
if providerURLInternal.Scheme == "" || providerURLInternal.Host == "" {
234+
return errors.New("proxy provider url must include scheme and host")
235+
}
235236
}
236237

237238
providerData := &providers.ProviderData{
238-
ClientID: o.ClientID,
239-
ClientSecret: o.ClientSecret,
240-
ProviderURL: providerURL,
241-
ProxyProviderURL: proxyProviderURL,
242-
Scope: o.Scope,
243-
SessionLifetimeTTL: o.SessionLifetimeTTL,
244-
SessionValidTTL: o.SessionValidTTL,
245-
GracePeriodTTL: o.GracePeriodTTL,
239+
ClientID: o.ClientID,
240+
ClientSecret: o.ClientSecret,
241+
ProviderURL: providerURL,
242+
ProviderURLInternal: providerURLInternal,
243+
Scope: o.Scope,
244+
SessionLifetimeTTL: o.SessionLifetimeTTL,
245+
SessionValidTTL: o.SessionValidTTL,
246+
GracePeriodTTL: o.GracePeriodTTL,
246247
}
247248

248249
p := providers.New(o.Provider, providerData, o.StatsdClient)

internal/proxy/options_test.go

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,6 @@ func TestDefaultProviderApiSettings(t *testing.T) {
7171
p.SignOutURL.String())
7272
testutil.Equal(t, "https://www.example.com/redeem",
7373
p.RedeemURL.String())
74-
testutil.Equal(t, "https://www.example.com/redeem",
75-
p.ProxyRedeemURL.String())
7674
testutil.Equal(t, "https://www.example.com/validate",
7775
p.ValidateURL.String())
7876
testutil.Equal(t, "https://www.example.com/profile",
@@ -82,12 +80,12 @@ func TestDefaultProviderApiSettings(t *testing.T) {
8280

8381
func TestProviderURLValidation(t *testing.T) {
8482
testCases := []struct {
85-
name string
86-
providerURLString string
87-
proxyProviderURLString string
88-
expectedError string
89-
expectedProxyProviderURLString string
90-
expectedSignInURL string
83+
name string
84+
providerURLString string
85+
providerURLInternalString string
86+
expectedError string
87+
expectedProviderURLInternalString string
88+
expectedSignInURL string
9189
}{
9290
{
9391
name: "http scheme preserved",
@@ -100,15 +98,15 @@ func TestProviderURLValidation(t *testing.T) {
10098
expectedSignInURL: "https://provider.example.com/sign_in",
10199
},
102100
{
103-
name: "proxy provider url string based on providerURL",
104-
providerURLString: "https://provider.example.com",
105-
expectedProxyProviderURLString: "https://provider.example.com",
101+
name: "proxy provider url string based on providerURL",
102+
providerURLString: "https://provider.example.com",
103+
expectedProviderURLInternalString: "",
106104
},
107105
{
108-
name: "proxy provider url string based on proxyProviderURL",
109-
providerURLString: "https://provider.example.com",
110-
proxyProviderURLString: "https://provider-internal.example.com",
111-
expectedProxyProviderURLString: "https://provider-internal.example.com",
106+
name: "proxy provider url string based on proxyProviderURL",
107+
providerURLString: "https://provider.example.com",
108+
providerURLInternalString: "https://provider-internal.example.com",
109+
expectedProviderURLInternalString: "https://provider-internal.example.com",
112110
},
113111
{
114112
name: "scheme required",
@@ -136,7 +134,7 @@ func TestProviderURLValidation(t *testing.T) {
136134
t.Run(tc.name, func(t *testing.T) {
137135
o := testOptions()
138136
o.ProviderURLString = tc.providerURLString
139-
o.ProxyProviderURLString = tc.proxyProviderURLString
137+
o.ProviderURLInternalString = tc.providerURLInternalString
140138
err := o.Validate()
141139
if tc.expectedError != "" {
142140
if err == nil {
@@ -148,8 +146,8 @@ func TestProviderURLValidation(t *testing.T) {
148146
if tc.expectedSignInURL != "" {
149147
testutil.Equal(t, o.provider.Data().SignInURL.String(), tc.expectedSignInURL)
150148
}
151-
if tc.expectedProxyProviderURLString != "" {
152-
testutil.Equal(t, o.provider.Data().ProxyProviderURL.String(), tc.expectedProxyProviderURLString)
149+
if tc.expectedProviderURLInternalString != "" {
150+
testutil.Equal(t, o.provider.Data().ProviderURLInternal.String(), tc.expectedProviderURLInternalString)
153151
}
154152
})
155153
}

internal/proxy/providers/provider_data.go

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,18 @@ import (
88
// ProviderData holds the fields associated with providers
99
// necessary to implement the Provider interface.
1010
type ProviderData struct {
11-
ProviderName string
12-
ProviderURL *url.URL
13-
ProxyProviderURL *url.URL
14-
ClientID string
15-
ClientSecret string
16-
SignInURL *url.URL
17-
SignOutURL *url.URL
18-
RedeemURL *url.URL
19-
ProxyRedeemURL *url.URL
20-
RefreshURL *url.URL
21-
ProfileURL *url.URL
22-
ValidateURL *url.URL
23-
Scope string
11+
ProviderName string
12+
ProviderURL *url.URL
13+
ProviderURLInternal *url.URL
14+
ClientID string
15+
ClientSecret string
16+
SignInURL *url.URL
17+
SignOutURL *url.URL
18+
RedeemURL *url.URL
19+
RefreshURL *url.URL
20+
ProfileURL *url.URL
21+
ValidateURL *url.URL
22+
Scope string
2423

2524
SessionValidTTL time.Duration
2625
SessionLifetimeTTL time.Duration

internal/proxy/providers/sso.go

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -57,26 +57,34 @@ func init() {
5757
func NewSSOProvider(p *ProviderData, sc *statsd.Client) *SSOProvider {
5858
p.ProviderName = "SSO"
5959
base := p.ProviderURL
60+
internalBase := base
61+
62+
if p.ProviderURLInternal != nil {
63+
internalBase = p.ProviderURLInternal
64+
}
65+
6066
p.SignInURL = base.ResolveReference(&url.URL{Path: "/sign_in"})
6167
p.SignOutURL = base.ResolveReference(&url.URL{Path: "/sign_out"})
62-
p.RedeemURL = base.ResolveReference(&url.URL{Path: "/redeem"})
63-
p.RefreshURL = base.ResolveReference(&url.URL{Path: "/refresh"})
64-
p.ValidateURL = base.ResolveReference(&url.URL{Path: "/validate"})
65-
p.ProfileURL = base.ResolveReference(&url.URL{Path: "/profile"})
66-
p.ProxyRedeemURL = p.ProxyProviderURL.ResolveReference(&url.URL{Path: "/redeem"})
68+
69+
p.RedeemURL = internalBase.ResolveReference(&url.URL{Path: "/redeem"})
70+
p.RefreshURL = internalBase.ResolveReference(&url.URL{Path: "/refresh"})
71+
p.ValidateURL = internalBase.ResolveReference(&url.URL{Path: "/validate"})
72+
p.ProfileURL = internalBase.ResolveReference(&url.URL{Path: "/profile"})
73+
6774
return &SSOProvider{
6875
ProviderData: p,
6976
StatsdClient: sc,
7077
}
7178
}
7279

73-
func newRequest(method, url string, body io.Reader) (*http.Request, error) {
80+
func (p *SSOProvider) newRequest(method, url string, body io.Reader) (*http.Request, error) {
7481
req, err := http.NewRequest(method, url, body)
7582
if err != nil {
7683
return nil, err
7784
}
7885
req.Header.Set("User-Agent", userAgentString)
7986
req.Header.Set("Accept", "application/json")
87+
req.Host = p.ProviderData.ProviderURL.Host
8088
return req, nil
8189
}
8290

@@ -109,11 +117,10 @@ func (p *SSOProvider) Redeem(redirectURL, code string) (*SessionState, error) {
109117
params.Add("code", code)
110118
params.Add("grant_type", "authorization_code")
111119

112-
req, err := newRequest("POST", p.ProxyRedeemURL.String(), bytes.NewBufferString(params.Encode()))
120+
req, err := p.newRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
113121
if err != nil {
114122
return nil, err
115123
}
116-
req.Host = p.RedeemURL.Host
117124
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
118125
resp, err := httpClient.Do(req)
119126
if err != nil {
@@ -194,7 +201,7 @@ func (p *SSOProvider) UserGroups(email string, groups []string) ([]string, error
194201
params.Add("client_id", p.ClientID)
195202
params.Add("groups", strings.Join(groups, ","))
196203

197-
req, err := newRequest("GET", fmt.Sprintf("%s?%s", p.ProfileURL.String(), params.Encode()), nil)
204+
req, err := p.newRequest("GET", fmt.Sprintf("%s?%s", p.ProfileURL.String(), params.Encode()), nil)
198205
if err != nil {
199206
return nil, err
200207
}
@@ -284,7 +291,7 @@ func (p *SSOProvider) redeemRefreshToken(refreshToken string) (token string, exp
284291
params.Add("client_secret", p.ClientSecret)
285292
params.Add("refresh_token", refreshToken)
286293
var req *http.Request
287-
req, err = newRequest("POST", p.RefreshURL.String(), bytes.NewBufferString(params.Encode()))
294+
req, err = p.newRequest("POST", p.RefreshURL.String(), bytes.NewBufferString(params.Encode()))
288295
if err != nil {
289296
return
290297
}
@@ -329,7 +336,7 @@ func (p *SSOProvider) ValidateSessionState(s *SessionState, allowedGroups []stri
329336
// we validate the user's access token is valid
330337
params := url.Values{}
331338
params.Add("client_id", p.ClientID)
332-
req, err := newRequest("GET", fmt.Sprintf("%s?%s", p.ValidateURL.String(), params.Encode()), nil)
339+
req, err := p.newRequest("GET", fmt.Sprintf("%s?%s", p.ValidateURL.String(), params.Encode()), nil)
333340
if err != nil {
334341
logger.WithUser(s.Email).Error(err, "error validating session state")
335342
return false

0 commit comments

Comments
 (0)