|
| 1 | +package main |
| 2 | + |
| 3 | +import ( |
| 4 | + "fmt" |
| 5 | + "regexp" |
| 6 | + |
| 7 | + "github.com/aws/aws-sdk-go/aws/endpoints" |
| 8 | + "github.com/spf13/viper" |
| 9 | +) |
| 10 | + |
| 11 | +const ( |
| 12 | + // AWSRegionFlag is the generic AWS Region Flag |
| 13 | + AWSRegionFlag string = "aws-region" |
| 14 | + // AWSAccountIDFlag is the AWS AccountID Flag |
| 15 | + AWSAccountIDFlag string = "aws-account-id" |
| 16 | + |
| 17 | + // VaultAWSKeychainNameFlag is the aws-vault keychain name Flag |
| 18 | + VaultAWSKeychainNameFlag string = "aws-vault-keychain-name" |
| 19 | + // VaultAWSKeychainNameDefault is the aws-vault default keychain name |
| 20 | + VaultAWSKeychainNameDefault string = "login" |
| 21 | + // VaultAWSProfileFlag is the aws-vault profile name Flag |
| 22 | + VaultAWSProfileFlag string = "aws-profile" |
| 23 | + |
| 24 | + // IAMUserFlag is the IAM User name Flag |
| 25 | + IAMUserFlag string = "iam-user" |
| 26 | + // IAMRoleFlag is the IAM Role name Flag |
| 27 | + IAMRoleFlag string = "iam-role" |
| 28 | + |
| 29 | + // OutputFlag is the Output Flag |
| 30 | + OutputFlag = "output" |
| 31 | + |
| 32 | + // VerboseFlag is the Verbose Flag |
| 33 | + VerboseFlag string = "verbose" |
| 34 | +) |
| 35 | + |
| 36 | +func stringSliceContains(stringSlice []string, value string) bool { |
| 37 | + for _, x := range stringSlice { |
| 38 | + if value == x { |
| 39 | + return true |
| 40 | + } |
| 41 | + } |
| 42 | + return false |
| 43 | +} |
| 44 | + |
| 45 | +type errInvalidKeychainName struct { |
| 46 | + KeychainName string |
| 47 | +} |
| 48 | + |
| 49 | +func (e *errInvalidKeychainName) Error() string { |
| 50 | + return fmt.Sprintf("invalid keychain name '%s'", e.KeychainName) |
| 51 | +} |
| 52 | + |
| 53 | +type errInvalidAWSProfile struct { |
| 54 | + Profile string |
| 55 | +} |
| 56 | + |
| 57 | +func (e *errInvalidAWSProfile) Error() string { |
| 58 | + return fmt.Sprintf("invalid aws profile '%s'", e.Profile) |
| 59 | +} |
| 60 | + |
| 61 | +type errInvalidVault struct { |
| 62 | + KeychainName string |
| 63 | + Profile string |
| 64 | +} |
| 65 | + |
| 66 | +func (e *errInvalidVault) Error() string { |
| 67 | + return fmt.Sprintf("invalid keychain name %q or profile %q", e.KeychainName, e.Profile) |
| 68 | +} |
| 69 | + |
| 70 | +func checkVault(v *viper.Viper) error { |
| 71 | + // Both keychain name and profile are required or both must be missing |
| 72 | + keychainName := v.GetString(VaultAWSKeychainNameFlag) |
| 73 | + keychainNames := []string{ |
| 74 | + VaultAWSKeychainNameDefault, |
| 75 | + } |
| 76 | + if len(keychainName) > 0 && !stringSliceContains(keychainNames, keychainName) { |
| 77 | + return fmt.Errorf("%s is invalid, expected %v: %w", VaultAWSKeychainNameFlag, keychainNames, &errInvalidKeychainName{KeychainName: keychainName}) |
| 78 | + } |
| 79 | + |
| 80 | + awsProfile := v.GetString(VaultAWSProfileFlag) |
| 81 | + if len(awsProfile) == 0 { |
| 82 | + return fmt.Errorf("%s must not be empty: %w", VaultAWSProfileFlag, &errInvalidAWSProfile{Profile: awsProfile}) |
| 83 | + } |
| 84 | + |
| 85 | + return nil |
| 86 | +} |
| 87 | + |
| 88 | +type errInvalidRegion struct { |
| 89 | + Region string |
| 90 | +} |
| 91 | + |
| 92 | +func (e *errInvalidRegion) Error() string { |
| 93 | + return fmt.Sprintf("invalid region %q", e.Region) |
| 94 | +} |
| 95 | + |
| 96 | +func checkRegion(v *viper.Viper) error { |
| 97 | + |
| 98 | + r := v.GetString(AWSRegionFlag) |
| 99 | + if _, ok := endpoints.PartitionForRegion(endpoints.DefaultPartitions(), r); !ok { |
| 100 | + return fmt.Errorf("%s is invalid: %w", AWSRegionFlag, &errInvalidRegion{Region: r}) |
| 101 | + } |
| 102 | + |
| 103 | + return nil |
| 104 | +} |
| 105 | + |
| 106 | +type errInvalidAccountID struct { |
| 107 | + AccountID string |
| 108 | +} |
| 109 | + |
| 110 | +func (e *errInvalidAccountID) Error() string { |
| 111 | + return fmt.Sprintf("invalid Account ID %q", e.AccountID) |
| 112 | +} |
| 113 | + |
| 114 | +func checkAccountID(v *viper.Viper) error { |
| 115 | + id := v.GetString(AWSAccountIDFlag) |
| 116 | + if matched, err := regexp.Match(`\d[12]`, []byte(id)); !matched || err != nil { |
| 117 | + return fmt.Errorf("%s must be a 12 digit number: %w", AWSAccountIDFlag, &errInvalidAccountID{AccountID: id}) |
| 118 | + } |
| 119 | + |
| 120 | + return nil |
| 121 | +} |
| 122 | + |
| 123 | +type errInvalidIAMUser struct { |
| 124 | + IAMUser string |
| 125 | +} |
| 126 | + |
| 127 | +func (e *errInvalidIAMUser) Error() string { |
| 128 | + return fmt.Sprintf("invalid output %q", e.IAMUser) |
| 129 | +} |
| 130 | + |
| 131 | +func checkIAMUser(v *viper.Viper) error { |
| 132 | + |
| 133 | + user := v.GetString(IAMUserFlag) |
| 134 | + if len(user) == 0 { |
| 135 | + return fmt.Errorf("%s is invalid: %w", IAMUserFlag, &errInvalidIAMUser{IAMUser: user}) |
| 136 | + } |
| 137 | + |
| 138 | + return nil |
| 139 | +} |
| 140 | + |
| 141 | +type errInvalidIAMRole struct { |
| 142 | + IAMRole string |
| 143 | +} |
| 144 | + |
| 145 | +func (e *errInvalidIAMRole) Error() string { |
| 146 | + return fmt.Sprintf("invalid output %q", e.IAMRole) |
| 147 | +} |
| 148 | + |
| 149 | +func checkIAMRole(v *viper.Viper) error { |
| 150 | + |
| 151 | + role := v.GetString(IAMRoleFlag) |
| 152 | + if len(role) == 0 { |
| 153 | + return fmt.Errorf("%s is invalid: %w", IAMRoleFlag, &errInvalidIAMRole{IAMRole: role}) |
| 154 | + } |
| 155 | + |
| 156 | + return nil |
| 157 | +} |
| 158 | + |
| 159 | +type errInvalidOutput struct { |
| 160 | + Output string |
| 161 | +} |
| 162 | + |
| 163 | +func (e *errInvalidOutput) Error() string { |
| 164 | + return fmt.Sprintf("invalid output %q", e.Output) |
| 165 | +} |
| 166 | + |
| 167 | +func checkOutput(v *viper.Viper) error { |
| 168 | + |
| 169 | + o := v.GetString(OutputFlag) |
| 170 | + outputTypes := []string{ |
| 171 | + "text", |
| 172 | + "json", |
| 173 | + "yaml", |
| 174 | + "table", |
| 175 | + } |
| 176 | + if len(o) > 0 && !stringSliceContains(outputTypes, o) { |
| 177 | + return fmt.Errorf("%s is invalid, expected one of %v: %w", OutputFlag, outputTypes, &errInvalidOutput{Output: o}) |
| 178 | + } |
| 179 | + |
| 180 | + return nil |
| 181 | +} |
0 commit comments