Skip to content

Commit 0738475

Browse files
Add --merge-config flag to support merging with default configuration (#1075)
* add config merge support * fix indentation * Update cmd/polaris/root.go Co-authored-by: Andy Suderman <[email protected]> --------- Co-authored-by: Andy Suderman <[email protected]>
1 parent 9b5438d commit 0738475

File tree

5 files changed

+137
-17
lines changed

5 files changed

+137
-17
lines changed

cmd/polaris/root.go

+3-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
)
2525

2626
var (
27+
mergeConfig bool
2728
configPath string
2829
disallowExemptions bool
2930
disallowConfigExemptions bool
@@ -42,6 +43,7 @@ var (
4243

4344
func init() {
4445
// Flags
46+
rootCmd.PersistentFlags().BoolVarP(&mergeConfig, "merge-config", "m", false, "If true, custom configuration will be merged with default configuration instead of replacing it.")
4547
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "Location of Polaris configuration file.")
4648
rootCmd.PersistentFlags().StringVarP(&kubeContext, "context", "x", "", "Set the kube context.")
4749
rootCmd.PersistentFlags().BoolVarP(&disallowExemptions, "disallow-exemptions", "", false, "Disallow any configured exemption.")
@@ -65,7 +67,7 @@ var rootCmd = &cobra.Command{
6567
logrus.SetLevel(parsedLevel)
6668
}
6769

68-
config, err = conf.ParseFile(configPath)
70+
config, err = conf.MergeConfigAndParseFile(configPath, mergeConfig)
6971
if err != nil {
7072
logrus.Errorf("Error parsing config at %s: %v", configPath, err)
7173
os.Exit(1)

pkg/config/config.go

+37-14
Original file line numberDiff line numberDiff line change
@@ -52,27 +52,50 @@ type Exemption struct {
5252
//go:embed default.yaml
5353
var defaultConfig []byte
5454

55-
// ParseFile parses config from a file.
56-
func ParseFile(path string) (Configuration, error) {
57-
var rawBytes []byte
55+
// MergeConfigAndParseFile parses config from a file.
56+
func MergeConfigAndParseFile(customConfigPath string, mergeConfig bool) (Configuration, error) {
57+
rawBytes, err := mergeConfigFile(customConfigPath, mergeConfig)
58+
if err != nil {
59+
return Configuration{}, err
60+
}
61+
62+
return Parse(rawBytes)
63+
}
64+
65+
func mergeConfigFile(customConfigPath string, mergeConfig bool) ([]byte, error) {
66+
if customConfigPath == "" {
67+
return defaultConfig, nil
68+
}
69+
70+
var customConfigContent []byte
5871
var err error
59-
if path == "" {
60-
rawBytes = defaultConfig
61-
} else if strings.HasPrefix(path, "https://") || strings.HasPrefix(path, "http://") {
72+
if strings.HasPrefix(customConfigPath, "https://") || strings.HasPrefix(customConfigPath, "http://") {
6273
// path is a url
63-
response, err2 := http.Get(path)
64-
if err2 != nil {
65-
return Configuration{}, err2
74+
response, err := http.Get(customConfigPath)
75+
if err != nil {
76+
return nil, err
77+
}
78+
customConfigContent, err = io.ReadAll(response.Body)
79+
if err != nil {
80+
return nil, err
6681
}
67-
rawBytes, err = io.ReadAll(response.Body)
6882
} else {
6983
// path is local
70-
rawBytes, err = os.ReadFile(path)
84+
customConfigContent, err = os.ReadFile(customConfigPath)
85+
if err != nil {
86+
return nil, err
87+
}
7188
}
72-
if err != nil {
73-
return Configuration{}, err
89+
90+
if mergeConfig {
91+
mergedConfig, err := mergeYaml(defaultConfig, customConfigContent)
92+
if err != nil {
93+
return nil, err
94+
}
95+
return mergedConfig, nil
7496
}
75-
return Parse(rawBytes)
97+
98+
return customConfigContent, nil
7699
}
77100

78101
// Parse parses config from a byte array.

pkg/config/config_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ func TestConfigFromURL(t *testing.T) {
125125
}()
126126
time.Sleep(time.Second)
127127

128-
parsedConf, err = ParseFile("http://localhost:8081/exampleURL")
128+
parsedConf, err = MergeConfigAndParseFile("http://localhost:8081/exampleURL", false)
129129
assert.NoError(t, err, "Expected no error when parsing YAML from URL")
130130
if err := srv.Shutdown(context.TODO()); err != nil {
131131
panic(err)
@@ -136,7 +136,7 @@ func TestConfigFromURL(t *testing.T) {
136136

137137
func TestConfigNoServerError(t *testing.T) {
138138
var err error
139-
_, err = ParseFile("http://localhost:8081/exampleURL")
139+
_, err = MergeConfigAndParseFile("http://localhost:8081/exampleURL", false)
140140
assert.Error(t, err)
141141
assert.Regexp(t, regexp.MustCompile("connection refused"), err.Error())
142142
}

pkg/config/merger.go

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package config
2+
3+
import (
4+
"gopkg.in/yaml.v3" // do not change the yaml import
5+
)
6+
7+
func mergeYaml(defaultConfig, overridesConfig []byte) ([]byte, error) {
8+
var defaultData, overrideConfig map[string]any
9+
10+
err := yaml.Unmarshal([]byte(defaultConfig), &defaultData)
11+
if err != nil {
12+
return nil, err
13+
}
14+
15+
err = yaml.Unmarshal([]byte(overridesConfig), &overrideConfig)
16+
if err != nil {
17+
return nil, err
18+
}
19+
20+
mergedData := mergeYAMLMaps(defaultData, overrideConfig)
21+
22+
mergedConfig, err := yaml.Marshal(mergedData)
23+
if err != nil {
24+
return nil, err
25+
}
26+
27+
return mergedConfig, nil
28+
}
29+
30+
func mergeYAMLMaps(defaults, overrides map[string]any) map[string]any {
31+
for k, v := range overrides {
32+
if vMap, ok := v.(map[string]any); ok {
33+
// if the key exists in defaults and is a map, recursively merge
34+
if mv1, ok := defaults[k].(map[string]any); ok {
35+
defaults[k] = mergeYAMLMaps(mv1, vMap)
36+
} else {
37+
defaults[k] = vMap
38+
}
39+
} else {
40+
// add or overwrite the value in defaults
41+
defaults[k] = v
42+
}
43+
}
44+
return defaults
45+
}

pkg/config/merger_test.go

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
package config
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
)
8+
9+
var defaults = `
10+
checks:
11+
deploymentMissingReplicas: warning
12+
priorityClassNotSet: warning
13+
tagNotSpecified: danger
14+
existing:
15+
sub:
16+
key: value
17+
`
18+
19+
var overrides = `
20+
checks:
21+
pullPolicyNotAlways: ignore
22+
tagNotSpecified: overrides
23+
existing:
24+
sub:
25+
key1: value1
26+
new: value
27+
new:
28+
key: value
29+
`
30+
31+
func TestMergeYaml(t *testing.T) {
32+
mergedContent, err := mergeYaml([]byte(defaults), []byte(overrides))
33+
assert.NoError(t, err)
34+
35+
expectedYAML := `checks:
36+
deploymentMissingReplicas: warning
37+
priorityClassNotSet: warning
38+
pullPolicyNotAlways: ignore
39+
tagNotSpecified: overrides
40+
existing:
41+
new: value
42+
sub:
43+
key: value
44+
key1: value1
45+
new:
46+
key: value
47+
`
48+
49+
assert.Equal(t, expectedYAML, string(mergedContent))
50+
}

0 commit comments

Comments
 (0)