Skip to content

Commit 696ae86

Browse files
author
Chris Gilmer
committed
Move shared code and tests for that code
1 parent 544bee4 commit 696ae86

File tree

4 files changed

+203
-147
lines changed

4 files changed

+203
-147
lines changed

cmd/common.go

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
package main
2+
3+
import (
4+
"fmt"
5+
"log"
6+
"os"
7+
8+
"github.com/99designs/aws-vault/prompt"
9+
"github.com/99designs/aws-vault/vault"
10+
"github.com/99designs/keyring"
11+
"github.com/aws/aws-sdk-go/aws/endpoints"
12+
"github.com/pkg/browser"
13+
"github.com/skip2/go-qrcode"
14+
"gopkg.in/ini.v1"
15+
)
16+
17+
func promptMFAtoken(messagePrefix string, logger *log.Logger) string {
18+
var token string
19+
for attempts := maxMFATokenPromptAttempts; token == "" && attempts > 0; attempts-- {
20+
t, err := prompt.TerminalPrompt(fmt.Sprintf("%sMFA token (%d attempts remaining): ", messagePrefix, attempts))
21+
if err != nil {
22+
logger.Println(err)
23+
continue
24+
}
25+
err = validate.Var(t, "numeric,len=6")
26+
if err != nil {
27+
logger.Println("MFA token must be 6 digits. Please try again.")
28+
continue
29+
}
30+
token = t
31+
}
32+
return token
33+
}
34+
35+
func getMFATokenPair(logger *log.Logger) MFATokenPair {
36+
var mfaTokenPair MFATokenPair
37+
for attempts := maxMFATokenPromptAttempts; attempts > 0; attempts-- {
38+
logger.Printf("Two unique MFA tokens needed to activate MFA device (%d attempts remaining)\n", attempts)
39+
authToken1 := promptMFAtoken("First ", logger)
40+
authToken2 := promptMFAtoken("Second ", logger)
41+
42+
mfaTokenPair = MFATokenPair{
43+
Token1: authToken1,
44+
Token2: authToken2,
45+
}
46+
err := validate.Struct(mfaTokenPair)
47+
if err != nil {
48+
logger.Println(err)
49+
} else {
50+
break
51+
}
52+
}
53+
return mfaTokenPair
54+
}
55+
56+
// UpdateAWSProfile updates or creates a single AWS profile to the AWS config file
57+
func UpdateAWSProfile(iniFile *ini.File, profile *vault.ProfileSection, sourceProfile *string, output string, logger *log.Logger) error {
58+
logger.Printf("Adding the profile %q to the AWS config file", profile.Name)
59+
60+
sectionName := fmt.Sprintf("profile %s", profile.Name)
61+
62+
// Get or create section before updating
63+
var err error
64+
var section *ini.Section
65+
section = iniFile.Section(sectionName)
66+
if section == nil {
67+
section, err = iniFile.NewSection(sectionName)
68+
if err != nil {
69+
return fmt.Errorf("error creating section %q: %w", profile.Name, err)
70+
}
71+
}
72+
73+
// Add the source profile when provided
74+
if sourceProfile != nil {
75+
_, err = section.NewKey("source_profile", *sourceProfile)
76+
if err != nil {
77+
return fmt.Errorf("unable to add source profile: %w", err)
78+
}
79+
}
80+
81+
if err = section.ReflectFrom(&profile); err != nil {
82+
return fmt.Errorf("error mapping profile to ini file: %w", err)
83+
}
84+
_, err = section.NewKey("output", output)
85+
if err != nil {
86+
return fmt.Errorf("unable to add output key: %w", err)
87+
}
88+
return nil
89+
}
90+
91+
func getKeyring(keychainName string) (*keyring.Keyring, error) {
92+
ring, err := keyring.Open(keyring.Config{
93+
ServiceName: "aws-vault",
94+
AllowedBackends: []keyring.BackendType{
95+
keyring.KeychainBackend,
96+
keyring.FileBackend,
97+
},
98+
KeychainName: keychainName,
99+
KeychainTrustApplication: true,
100+
})
101+
if err != nil {
102+
return nil, fmt.Errorf("error opening keyring: %w", err)
103+
}
104+
105+
return &ring, nil
106+
}
107+
108+
func deleteSession(profile string, keyring *keyring.Keyring, logger *log.Logger) error {
109+
credsKeyring := vault.CredentialKeyring{Keyring: *keyring}
110+
sessions := credsKeyring.Sessions()
111+
112+
if n, _ := sessions.Delete(profile); n > 0 {
113+
logger.Printf("Deleted %d existing sessions.\n", n)
114+
}
115+
116+
return nil
117+
}
118+
119+
func generateQrCode(payload string, tempFile *os.File) error {
120+
// Creates QR Code
121+
q, err := qrcode.New(payload, qrcode.Medium)
122+
if err != nil {
123+
return fmt.Errorf("unable to create qr code: %w", err)
124+
}
125+
126+
// Generates a QR PNG 256 x 256, returns []byte
127+
qr, err := q.PNG(256)
128+
if err != nil {
129+
return fmt.Errorf("unable to generate PNG: %w", err)
130+
}
131+
132+
// Write the QR PNG to the Temp File
133+
if _, err := tempFile.Write(qr); err != nil {
134+
_ = tempFile.Close()
135+
return err
136+
}
137+
return nil
138+
}
139+
140+
func openQrCode(tempFile *os.File) error {
141+
err := browser.OpenFile(tempFile.Name())
142+
if err != nil {
143+
return fmt.Errorf("unable to open QR Code PNG: %w", err)
144+
}
145+
146+
if err := tempFile.Close(); err != nil {
147+
return fmt.Errorf("unable to close QR Code: %w", err)
148+
}
149+
return nil
150+
}
151+
152+
func getPartition(region string) (string, error) {
153+
partition, ok := endpoints.PartitionForRegion(endpoints.DefaultPartitions(), region)
154+
if !ok {
155+
return "", fmt.Errorf("Error finding partition for region: %s", region)
156+
}
157+
return partition.ID(), nil
158+
}

