Skip to content

feat: support OAuth client credentials grant #115

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Dec 22, 2023
7 changes: 1 addition & 6 deletions pkg/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"github.com/snyk/go-application-framework/internal/api"
"github.com/snyk/go-application-framework/internal/constants"
"github.com/snyk/go-application-framework/internal/utils"
"github.com/snyk/go-application-framework/pkg/auth"
"github.com/snyk/go-application-framework/pkg/configuration"
localworkflows "github.com/snyk/go-application-framework/pkg/local_workflows"
"github.com/snyk/go-application-framework/pkg/workflow"
Expand Down Expand Up @@ -87,11 +86,7 @@ func initConfiguration(engine workflow.Engine, config configuration.Configuratio
})

config.AddDefaultValue(configuration.FF_OAUTH_AUTH_FLOW_ENABLED, func(existingValue any) any {
if existingValue == nil {
return auth.IsKnownOAuthEndpoint(config.GetString(configuration.API_URL))
} else {
return existingValue
}
return true
})

config.AddDefaultValue(configuration.IS_FEDRAMP, func(existingValue any) any {
Expand Down
21 changes: 0 additions & 21 deletions pkg/app/app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,24 +157,6 @@ func Test_CreateAppEngineWithRuntimeInfo(t *testing.T) {
assert.Equal(t, ri, engine.GetRuntimeInfo())
}

func Test_initConfiguration_existingValueOfOAuthFFRespected(t *testing.T) {
existingValue := false
endpoint := "https://snykgov.io"

// setup mock
ctrl := gomock.NewController(t)
mockApiClient := mocks.NewMockApiClient(ctrl)

config := configuration.NewInMemory()
initConfiguration(workflow.NewWorkFlowEngine(config), config, mockApiClient, &zlog.Logger)

config.Set(configuration.FF_OAUTH_AUTH_FLOW_ENABLED, existingValue)
config.Set(configuration.API_URL, endpoint)

actualOAuthFF := config.GetBool(configuration.FF_OAUTH_AUTH_FLOW_ENABLED)
assert.Equal(t, existingValue, actualOAuthFF)
}

func Test_initConfiguration_snykgov(t *testing.T) {
endpoint := "https://snykgov.io"

Expand Down Expand Up @@ -206,9 +188,6 @@ func Test_initConfiguration_NOT_snykgov(t *testing.T) {

config.Set(configuration.API_URL, endpoint)

actualOAuthFF := config.GetBool(configuration.FF_OAUTH_AUTH_FLOW_ENABLED)
assert.False(t, actualOAuthFF)

isFedramp := config.GetBool(configuration.IS_FEDRAMP)
assert.False(t, isFedramp)
}
91 changes: 84 additions & 7 deletions pkg/auth/oauth2authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,27 @@ import (

"github.com/pkg/browser"
"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"

"github.com/snyk/go-application-framework/pkg/configuration"
)

const (
CONFIG_KEY_OAUTH_TOKEN string = "INTERNAL_OAUTH_TOKEN_STORAGE"
OAUTH_CLIENT_ID string = "b56d4c2e-b9e1-4d27-8773-ad47eafb0956"
CALLBACK_HOSTNAME string = "127.0.0.1"
CALLBACK_PATH string = "/authorization-code/callback"
TIMEOUT_SECONDS time.Duration = 120 * time.Second
AUTHENTICATED_MESSAGE = "Your account has been authenticated."
CONFIG_KEY_OAUTH_TOKEN string = "INTERNAL_OAUTH_TOKEN_STORAGE"
OAUTH_CLIENT_ID string = "b56d4c2e-b9e1-4d27-8773-ad47eafb0956"
CALLBACK_HOSTNAME string = "127.0.0.1"
CALLBACK_PATH string = "/authorization-code/callback"
TIMEOUT_SECONDS time.Duration = 120 * time.Second
AUTHENTICATED_MESSAGE = "Your account has been authenticated."
PARAMETER_CLIENT_ID string = "client-id"
PARAMETER_CLIENT_SECRET string = "client-secret"
)

type GrantType int

const (
ClientCredentialsGrant GrantType = iota
AuthorizationCodeGrant
)

var _ Authenticator = (*oAuth2Authenticator)(nil)
Expand All @@ -42,6 +52,7 @@ type oAuth2Authenticator struct {
oauthConfig *oauth2.Config
token *oauth2.Token
headless bool
grantType GrantType
openBrowserFunc func(authUrl string)
shutdownServerFunc func(server *http.Server)
tokenRefresherFunc func(ctx context.Context, oauthConfig *oauth2.Config, token *oauth2.Token) (*oauth2.Token, error)
Expand Down Expand Up @@ -89,6 +100,21 @@ func getOAuthConfiguration(config configuration.Configuration) *oauth2.Config {
AuthURL: authUrl,
},
}

if determineGrantType(config) == ClientCredentialsGrant {
conf.ClientID = config.GetString(PARAMETER_CLIENT_ID)
conf.ClientSecret = config.GetString(PARAMETER_CLIENT_SECRET)
}

return conf
}

func getOAuthConfigurationClientCredentials(in *oauth2.Config) *clientcredentials.Config {
conf := &clientcredentials.Config{
ClientID: in.ClientID,
ClientSecret: in.ClientSecret,
TokenURL: in.Endpoint.TokenURL,
}
return conf
}

Expand Down Expand Up @@ -138,6 +164,20 @@ func RefreshToken(ctx context.Context, oauthConfig *oauth2.Config, token *oauth2
return tokenSource.Token()
}

func refreshTokenClientCredentials(ctx context.Context, oauthConfig *oauth2.Config, token *oauth2.Token) (*oauth2.Token, error) {
conf := getOAuthConfigurationClientCredentials(oauthConfig)
tokenSource := conf.TokenSource(ctx)
return tokenSource.Token()
}

func determineGrantType(config configuration.Configuration) GrantType {
grantType := AuthorizationCodeGrant
if config.IsSet(PARAMETER_CLIENT_SECRET) && config.IsSet(PARAMETER_CLIENT_ID) {
grantType = ClientCredentialsGrant
}
return grantType
}

//goland:noinspection GoUnusedExportedFunction
func NewOAuth2Authenticator(config configuration.Configuration, httpClient *http.Client) Authenticator {
return NewOAuth2AuthenticatorWithOpts(config, WithHttpClient(httpClient))
Expand All @@ -154,7 +194,14 @@ func NewOAuth2AuthenticatorWithOpts(config configuration.Configuration, opts ...
o.httpClient = http.DefaultClient
o.openBrowserFunc = OpenBrowser
o.shutdownServerFunc = ShutdownServer
o.tokenRefresherFunc = RefreshToken
o.grantType = determineGrantType(config)

// set refresh function depending on grant type
if o.grantType == ClientCredentialsGrant {
o.tokenRefresherFunc = refreshTokenClientCredentials
} else {
o.tokenRefresherFunc = RefreshToken
}

// apply options
for _, opt := range opts {
Expand Down Expand Up @@ -193,6 +240,36 @@ func (o *oAuth2Authenticator) persistToken(token *oauth2.Token) {
}

func (o *oAuth2Authenticator) Authenticate() error {
var err error

if o.grantType == ClientCredentialsGrant {
err = o.authenticateWithClientCredentialsGrant()
} else {
err = o.authenticateWithAuthorizationCode()
}

return err
}

func (o *oAuth2Authenticator) authenticateWithClientCredentialsGrant() error {
ctx := context.Background()
config := getOAuthConfigurationClientCredentials(o.oauthConfig)

// Use the custom HTTP client when requesting a token.
if o.httpClient != nil {
ctx = context.WithValue(ctx, oauth2.HTTPClient, o.httpClient)
}

// get token
token, err := config.Token(ctx)
if err == nil {
o.persistToken(token)
}

return err
}

func (o *oAuth2Authenticator) authenticateWithAuthorizationCode() error {
var responseCode string
var responseState string
var responseError string
Expand Down
62 changes: 61 additions & 1 deletion pkg/auth/oauth2authenticator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ import (
"testing"
"time"

"github.com/snyk/go-application-framework/pkg/configuration"
"github.com/stretchr/testify/assert"
"golang.org/x/oauth2"

"github.com/snyk/go-application-framework/pkg/configuration"
)

func Test_GetVerifier(t *testing.T) {
Expand Down Expand Up @@ -205,3 +206,62 @@ func Test_AddAuthenticationHeader_expiredToken_somebodyUpdated(t *testing.T) {
assert.Equal(t, *newToken, *actualToken)
assert.Equal(t, *newToken, *authenticator.(*oAuth2Authenticator).token)
}

func Test_determineGrantType(t *testing.T) {
config := configuration.NewInMemory()
expected := AuthorizationCodeGrant
actual := determineGrantType(config)
assert.Equal(t, expected, actual)

config.Set(PARAMETER_CLIENT_SECRET, "secret")
expected = AuthorizationCodeGrant
actual = determineGrantType(config)
assert.Equal(t, expected, actual)

config.Set(PARAMETER_CLIENT_ID, "id")
expected = ClientCredentialsGrant
actual = determineGrantType(config)
assert.Equal(t, expected, actual)
}

func Test_Authenticate_CredentialsGrant(t *testing.T) {
go func() {
mux := http.NewServeMux()
srv := &http.Server{
Handler: mux,
Addr: "localhost:3221",
}
mux.HandleFunc("/oauth2/token", func(w http.ResponseWriter, r *http.Request) {
newToken := &oauth2.Token{
AccessToken: "a",
TokenType: "b",
Expiry: time.Now().Add(60 * time.Second).UTC(),
}
data, err := json.Marshal(newToken)
assert.Nil(t, err)

w.Header().Set("Content-Type", "application/json;charset=UTF-8")
_, err = w.Write(data)
assert.Nil(t, err)
})

timer := time.AfterFunc(3*time.Second, func() {
srv.Shutdown(context.Background())
})

srv.ListenAndServe()
timer.Stop()
}()

config := configuration.NewInMemory()
config.Set(PARAMETER_CLIENT_SECRET, "secret")
config.Set(PARAMETER_CLIENT_ID, "id")
config.Set(configuration.API_URL, "http://localhost:3221")

authenticator := NewOAuth2AuthenticatorWithOpts(config, WithHttpClient(http.DefaultClient))
err := authenticator.Authenticate()
assert.Nil(t, err)

token := config.GetString(CONFIG_KEY_OAUTH_TOKEN)
assert.NotEmpty(t, token)
}
5 changes: 4 additions & 1 deletion pkg/local_workflows/auth_workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ import (
"fmt"
"os"

"github.com/spf13/pflag"

"github.com/snyk/go-application-framework/pkg/auth"
"github.com/snyk/go-application-framework/pkg/configuration"
"github.com/snyk/go-application-framework/pkg/workflow"
"github.com/spf13/pflag"
)

const (
Expand Down Expand Up @@ -40,6 +41,8 @@ func InitAuth(engine workflow.Engine) error {
config := pflag.NewFlagSet(workflowNameAuth, pflag.ExitOnError)
config.String(authTypeParameter, "", authTypeDescription)
config.Bool(headlessFlag, false, "Enable headless OAuth authentication")
config.String(auth.PARAMETER_CLIENT_SECRET, "", "Client Credential Grant, client secret")
config.String(auth.PARAMETER_CLIENT_ID, "", "Client Credential Grant, client id")

_, err := engine.Register(WORKFLOWID_AUTH, workflow.ConfigurationOptionsFromFlagset(config), authEntryPoint)
return err
Expand Down