Skip to content

Commit e8c1d45

Browse files
committed
refactor(misconf): make Rego scanner independent of config type
Signed-off-by: nikpivkin <[email protected]>
1 parent ffa3023 commit e8c1d45

File tree

15 files changed

+65
-134
lines changed

15 files changed

+65
-134
lines changed

pkg/iac/rego/build.go

+6-8
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,15 @@ import (
1313
"github.com/aquasecurity/trivy/pkg/iac/types"
1414
)
1515

16-
func BuildSchemaSetFromPolicies(policies map[string]*ast.Module, paths []string, fsys fs.FS, customSchemas map[string][]byte) (*ast.SchemaSet, bool, error) {
16+
func BuildSchemaSetFromPolicies(policies map[string]*ast.Module, paths []string, fsys fs.FS, customSchemas map[string][]byte) (*ast.SchemaSet, error) {
1717
schemaSet := ast.NewSchemaSet()
1818
schemaSet.Put(ast.MustParseRef("schema.input"), make(map[string]any)) // for backwards compat only
19-
var customFound bool
2019

2120
for _, policy := range policies {
2221
for _, annotation := range policy.Annotations {
2322
for _, ss := range annotation.Schemas {
2423
schemaName, err := ss.Schema.Ptr()
25-
if err != nil || schemaName == "input" {
24+
if err != nil || schemaName == "input" { // for backwards compat only
2625
continue
2726
}
2827

@@ -38,27 +37,26 @@ func BuildSchemaSetFromPolicies(policies map[string]*ast.Module, paths []string,
3837
} else {
3938
b, err := findSchemaInFS(paths, fsys, schemaName)
4039
if err != nil {
41-
return schemaSet, true, err
40+
return nil, err
4241
}
4342

4443
if b == nil {
45-
return nil, false, fmt.Errorf("could not find schema %q", schemaName)
44+
return nil, fmt.Errorf("could not find schema %q", schemaName)
4645
}
4746

4847
schema = b
4948
}
5049

5150
var rawSchema any
5251
if err := util.UnmarshalJSON(schema, &rawSchema); err != nil {
53-
return schemaSet, false, fmt.Errorf("could not parse schema %q: %w", schemaName, err)
52+
return schemaSet, fmt.Errorf("could not parse schema %q: %w", schemaName, err)
5453
}
55-
customFound = true
5654
schemaSet.Put(ss.Schema, rawSchema)
5755
}
5856
}
5957
}
6058

61-
return schemaSet, customFound, nil
59+
return schemaSet, nil
6260
}
6361

6462
// findSchemaInFS tries to find the schema anywhere in the specified FS

pkg/iac/rego/embed.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ var LoadAndRegister = sync.OnceFunc(func() {
3636
func RegisterRegoRules(modules map[string]*ast.Module) {
3737
ctx := context.TODO()
3838

39-
schemaSet, _, _ := BuildSchemaSetFromPolicies(modules, nil, nil, make(map[string][]byte))
39+
schemaSet, _ := BuildSchemaSetFromPolicies(modules, nil, nil, make(map[string][]byte))
4040

4141
compiler := ast.NewCompiler().
4242
WithSchemas(schemaSet).

pkg/iac/rego/load.go

+2-22
Original file line numberDiff line numberDiff line change
@@ -233,13 +233,10 @@ func (s *Scanner) prunePoliciesWithError(compiler *ast.Compiler) error {
233233

234234
func (s *Scanner) compilePolicies(srcFS fs.FS, paths []string) error {
235235

236-
schemaSet, custom, err := BuildSchemaSetFromPolicies(s.policies, paths, srcFS, s.customSchemas)
236+
schemaSet, err := BuildSchemaSetFromPolicies(s.policies, paths, srcFS, s.customSchemas)
237237
if err != nil {
238238
return err
239239
}
240-
if custom {
241-
s.inputSchema = nil // discard auto detected input schema in favor of check defined schema
242-
}
243240

244241
compiler := ast.NewCompiler().
245242
WithUseTypeCheckAnnotations(true).
@@ -259,18 +256,6 @@ func (s *Scanner) compilePolicies(srcFS fs.FS, paths []string) error {
259256
if err := s.filterModules(retriever); err != nil {
260257
return err
261258
}
262-
if s.inputSchema != nil {
263-
schemaSet := ast.NewSchemaSet()
264-
schemaSet.Put(ast.MustParseRef("schema.input"), s.inputSchema)
265-
compiler.WithSchemas(schemaSet)
266-
compiler.Compile(s.policies)
267-
if compiler.Failed() {
268-
if err := s.prunePoliciesWithError(compiler); err != nil {
269-
return err
270-
}
271-
return s.compilePolicies(srcFS, paths)
272-
}
273-
}
274259
s.compiler = compiler
275260
s.retriever = retriever
276261
return nil
@@ -307,12 +292,7 @@ func (s *Scanner) filterModules(retriever *MetadataRetriever) error {
307292
continue
308293
}
309294

310-
for _, selector := range meta.InputOptions.Selectors {
311-
if selector.Type == string(s.sourceType) {
312-
filtered[name] = module
313-
break
314-
}
315-
}
295+
filtered[name] = module
316296
}
317297

318298
s.policies = filtered

pkg/iac/rego/load_test.go

-9
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ import (
1515

1616
checks "github.com/aquasecurity/trivy-checks"
1717
"github.com/aquasecurity/trivy/pkg/iac/rego"
18-
"github.com/aquasecurity/trivy/pkg/iac/types"
1918
"github.com/aquasecurity/trivy/pkg/log"
2019
)
2120

@@ -30,7 +29,6 @@ func Test_RegoScanning_WithSomeInvalidPolicies(t *testing.T) {
3029
var debugBuf bytes.Buffer
3130
slog.SetDefault(log.New(log.NewHandler(&debugBuf, nil)))
3231
scanner := rego.NewScanner(
33-
types.SourceDockerfile,
3432
rego.WithRegoErrorLimits(0),
3533
rego.WithPolicyDirs("."),
3634
)
@@ -44,7 +42,6 @@ func Test_RegoScanning_WithSomeInvalidPolicies(t *testing.T) {
4442
var debugBuf bytes.Buffer
4543
slog.SetDefault(log.New(log.NewHandler(&debugBuf, nil)))
4644
scanner := rego.NewScanner(
47-
types.SourceDockerfile,
4845
rego.WithRegoErrorLimits(1),
4946
rego.WithPolicyDirs("."),
5047
)
@@ -65,7 +62,6 @@ deny {
6562
input.evil == "foo bar"
6663
}`
6764
scanner := rego.NewScanner(
68-
types.SourceJSON,
6965
rego.WithPolicyDirs("."),
7066
rego.WithPolicyReader(strings.NewReader(check)),
7167
)
@@ -84,7 +80,6 @@ deny {
8480
input.evil == "foo bar"
8581
}`
8682
scanner := rego.NewScanner(
87-
types.SourceJSON,
8883
rego.WithPolicyDirs("."),
8984
rego.WithPolicyReader(strings.NewReader(check)),
9085
)
@@ -106,14 +101,12 @@ deny {
106101
input.evil == "foo bar"
107102
}`
108103
scanner := rego.NewScanner(
109-
types.SourceJSON,
110104
rego.WithPolicyDirs("."),
111105
rego.WithPolicyReader(strings.NewReader(check)),
112106
)
113107
err := scanner.LoadPolicies(fstest.MapFS{})
114108
require.NoError(t, err)
115109
})
116-
117110
}
118111

119112
func Test_FallbackToEmbedded(t *testing.T) {
@@ -195,7 +188,6 @@ deny {
195188
for _, tt := range tests {
196189
t.Run(tt.name, func(t *testing.T) {
197190
scanner := rego.NewScanner(
198-
types.SourceDockerfile,
199191
rego.WithRegoErrorLimits(0),
200192
rego.WithEmbeddedPolicies(false),
201193
rego.WithPolicyDirs("."),
@@ -255,7 +247,6 @@ deny {
255247
}
256248

257249
scanner := rego.NewScanner(
258-
types.SourceDockerfile,
259250
rego.WithEmbeddedPolicies(false),
260251
rego.WithPolicyDirs("."),
261252
)

pkg/iac/rego/scanner.go

+20-33
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ import (
1616

1717
"github.com/aquasecurity/trivy/pkg/iac/framework"
1818
"github.com/aquasecurity/trivy/pkg/iac/providers"
19-
"github.com/aquasecurity/trivy/pkg/iac/rego/schemas"
2019
"github.com/aquasecurity/trivy/pkg/iac/scan"
2120
"github.com/aquasecurity/trivy/pkg/iac/scanners/options"
2221
"github.com/aquasecurity/trivy/pkg/iac/types"
@@ -56,8 +55,6 @@ type Scanner struct {
5655
dataFS fs.FS
5756
dataDirs []string
5857
frameworks []framework.Framework
59-
inputSchema any // unmarshalled into this from a json schema document
60-
sourceType types.Source
6158
includeDeprecatedChecks bool
6259
includeEmbeddedPolicies bool
6360
includeEmbeddedLibraries bool
@@ -88,17 +85,11 @@ type DynamicMetadata struct {
8885
EndLine int
8986
}
9087

91-
func NewScanner(source types.Source, opts ...options.ScannerOption) *Scanner {
88+
func NewScanner(opts ...options.ScannerOption) *Scanner {
9289
LoadAndRegister()
9390

94-
schema, ok := schemas.SchemaMap[source]
95-
if !ok {
96-
schema = schemas.Anything
97-
}
98-
9991
s := &Scanner{
10092
regoErrorLimit: ast.CompileErrorLimitDefault,
101-
sourceType: source,
10293
ruleNamespaces: builtinNamespaces.Clone(),
10394
runtimeValues: addRuntimeValues(),
10495
logger: log.WithPrefix("rego"),
@@ -109,12 +100,6 @@ func NewScanner(source types.Source, opts ...options.ScannerOption) *Scanner {
109100
for _, opt := range opts {
110101
opt(s)
111102
}
112-
if schema != schemas.None {
113-
err := json.Unmarshal([]byte(schema), &s.inputSchema)
114-
if err != nil {
115-
panic(err)
116-
}
117-
}
118103
return s
119104
}
120105

@@ -130,12 +115,6 @@ func (s *Scanner) runQuery(ctx context.Context, query string, input ast.Value, d
130115
rego.Trace(trace),
131116
}
132117

133-
if s.inputSchema != nil {
134-
schemaSet := ast.NewSchemaSet()
135-
schemaSet.Put(ast.MustParseRef("schema.input"), s.inputSchema)
136-
regoOptions = append(regoOptions, rego.Schemas(schemaSet))
137-
}
138-
139118
if input != nil {
140119
regoOptions = append(regoOptions, rego.ParsedInput(input))
141120
}
@@ -176,7 +155,7 @@ func GetInputsContents(inputs []Input) []any {
176155
return results
177156
}
178157

179-
func (s *Scanner) ScanInput(ctx context.Context, inputs ...Input) (scan.Results, error) {
158+
func (s *Scanner) ScanInput(ctx context.Context, sourceType types.Source, inputs ...Input) (scan.Results, error) {
180159

181160
s.logger.Debug("Scanning inputs", "count", len(inputs))
182161

@@ -210,11 +189,9 @@ func (s *Scanner) ScanInput(ctx context.Context, inputs ...Input) (scan.Results,
210189
continue // skip deprecated checks
211190
}
212191

213-
if isPolicyWithSubtype(s.sourceType) {
214-
// skip if check isn't relevant to what is being scanned
215-
if !isPolicyApplicable(staticMeta, inputs...) {
216-
continue
217-
}
192+
// skip if check isn't relevant to what is being scanned
193+
if !isPolicyApplicable(sourceType, staticMeta, inputs...) {
194+
continue
218195
}
219196

220197
if len(inputs) == 0 {
@@ -279,18 +256,28 @@ func checkSubtype(ii map[string]any, provider string, subTypes []SubType) bool {
279256
return false
280257
}
281258

282-
func isPolicyApplicable(staticMetadata *StaticMetadata, inputs ...Input) bool {
259+
func isPolicyApplicable(sourceType types.Source, staticMetadata *StaticMetadata, inputs ...Input) bool {
260+
if len(staticMetadata.InputOptions.Selectors) == 0 { // check always applies if no selectors
261+
return true
262+
}
263+
264+
for _, selector := range staticMetadata.InputOptions.Selectors {
265+
if selector.Type != string(sourceType) {
266+
return false
267+
}
268+
}
269+
270+
if !isPolicyWithSubtype(sourceType) {
271+
return true
272+
}
273+
283274
for _, input := range inputs {
284275
if ii, ok := input.Contents.(map[string]any); ok {
285276
for provider := range ii {
286277
if !supportedProviders.Contains(provider) {
287278
continue
288279
}
289280

290-
if len(staticMetadata.InputOptions.Selectors) == 0 { // check always applies if no selectors
291-
return true
292-
}
293-
294281
// check metadata for subtype
295282
for _, s := range staticMetadata.InputOptions.Selectors {
296283
if checkSubtype(ii, provider, s.Subtypes) {

0 commit comments

Comments
 (0)