cmd/common_test.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package main
2+
3+
import (
4+
"io/ioutil"
5+
"os"
6+
"testing"
7+
8+
"github.com/stretchr/testify/assert"
9+
)
10+
11+
func newConfigFile(t *testing.T, b []byte) string {
12+
f, err := ioutil.TempFile("", "aws-config")
13+
if err != nil {
14+
t.Fatal(err)
15+
}
16+
if err := ioutil.WriteFile(f.Name(), b, 0600); err != nil {
17+
t.Fatal(err)
18+
}
19+
return f.Name()
20+
}
21+
22+
func TestGenerateQrCode(t *testing.T) {
23+
tempFile, err := ioutil.TempFile("", "temp-qr.*.png")
24+
assert.NoError(t, err)
25+
defer func() {
26+
errRemove := os.Remove(tempFile.Name())
27+
assert.NoError(t, errRemove)
28+
}()
29+
30+
err = generateQrCode("otpauth://totp/super@top?secret=secret", tempFile)
31+
assert.NoError(t, err)
32+
}
33+
34+
func TestGetPartition(t *testing.T) {
35+
commPartition, err := getPartition("us-west-2")
36+
assert.Equal(t, commPartition, "aws")
37+
assert.NoError(t, err)
38+
39+
govPartition, err := getPartition("us-gov-west-1")
40+
assert.Equal(t, govPartition, "aws-us-gov")
41+
assert.NoError(t, err)
42+
43+
_, err = getPartition("aws-under-the-sea")
44+
assert.Error(t, err)
45+
}

cmd/setup.go

