Skip to content

Commit d1d937a

Browse files
test: added get headers builtin tests (#434)
* test: added get headers builtin tests * refactor: comments+indent
1 parent f197630 commit d1d937a

File tree

2 files changed

+104
-11
lines changed

2 files changed

+104
-11
lines changed

custom_builtins/get_header.go

+13-11
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,17 @@ var GetHeaderFunction = rego.Function2(
4040
Name: GetHeaderDecl.Name,
4141
Decl: GetHeaderDecl.Decl,
4242
},
43-
func(_ rego.BuiltinContext, a, b *ast.Term) (*ast.Term, error) {
44-
var headerKey string
45-
var headers http.Header
46-
if err := ast.As(a.Value, &headerKey); err != nil {
47-
return nil, err
48-
}
49-
if err := ast.As(b.Value, &headers); err != nil {
50-
return nil, err
51-
}
52-
return ast.StringTerm(headers.Get(headerKey)), nil
53-
},
43+
getHeaderDefinition,
5444
)
45+
46+
func getHeaderDefinition(_ rego.BuiltinContext, a, b *ast.Term) (*ast.Term, error) {
47+
var headerKey string
48+
var headers http.Header
49+
if err := ast.As(a.Value, &headerKey); err != nil {
50+
return nil, err
51+
}
52+
if err := ast.As(b.Value, &headers); err != nil {
53+
return nil, err
54+
}
55+
return ast.StringTerm(headers.Get(headerKey)), nil
56+
}

custom_builtins/get_header_test.go

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
// Copyright 2025 Mia srl
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package custom_builtins
16+
17+
import (
18+
"encoding/json"
19+
"net/textproto"
20+
"testing"
21+
22+
"github.com/open-policy-agent/opa/ast"
23+
"github.com/open-policy-agent/opa/rego"
24+
"github.com/stretchr/testify/require"
25+
)
26+
27+
func TestGetHeader(t *testing.T) {
28+
29+
buildMapTerm := func(t *testing.T, data map[string][]string) *ast.Term {
30+
t.Helper()
31+
32+
terms := make([][2]*ast.Term, 0, len(data))
33+
34+
for key, values := range data {
35+
36+
valueTermArray := make([]*ast.Term, len(values))
37+
for i, val := range values {
38+
valueTermArray[i] = ast.StringTerm(val)
39+
}
40+
41+
terms = append(terms, [2]*ast.Term{
42+
// NOTE: textproto.CanonicalMIMEHeaderKey is used to canonicalize
43+
// the key in the same way the net/http package does.
44+
ast.StringTerm(textproto.CanonicalMIMEHeaderKey(key)),
45+
{Value: ast.NewArray(valueTermArray...)},
46+
})
47+
}
48+
49+
return &ast.Term{Value: ast.NewObject(terms...)}
50+
}
51+
52+
t.Run("GetHeader", func(t *testing.T) {
53+
t.Run("returns correct value when header exists", func(t *testing.T) {
54+
foundTerm, err := getHeaderDefinition(
55+
rego.BuiltinContext{},
56+
ast.StringTerm("X-My-Header"), // search with canonicalized key
57+
buildMapTerm(t, map[string][]string{"X-My-Header": {"value"}}),
58+
)
59+
require.NoError(t, err)
60+
require.Equal(t, ast.StringTerm("value"), foundTerm)
61+
})
62+
63+
t.Run("returns correct value when header exists case-insensitive", func(t *testing.T) {
64+
foundTerm, err := getHeaderDefinition(
65+
rego.BuiltinContext{},
66+
ast.StringTerm("x-my-header"), // search with lower case header key
67+
buildMapTerm(t, map[string][]string{"X-My-Header": {"value"}}),
68+
)
69+
require.NoError(t, err)
70+
require.Equal(t, ast.StringTerm("value"), foundTerm)
71+
})
72+
73+
t.Run("returns error on invalid key", func(t *testing.T) {
74+
_, err := getHeaderDefinition(
75+
rego.BuiltinContext{},
76+
ast.NumberTerm(json.Number("42")),
77+
buildMapTerm(t, map[string][]string{}),
78+
)
79+
require.Error(t, err)
80+
})
81+
82+
t.Run("returns error on invalid headers map", func(t *testing.T) {
83+
_, err := getHeaderDefinition(
84+
rego.BuiltinContext{},
85+
ast.StringTerm("x-my-header"),
86+
ast.BooleanTerm(false),
87+
)
88+
require.Error(t, err)
89+
})
90+
})
91+
}

0 commit comments

Comments
 (0)