Skip to content

Commit c767e4a

Browse files
authored
feat: support OAuth client credentials grant (#115)
1 parent bbb0d20 commit c767e4a

File tree

5 files changed

+166
-32
lines changed

5 files changed

+166
-32
lines changed

pkg/app/app.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ import (
1212
"github.com/snyk/go-application-framework/internal/api"
1313
"github.com/snyk/go-application-framework/internal/constants"
1414
"github.com/snyk/go-application-framework/internal/utils"
15-
"github.com/snyk/go-application-framework/pkg/auth"
1615
"github.com/snyk/go-application-framework/pkg/configuration"
1716
localworkflows "github.com/snyk/go-application-framework/pkg/local_workflows"
1817
"github.com/snyk/go-application-framework/pkg/workflow"
@@ -88,7 +87,7 @@ func initConfiguration(engine workflow.Engine, config configuration.Configuratio
8887

8988
config.AddDefaultValue(configuration.FF_OAUTH_AUTH_FLOW_ENABLED, func(existingValue any) any {
9089
if existingValue == nil {
91-
return auth.IsKnownOAuthEndpoint(config.GetString(configuration.API_URL))
90+
return true
9291
} else {
9392
return existingValue
9493
}

pkg/app/app_test.go

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -157,24 +157,6 @@ func Test_CreateAppEngineWithRuntimeInfo(t *testing.T) {
157157
assert.Equal(t, ri, engine.GetRuntimeInfo())
158158
}
159159

160-
func Test_initConfiguration_existingValueOfOAuthFFRespected(t *testing.T) {
161-
existingValue := false
162-
endpoint := "https://snykgov.io"
163-
164-
// setup mock
165-
ctrl := gomock.NewController(t)
166-
mockApiClient := mocks.NewMockApiClient(ctrl)
167-
168-
config := configuration.NewInMemory()
169-
initConfiguration(workflow.NewWorkFlowEngine(config), config, mockApiClient, &zlog.Logger)
170-
171-
config.Set(configuration.FF_OAUTH_AUTH_FLOW_ENABLED, existingValue)
172-
config.Set(configuration.API_URL, endpoint)
173-
174-
actualOAuthFF := config.GetBool(configuration.FF_OAUTH_AUTH_FLOW_ENABLED)
175-
assert.Equal(t, existingValue, actualOAuthFF)
176-
}
177-
178160
func Test_initConfiguration_snykgov(t *testing.T) {
179161
endpoint := "https://snykgov.io"
180162

@@ -206,9 +188,6 @@ func Test_initConfiguration_NOT_snykgov(t *testing.T) {
206188

207189
config.Set(configuration.API_URL, endpoint)
208190

209-
actualOAuthFF := config.GetBool(configuration.FF_OAUTH_AUTH_FLOW_ENABLED)
210-
assert.False(t, actualOAuthFF)
211-
212191
isFedramp := config.GetBool(configuration.IS_FEDRAMP)
213192
assert.False(t, isFedramp)
214193
}

pkg/auth/oauth2authenticator.go

Lines changed: 85 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,27 @@ import (
1818

1919
"github.com/pkg/browser"
2020
"golang.org/x/oauth2"
21+
"golang.org/x/oauth2/clientcredentials"
2122

2223
"github.com/snyk/go-application-framework/pkg/configuration"
2324
)
2425

2526
const (
26-
CONFIG_KEY_OAUTH_TOKEN string = "INTERNAL_OAUTH_TOKEN_STORAGE"
27-
OAUTH_CLIENT_ID string = "b56d4c2e-b9e1-4d27-8773-ad47eafb0956"
28-
CALLBACK_HOSTNAME string = "127.0.0.1"
29-
CALLBACK_PATH string = "/authorization-code/callback"
30-
TIMEOUT_SECONDS time.Duration = 120 * time.Second
31-
AUTHENTICATED_MESSAGE = "Your account has been authenticated."
27+
CONFIG_KEY_OAUTH_TOKEN string = "INTERNAL_OAUTH_TOKEN_STORAGE"
28+
OAUTH_CLIENT_ID string = "b56d4c2e-b9e1-4d27-8773-ad47eafb0956"
29+
CALLBACK_HOSTNAME string = "127.0.0.1"
30+
CALLBACK_PATH string = "/authorization-code/callback"
31+
TIMEOUT_SECONDS time.Duration = 120 * time.Second
32+
AUTHENTICATED_MESSAGE = "Your account has been authenticated."
33+
PARAMETER_CLIENT_ID string = "client-id"
34+
PARAMETER_CLIENT_SECRET string = "client-secret"
35+
)
36+
37+
type GrantType int
38+
39+
const (
40+
ClientCredentialsGrant GrantType = iota
41+
AuthorizationCodeGrant
3242
)
3343

3444
var _ Authenticator = (*oAuth2Authenticator)(nil)
@@ -42,6 +52,7 @@ type oAuth2Authenticator struct {
4252
oauthConfig *oauth2.Config
4353
token *oauth2.Token
4454
headless bool
55+
grantType GrantType
4556
openBrowserFunc func(authUrl string)
4657
shutdownServerFunc func(server *http.Server)
4758
tokenRefresherFunc func(ctx context.Context, oauthConfig *oauth2.Config, token *oauth2.Token) (*oauth2.Token, error)
@@ -89,6 +100,21 @@ func getOAuthConfiguration(config configuration.Configuration) *oauth2.Config {
89100
AuthURL: authUrl,
90101
},
91102
}
103+
104+
if determineGrantType(config) == ClientCredentialsGrant {
105+
conf.ClientID = config.GetString(PARAMETER_CLIENT_ID)
106+
conf.ClientSecret = config.GetString(PARAMETER_CLIENT_SECRET)
107+
}
108+
109+
return conf
110+
}
111+
112+
func getOAuthConfigurationClientCredentials(in *oauth2.Config) *clientcredentials.Config {
113+
conf := &clientcredentials.Config{
114+
ClientID: in.ClientID,
115+
ClientSecret: in.ClientSecret,
116+
TokenURL: in.Endpoint.TokenURL,
117+
}
92118
return conf
93119
}
94120

@@ -138,6 +164,20 @@ func RefreshToken(ctx context.Context, oauthConfig *oauth2.Config, token *oauth2
138164
return tokenSource.Token()
139165
}
140166

167+
func refreshTokenClientCredentials(ctx context.Context, oauthConfig *oauth2.Config, token *oauth2.Token) (*oauth2.Token, error) {
168+
conf := getOAuthConfigurationClientCredentials(oauthConfig)
169+
tokenSource := conf.TokenSource(ctx)
170+
return tokenSource.Token()
171+
}
172+
173+
func determineGrantType(config configuration.Configuration) GrantType {
174+
grantType := AuthorizationCodeGrant
175+
if config.IsSet(PARAMETER_CLIENT_SECRET) && config.IsSet(PARAMETER_CLIENT_ID) {
176+
grantType = ClientCredentialsGrant
177+
}
178+
return grantType
179+
}
180+
141181
//goland:noinspection GoUnusedExportedFunction
142182
func NewOAuth2Authenticator(config configuration.Configuration, httpClient *http.Client) Authenticator {
143183
return NewOAuth2AuthenticatorWithOpts(config, WithHttpClient(httpClient))
@@ -154,7 +194,14 @@ func NewOAuth2AuthenticatorWithOpts(config configuration.Configuration, opts ...
154194
o.httpClient = http.DefaultClient
155195
o.openBrowserFunc = OpenBrowser
156196
o.shutdownServerFunc = ShutdownServer
157-
o.tokenRefresherFunc = RefreshToken
197+
o.grantType = determineGrantType(config)
198+
199+
// set refresh function depending on grant type
200+
if o.grantType == ClientCredentialsGrant {
201+
o.tokenRefresherFunc = refreshTokenClientCredentials
202+
} else {
203+
o.tokenRefresherFunc = RefreshToken
204+
}
158205

159206
// apply options
160207
for _, opt := range opts {
@@ -193,6 +240,37 @@ func (o *oAuth2Authenticator) persistToken(token *oauth2.Token) {
193240
}
194241

195242
func (o *oAuth2Authenticator) Authenticate() error {
243+
var err error
244+
245+
if o.grantType == ClientCredentialsGrant {
246+
err = o.authenticateWithClientCredentialsGrant()
247+
} else {
248+
err = o.authenticateWithAuthorizationCode()
249+
}
250+
251+
return err
252+
}
253+
254+
func (o *oAuth2Authenticator) authenticateWithClientCredentialsGrant() error {
255+
ctx := context.Background()
256+
config := getOAuthConfigurationClientCredentials(o.oauthConfig)
257+
258+
// Use the custom HTTP client when requesting a token.
259+
if o.httpClient != nil {
260+
ctx = context.WithValue(ctx, oauth2.HTTPClient, o.httpClient)
261+
}
262+
263+
// get token
264+
token, err := config.Token(ctx)
265+
if err != nil {
266+
return err
267+
}
268+
269+
o.persistToken(token)
270+
return err
271+
}
272+
273+
func (o *oAuth2Authenticator) authenticateWithAuthorizationCode() error {
196274
var responseCode string
197275
var responseState string
198276
var responseError string

pkg/auth/oauth2authenticator_test.go

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@ import (
77
"testing"
88
"time"
99

10-
"github.com/snyk/go-application-framework/pkg/configuration"
1110
"github.com/stretchr/testify/assert"
1211
"golang.org/x/oauth2"
12+
13+
"github.com/snyk/go-application-framework/pkg/configuration"
1314
)
1415

1516
func Test_GetVerifier(t *testing.T) {
@@ -205,3 +206,77 @@ func Test_AddAuthenticationHeader_expiredToken_somebodyUpdated(t *testing.T) {
205206
assert.Equal(t, *newToken, *actualToken)
206207
assert.Equal(t, *newToken, *authenticator.(*oAuth2Authenticator).token)
207208
}
209+
210+
func Test_determineGrantType_empty(t *testing.T) {
211+
config := configuration.NewInMemory()
212+
expected := AuthorizationCodeGrant
213+
actual := determineGrantType(config)
214+
assert.Equal(t, expected, actual)
215+
}
216+
217+
func Test_determineGrantType_secret_only(t *testing.T) {
218+
config := configuration.NewInMemory()
219+
config.Set(PARAMETER_CLIENT_SECRET, "secret")
220+
expected := AuthorizationCodeGrant
221+
actual := determineGrantType(config)
222+
assert.Equal(t, expected, actual)
223+
}
224+
225+
func Test_determineGrantType_id_only(t *testing.T) {
226+
config := configuration.NewInMemory()
227+
config.Set(PARAMETER_CLIENT_ID, "id")
228+
expected := AuthorizationCodeGrant
229+
actual := determineGrantType(config)
230+
assert.Equal(t, expected, actual)
231+
}
232+
233+
func Test_determineGrantType_both(t *testing.T) {
234+
config := configuration.NewInMemory()
235+
config.Set(PARAMETER_CLIENT_ID, "id")
236+
config.Set(PARAMETER_CLIENT_SECRET, "secret")
237+
expected := ClientCredentialsGrant
238+
actual := determineGrantType(config)
239+
assert.Equal(t, expected, actual)
240+
}
241+
242+
func Test_Authenticate_CredentialsGrant(t *testing.T) {
243+
go func() {
244+
mux := http.NewServeMux()
245+
srv := &http.Server{
246+
Handler: mux,
247+
Addr: "localhost:3221",
248+
}
249+
mux.HandleFunc("/oauth2/token", func(w http.ResponseWriter, r *http.Request) {
250+
newToken := &oauth2.Token{
251+
AccessToken: "a",
252+
TokenType: "b",
253+
Expiry: time.Now().Add(60 * time.Second).UTC(),
254+
}
255+
data, err := json.Marshal(newToken)
256+
assert.Nil(t, err)
257+
258+
w.Header().Set("Content-Type", "application/json;charset=UTF-8")
259+
_, err = w.Write(data)
260+
assert.Nil(t, err)
261+
})
262+
263+
timer := time.AfterFunc(3*time.Second, func() {
264+
srv.Shutdown(context.Background())
265+
})
266+
267+
srv.ListenAndServe()
268+
timer.Stop()
269+
}()
270+
271+
config := configuration.NewInMemory()
272+
config.Set(PARAMETER_CLIENT_SECRET, "secret")
273+
config.Set(PARAMETER_CLIENT_ID, "id")
274+
config.Set(configuration.API_URL, "http://localhost:3221")
275+
276+
authenticator := NewOAuth2AuthenticatorWithOpts(config, WithHttpClient(http.DefaultClient))
277+
err := authenticator.Authenticate()
278+
assert.Nil(t, err)
279+
280+
token := config.GetString(CONFIG_KEY_OAUTH_TOKEN)
281+
assert.NotEmpty(t, token)
282+
}

pkg/local_workflows/auth_workflow.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@ import (
44
"fmt"
55
"os"
66

7+
"github.com/spf13/pflag"
8+
79
"github.com/snyk/go-application-framework/pkg/auth"
810
"github.com/snyk/go-application-framework/pkg/configuration"
911
"github.com/snyk/go-application-framework/pkg/workflow"
10-
"github.com/spf13/pflag"
1112
)
1213

1314
const (
@@ -40,6 +41,8 @@ func InitAuth(engine workflow.Engine) error {
4041
config := pflag.NewFlagSet(workflowNameAuth, pflag.ExitOnError)
4142
config.String(authTypeParameter, "", authTypeDescription)
4243
config.Bool(headlessFlag, false, "Enable headless OAuth authentication")
44+
config.String(auth.PARAMETER_CLIENT_SECRET, "", "Client Credential Grant, client secret")
45+
config.String(auth.PARAMETER_CLIENT_ID, "", "Client Credential Grant, client id")
4346

4447
_, err := engine.Register(WORKFLOWID_AUTH, workflow.ConfigurationOptionsFromFlagset(config), authEntryPoint)
4548
return err

0 commit comments

Comments
 (0)