Lines changed: 0 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@ import (
1717
"github.com/aws/aws-sdk-go/aws/session"
1818
"github.com/aws/aws-sdk-go/service/iam"
1919
"github.com/aws/aws-sdk-go/service/sts"
20-
"github.com/pkg/browser"
21-
"github.com/skip2/go-qrcode"
2220
"github.com/spf13/cobra"
2321
"github.com/spf13/pflag"
2422
"github.com/spf13/viper"
@@ -300,45 +298,6 @@ func (sc *SetupConfig) CreateVirtualMFADevice() error {
300298
return nil
301299
}
302300

303-
func promptMFAtoken(messagePrefix string, logger *log.Logger) string {
304-
var token string
305-
for attempts := maxMFATokenPromptAttempts; token == "" && attempts > 0; attempts-- {
306-
t, err := prompt.TerminalPrompt(fmt.Sprintf("%sMFA token (%d attempts remaining): ", messagePrefix, attempts))
307-
if err != nil {
308-
logger.Println(err)
309-
continue
310-
}
311-
err = validate.Var(t, "numeric,len=6")
312-
if err != nil {
313-
logger.Println("MFA token must be 6 digits. Please try again.")
314-
continue
315-
}
316-
token = t
317-
}
318-
return token
319-
}
320-
321-
func getMFATokenPair(logger *log.Logger) MFATokenPair {
322-
var mfaTokenPair MFATokenPair
323-
for attempts := maxMFATokenPromptAttempts; attempts > 0; attempts-- {
324-
logger.Printf("Two unique MFA tokens needed to activate MFA device (%d attempts remaining)\n", attempts)
325-
authToken1 := promptMFAtoken("First ", logger)
326-
authToken2 := promptMFAtoken("Second ", logger)
327-
328-
mfaTokenPair = MFATokenPair{
329-
Token1: authToken1,
330-
Token2: authToken2,
331-
}
332-
err := validate.Struct(mfaTokenPair)
333-
if err != nil {
334-
logger.Println(err)
335-
} else {
336-
break
337-
}
338-
}
339-
return mfaTokenPair
340-
}
341-
342301
// EnableVirtualMFADevice enables the user's MFA device
343302
func (sc *SetupConfig) EnableVirtualMFADevice() error {
344303
sc.Logger.Println("Enabling the virtual MFA device")
@@ -536,75 +495,6 @@ func (sc *SetupConfig) RemoveVaultSession() error {
536495
return nil
537496
}
538497

539-
func getKeyring(keychainName string) (*keyring.Keyring, error) {
540-
ring, err := keyring.Open(keyring.Config{
541-
ServiceName: "aws-vault",
542-
AllowedBackends: []keyring.BackendType{
543-
keyring.KeychainBackend,
544-
keyring.FileBackend,
545-
},
546-
KeychainName: keychainName,
547-
KeychainTrustApplication: true,
548-
})
549-
if err != nil {
550-
return nil, fmt.Errorf("error opening keyring: %w", err)
551-
}
552-
553-
return &ring, nil
554-
}
555-
556-
func deleteSession(profile string, keyring *keyring.Keyring, logger *log.Logger) error {
557-
credsKeyring := vault.CredentialKeyring{Keyring: *keyring}
558-
sessions := credsKeyring.Sessions()
559-
560-
if n, _ := sessions.Delete(profile); n > 0 {
561-
logger.Printf("Deleted %d existing sessions.\n", n)
562-
}
563-
564-
return nil
565-
}
566-
567-
func generateQrCode(payload string, tempFile *os.File) error {
568-
// Creates QR Code
569-
q, err := qrcode.New(payload, qrcode.Medium)
570-
if err != nil {
571-
return fmt.Errorf("unable to create qr code: %w", err)
572-
}
573-
574-
// Generates a QR PNG 256 x 256, returns []byte
575-
qr, err := q.PNG(256)
576-
if err != nil {
577-
return fmt.Errorf("unable to generate PNG: %w", err)
578-
}
579-
580-
// Write the QR PNG to the Temp File
581-
if _, err := tempFile.Write(qr); err != nil {
582-
_ = tempFile.Close()
583-
return err
584-
}
585-
return nil
586-
}
587-
588-
func openQrCode(tempFile *os.File) error {
589-
err := browser.OpenFile(tempFile.Name())
590-
if err != nil {
591-
return fmt.Errorf("unable to open QR Code PNG: %w", err)
592-
}
593-
594-
if err := tempFile.Close(); err != nil {
595-
return fmt.Errorf("unable to close QR Code: %w", err)
596-
}
597-
return nil
598-
}
599-
600-
func getPartition(region string) (string, error) {
601-
partition, ok := endpoints.PartitionForRegion(endpoints.DefaultPartitions(), region)
602-
if !ok {
603-
return "", fmt.Errorf("Error finding partition for region: %s", region)
604-
}
605-
return partition.ID(), nil
606-
}
607-
608498
func setupUserFunction(cmd *cobra.Command, args []string) error {
609499
defer func() {
610500
if r := recover(); r != nil {

cmd/setup_test.go

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package main
22

33
import (
4-
"io/ioutil"
54
"log"
65
"os"
76
"testing"
@@ -10,17 +9,6 @@ import (
109
"github.com/stretchr/testify/assert"
1110
)
1211

13-
func newConfigFile(t *testing.T, b []byte) string {
14-
f, err := ioutil.TempFile("", "aws-config")
15-
if err != nil {
16-
t.Fatal(err)
17-
}
18-
if err := ioutil.WriteFile(f.Name(), b, 0600); err != nil {
19-
t.Fatal(err)
20-
}
21-
return f.Name()
22-
}
23-
2412
func TestUpdateAWSConfigFile(t *testing.T) {
2513

2614
// Test logger
@@ -85,28 +73,3 @@ output=json
8573
assert.Equal(t, testSection.Region, "us-west-2")
8674
// assert.Equal(t, testBaseSection.Output, "json")
8775
}
88-
89-
func TestGenerateQrCode(t *testing.T) {
90-
tempFile, err := ioutil.TempFile("", "temp-qr.*.png")
91-
assert.NoError(t, err)
92-
defer func() {
93-
errRemove := os.Remove(tempFile.Name())
94-
assert.NoError(t, errRemove)
95-
}()
96-
97-
err = generateQrCode("otpauth://totp/super@top?secret=secret", tempFile)
98-
assert.NoError(t, err)
99-
}
100-
101-
func TestGetPartition(t *testing.T) {
102-
commPartition, err := getPartition("us-west-2")
103-
assert.Equal(t, commPartition, "aws")
104-
assert.NoError(t, err)
105-
106-
govPartition, err := getPartition("us-gov-west-1")
107-
assert.Equal(t, govPartition, "aws-us-gov")
108-
assert.NoError(t, err)
109-
110-
_, err = getPartition("aws-under-the-sea")
111-
assert.Error(t, err)
112-
}

0 commit comments

Comments
 (0)