Skip to content

Commit 08a0f72

Browse files
author
Chris Gilmer
committed
Add tests for cli
1 parent 7ba7728 commit 08a0f72

File tree

9 files changed

+223
-50
lines changed

9 files changed

+223
-50
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
bin/
22
dist/
3+
coverage.out

Makefile

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ bin/setup-new-aws-user: ## Build setup-new-aws-user
1919
test:
2020
go test -v ./cmd/...
2121

22+
.PHONY: test_coverage
23+
test_coverage:
24+
go test -v -coverprofile=coverage.out -covermode=count ./cmd/...
25+
go tool cover -html=coverage.out
26+
2227
.PHONY: clean
2328
clean:
2429
rm -f .*.stamp

cmd/add_profile.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@ import (
1616
"gopkg.in/ini.v1"
1717
)
1818

19-
func addProfileInitFlags(flag *pflag.FlagSet) {
19+
// AddProfileInitFlags sets up the CLI flags for the 'add-profile' subcommand
20+
func AddProfileInitFlags(flag *pflag.FlagSet) {
2021

2122
flag.String(VaultAWSKeychainNameFlag, VaultAWSKeychainNameDefault, "The aws-vault keychain name")
22-
flag.String(VaultAWSProfileFlag, "", "The aws-vault profile name")
23-
flag.String(VaultAWSNewProfileFlag, "", "A comma separated list of new AWS 'PROFILE1:ACCOUNTID1,PROFILE2:ACCOUNTID2,...'")
23+
flag.StringSlice(AWSProfileAccountFlag, []string{}, "A comma separated list of AWS profiles and account IDs 'PROFILE1:ACCOUNTID1,PROFILE2:ACCOUNTID2,...'")
2424
flag.String(AWSRegionFlag, endpoints.UsWest2RegionID, "The AWS region")
2525
flag.String(IAMUserFlag, "", "The IAM user name to setup")
2626
flag.String(IAMRoleFlag, "", "The IAM role name assigned to the user being setup")
@@ -32,7 +32,8 @@ func addProfileInitFlags(flag *pflag.FlagSet) {
3232
flag.SortFlags = false
3333
}
3434

35-
func addProfileCheckConfig(v *viper.Viper) error {
35+
// AddProfileCheckConfig checks the CLI flag configuration for the 'add-profile' subcommand
36+
func AddProfileCheckConfig(v *viper.Viper) error {
3637

3738
if err := checkVault(v); err != nil {
3839
return fmt.Errorf("aws-vault check failed: %w", err)
@@ -84,7 +85,7 @@ func (sc *SetupConfig) AddProfile() error {
8485
}
8586
mfaSerial := mfaSerialKey.String()
8687

87-
for _, element := range *sc.NewProfiles {
88+
for _, element := range sc.NewProfiles {
8889
profileName := strings.Split(element, ":")[0]
8990
awsAccountID := strings.Split(element, ":")[1]
9091

@@ -145,16 +146,15 @@ func addProfileFunction(cmd *cobra.Command, args []string) error {
145146
}
146147

147148
// Check the config and exit with usage details if there is a problem
148-
checkConfigErr := addProfileCheckConfig(v)
149+
checkConfigErr := AddProfileCheckConfig(v)
149150
if checkConfigErr != nil {
150151
return checkConfigErr
151152
}
152153

153154
// Get command line flag values
154155
awsRegion := v.GetString(AWSRegionFlag)
155156
awsVaultKeychainName := v.GetString(VaultAWSKeychainNameFlag)
156-
awsVaultProfile := v.GetString(VaultAWSProfileFlag)
157-
awsVaultNewProfile := v.GetStringSlice(VaultAWSNewProfileFlag)
157+
awsVaultProfileAccount := v.GetStringSlice(AWSProfileAccountFlag)
158158
iamUser := v.GetString(IAMUserFlag)
159159
iamRole := v.GetString(IAMRoleFlag)
160160
output := v.GetString(OutputFlag)
@@ -184,8 +184,8 @@ func addProfileFunction(cmd *cobra.Command, args []string) error {
184184
Role: iamRole,
185185
Region: awsRegion,
186186
Partition: partition,
187-
RoleProfileName: &awsVaultProfile,
188-
NewProfiles: &awsVaultNewProfile,
187+
RoleProfileName: &awsVaultProfileAccount[0],
188+
NewProfiles: awsVaultProfileAccount[1:],
189189
Output: output,
190190
Config: config,
191191
Keyring: keyring,

cmd/cli.go

Lines changed: 65 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package main
33
import (
44
"fmt"
55
"regexp"
6+
"strings"
67

78
"github.com/aws/aws-sdk-go/aws/endpoints"
89
"github.com/spf13/viper"
@@ -20,10 +21,11 @@ const (
2021
VaultAWSKeychainNameDefault string = "login"
2122
// VaultAWSProfileFlag is the aws-vault profile name Flag
2223
VaultAWSProfileFlag string = "aws-profile"
23-
// VaultAWSBaseProfileFlag is the aws-vault base profile name Flag
24-
VaultAWSBaseProfileFlag string = "aws-base-profile"
25-
// VaultAWSNewProfileFlag is the aws-vault flag to add new profiles
26-
VaultAWSNewProfileFlag string = "aws-new-profile"
24+
25+
// AWSProfileAccountFlag is the combined AWS profile name and account ID Flag
26+
AWSProfileAccountFlag string = "aws-profile-account"
27+
//// AWSBaseProfileFlag is the AWS base profile name Flag
28+
//AWSBaseProfileFlag string = "aws-base-profile"
2729

2830
// IAMUserFlag is the IAM User name Flag
2931
IAMUserFlag string = "iam-user"
@@ -57,22 +59,13 @@ func (e *errInvalidKeychainName) Error() string {
5759
return fmt.Sprintf("invalid keychain name '%s'", e.KeychainName)
5860
}
5961

60-
type errInvalidAWSProfile struct {
61-
Profile string
62-
}
63-
64-
func (e *errInvalidAWSProfile) Error() string {
65-
return fmt.Sprintf("invalid aws profile '%s'", e.Profile)
66-
}
67-
68-
type errInvalidVault struct {
69-
KeychainName string
70-
Profile string
71-
}
72-
73-
func (e *errInvalidVault) Error() string {
74-
return fmt.Sprintf("invalid keychain name %q or profile %q", e.KeychainName, e.Profile)
75-
}
62+
// type errInvalidAWSProfile struct {
63+
// Profile string
64+
// }
65+
//
66+
// func (e *errInvalidAWSProfile) Error() string {
67+
// return fmt.Sprintf("invalid aws profile '%s'", e.Profile)
68+
// }
7669

7770
func checkVault(v *viper.Viper) error {
7871
// Both keychain name and profile are required or both must be missing
@@ -84,10 +77,10 @@ func checkVault(v *viper.Viper) error {
8477
return fmt.Errorf("%s is invalid, expected %v: %w", VaultAWSKeychainNameFlag, keychainNames, &errInvalidKeychainName{KeychainName: keychainName})
8578
}
8679

87-
awsProfile := v.GetString(VaultAWSProfileFlag)
88-
if len(awsProfile) == 0 {
89-
return fmt.Errorf("%s must not be empty: %w", VaultAWSProfileFlag, &errInvalidAWSProfile{Profile: awsProfile})
90-
}
80+
// awsProfile := v.GetString(VaultAWSProfileFlag)
81+
// if len(awsProfile) == 0 {
82+
// return fmt.Errorf("%s must not be empty: %w", VaultAWSProfileFlag, &errInvalidAWSProfile{Profile: awsProfile})
83+
// }
9184

9285
return nil
9386
}
@@ -100,6 +93,7 @@ func (e *errInvalidRegion) Error() string {
10093
return fmt.Sprintf("invalid region %q", e.Region)
10194
}
10295

96+
// Note: Testing the partition is not really the best check here, but its sufficient
10397
func checkRegion(v *viper.Viper) error {
10498

10599
r := v.GetString(AWSRegionFlag)
@@ -118,15 +112,60 @@ func (e *errInvalidAccountID) Error() string {
118112
return fmt.Sprintf("invalid Account ID %q", e.AccountID)
119113
}
120114

121-
func checkAccountID(v *viper.Viper) error {
122-
id := v.GetString(AWSAccountIDFlag)
115+
func checkAccountID(id string) error {
123116
if matched, err := regexp.Match(`^\d{12}$`, []byte(id)); !matched || err != nil {
124117
return fmt.Errorf("%s must be a 12 digit number: %w", AWSAccountIDFlag, &errInvalidAccountID{AccountID: id})
125118
}
126119

127120
return nil
128121
}
129122

123+
type errInvalidProfileName struct {
124+
ProfileName string
125+
}
126+
127+
func (e *errInvalidProfileName) Error() string {
128+
return fmt.Sprintf("invalid Account ID %q", e.ProfileName)
129+
}
130+
131+
func checkProfileName(profileName string) error {
132+
if matched, err := regexp.Match(`[A-Za-z0-9\-\_]+`, []byte(profileName)); !matched || err != nil {
133+
return fmt.Errorf("AWS Profile Name must be can only contain letters, numbers, hyphens, and underscores: %w", &errInvalidProfileName{ProfileName: profileName})
134+
}
135+
136+
return nil
137+
}
138+
139+
type errInvalidProfileAccount struct {
140+
ProfileAccount string
141+
}
142+
143+
func (e *errInvalidProfileAccount) Error() string {
144+
return fmt.Sprintf("invalid Profile Name and Account ID %q", e.ProfileAccount)
145+
}
146+
147+
func checkProfileAccount(v *viper.Viper) error {
148+
profileAccounts := v.GetStringSlice(AWSProfileAccountFlag)
149+
for _, profileAccount := range profileAccounts {
150+
// Validate the profile name and account are separated by a colon
151+
if !strings.Contains(profileAccount, ":") {
152+
return fmt.Errorf("Each Profile Name and Account ID must be separated by a colon ':': %w", &errInvalidProfileAccount{ProfileAccount: profileAccount})
153+
}
154+
// Split out the profile name and account ID
155+
profileAccountParts := strings.Split(profileAccount, ":")
156+
profileName := profileAccountParts[0]
157+
accountID := profileAccountParts[1]
158+
159+
if err := checkProfileName(profileName); err != nil {
160+
return err
161+
}
162+
if err := checkAccountID(accountID); err != nil {
163+
return err
164+
}
165+
}
166+
return nil
167+
}
168+
130169
type errInvalidIAMUser struct {
131170
IAMUser string
132171
}

cmd/cli_test.go

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88

99
"github.com/aws/aws-sdk-go/aws/endpoints"
1010
"github.com/spf13/viper"
11+
"github.com/stretchr/testify/assert"
1112
"github.com/stretchr/testify/suite"
1213
)
1314

@@ -95,24 +96,64 @@ func (suite *cliTestSuite) TestCheckRegion() {
9596
}
9697
}
9798

99+
func TestCheckProfileName(t *testing.T) {
100+
testValues := []string{
101+
"test-id",
102+
"test-id1",
103+
"test-id_1",
104+
}
105+
for _, testValue := range testValues {
106+
assert.NoError(t, checkProfileName(testValue))
107+
}
108+
testValuesWithErrors := []string{
109+
"",
110+
}
111+
for _, testValue := range testValuesWithErrors {
112+
assert.Error(t, checkProfileName(testValue))
113+
}
114+
}
115+
98116
func (suite *cliTestSuite) TestCheckAccountID() {
99117
suite.Setup()
100118
testValues := []string{
101119
"012345678901",
102120
"123456789012",
103121
}
104122
for _, testValue := range testValues {
105-
suite.viper.Set(AWSAccountIDFlag, testValue)
106-
suite.NoError(checkAccountID(suite.viper))
123+
suite.NoError(checkAccountID(testValue))
107124
}
108125
testValuesWithErrors := []string{
109126
"",
110127
"12345678901",
111128
"1234567890123",
112129
}
113130
for _, testValue := range testValuesWithErrors {
114-
suite.viper.Set(AWSAccountIDFlag, testValue)
115-
suite.Error(checkAccountID(suite.viper))
131+
suite.Error(checkAccountID(testValue))
132+
}
133+
}
134+
135+
func TestCheckProfileAccount(t *testing.T) {
136+
v := viper.New()
137+
138+
testValues := [][]string{
139+
{"test-id:012345678901"},
140+
{"test-id1:012345678901", "test-id2:012345678901", "test-id3:012345678901"},
141+
}
142+
for _, testValue := range testValues {
143+
v.Set(AWSProfileAccountFlag, testValue)
144+
err := checkProfileAccount(v)
145+
assert.NoError(t, err)
146+
}
147+
testValuesWithErrors := [][]string{
148+
{"test-id:0123456789011"},
149+
{":012345678901"},
150+
{"test-id:"},
151+
{"test-id012345678901"},
152+
}
153+
for _, testValue := range testValuesWithErrors {
154+
v.Set(AWSProfileAccountFlag, testValue)
155+
err := checkProfileAccount(v)
156+
assert.Error(t, err)
116157
}
117158
}
118159

cmd/main.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ func main() {
3030
Long: "Setup new AWS user with aws-vault",
3131
RunE: setupUserFunction,
3232
}
33-
setupUserInitFlags(setupUserCommand.Flags())
33+
SetupUserInitFlags(setupUserCommand.Flags())
3434
root.AddCommand(setupUserCommand)
3535

3636
addProfileCommand := &cobra.Command{
@@ -40,7 +40,7 @@ func main() {
4040
Long: "Add new AWS config profile",
4141
RunE: addProfileFunction,
4242
}
43-
addProfileInitFlags(addProfileCommand.Flags())
43+
AddProfileInitFlags(addProfileCommand.Flags())
4444
root.AddCommand(addProfileCommand)
4545

4646
versionCommand := &cobra.Command{

0 commit comments

Comments
 (0)