@@ -18,10 +18,15 @@ limitations under the License.
18
18
package modelSource
19
19
20
20
import (
21
+ "strconv"
22
+ "strings"
21
23
"testing"
22
24
25
+ "github.com/google/go-cmp/cmp"
26
+ "github.com/google/go-cmp/cmp/cmpopts"
23
27
"github.com/stretchr/testify/assert"
24
28
corev1 "k8s.io/api/core/v1"
29
+ "k8s.io/utils/ptr"
25
30
26
31
"github.com/inftyai/llmaz/pkg"
27
32
)
@@ -32,12 +37,11 @@ func Test_ModelHubProvider_InjectModelLoader(t *testing.T) {
32
37
allowPatterns := []string {"*.gguf" , "*.json" }
33
38
ignorePatterns := []string {"*.tmp" }
34
39
35
- testCases := []struct {
36
- name string
37
- provider * ModelHubProvider
38
- index int
39
- expectMainModel bool
40
- expectEnvContains []string
40
+ tests := []struct {
41
+ name string
42
+ provider * ModelHubProvider
43
+ index int
44
+ expectMainModel bool
41
45
}{
42
46
{
43
47
name : "inject full modelhub with fileName, revision, allow/ignore" ,
@@ -52,11 +56,6 @@ func Test_ModelHubProvider_InjectModelLoader(t *testing.T) {
52
56
},
53
57
index : 0 ,
54
58
expectMainModel : true ,
55
- expectEnvContains : []string {
56
- "MODEL_SOURCE_TYPE" , "MODEL_ID" , "MODEL_HUB_NAME" , "MODEL_FILENAME" ,
57
- "REVISION" , "MODEL_ALLOW_PATTERNS" , "MODEL_IGNORE_PATTERNS" ,
58
- "HUGGING_FACE_HUB_TOKEN" , "HF_TOKEN" ,
59
- },
60
59
},
61
60
{
62
61
name : "inject with index > 0 skips volume/container mount" ,
@@ -67,15 +66,15 @@ func Test_ModelHubProvider_InjectModelLoader(t *testing.T) {
67
66
},
68
67
index : 1 ,
69
68
expectMainModel : false ,
70
- expectEnvContains : []string {
71
- "MODEL_SOURCE_TYPE" , "MODEL_ID" , "MODEL_HUB_NAME" ,
72
- "HUGGING_FACE_HUB_TOKEN" , "HF_TOKEN" ,
73
- },
74
69
},
75
70
}
76
71
77
- for _ , tc := range testCases {
78
- t .Run (tc .name , func (t * testing.T ) {
72
+ envSortOpt := cmpopts .SortSlices (func (a , b corev1.EnvVar ) bool {
73
+ return a .Name < b .Name
74
+ })
75
+
76
+ for _ , tt := range tests {
77
+ t .Run (tt .name , func (t * testing.T ) {
79
78
template := & corev1.PodTemplateSpec {
80
79
Spec : corev1.PodSpec {
81
80
Containers : []corev1.Container {
@@ -89,57 +88,68 @@ func Test_ModelHubProvider_InjectModelLoader(t *testing.T) {
89
88
},
90
89
}
91
90
92
- tc .provider .InjectModelLoader (template , tc .index )
91
+ tt .provider .InjectModelLoader (template , tt .index )
93
92
94
93
assert .Len (t , template .Spec .InitContainers , 1 )
95
94
initContainer := template .Spec .InitContainers [0 ]
95
+
96
96
expectedName := MODEL_LOADER_CONTAINER_NAME
97
- if tc .index != 0 {
98
- expectedName += "-" + string ( rune ( '0' + tc .index ) )
97
+ if tt .index != 0 {
98
+ expectedName += "-" + strconv . Itoa ( tt .index )
99
99
}
100
100
assert .Equal (t , expectedName , initContainer .Name )
101
101
assert .Equal (t , pkg .LOADER_IMAGE , initContainer .Image )
102
102
103
- // Check env vars exist
104
- for _ , key := range tc .expectEnvContains {
105
- found := false
106
- for _ , env := range initContainer .Env {
107
- if env .Name == key {
108
- found = true
109
- break
110
- }
111
- }
112
- assert .True (t , found , "expected env %s not found" , key )
103
+ wantEnv := buildExpectedEnv (tt .provider )
104
+ if diff := cmp .Diff (wantEnv , initContainer .Env , envSortOpt ); diff != "" {
105
+ t .Errorf ("InitContainer.Env mismatch (-want +got):\n %s" , diff )
113
106
}
107
+ })
108
+ }
109
+ }
114
110
115
- // Main model should inject volume & container mount
116
- if tc .expectMainModel {
117
- // Volume should be present
118
- foundVol := false
119
- for _ , v := range template .Spec .Volumes {
120
- if v .Name == MODEL_VOLUME_NAME {
121
- foundVol = true
122
- break
123
- }
124
- }
125
- assert .True (t , foundVol , "volume not injected" )
126
-
127
- // Runner container mount should exist
128
- foundMount := false
129
- for _ , m := range template .Spec .Containers [0 ].VolumeMounts {
130
- if m .Name == MODEL_VOLUME_NAME && m .ReadOnly && m .MountPath == CONTAINER_MODEL_PATH {
131
- foundMount = true
132
- }
133
- }
134
- assert .True (t , foundMount , "volume mount not injected to runner" )
135
- } else {
136
- // No volumes or mounts should be injected
137
- assert .Empty (t , template .Spec .Volumes )
138
- assert .Empty (t , template .Spec .Containers [0 ].VolumeMounts )
139
- }
111
+ func buildExpectedEnv (p * ModelHubProvider ) []corev1.EnvVar {
112
+ envs := make ([]corev1.EnvVar , 0 , 10 )
113
+
114
+ envs = append (envs , corev1.EnvVar {Name : "HTTP_PROXY" , Value : "http://1.1.1.1" })
115
+
116
+ envs = append (envs ,
117
+ corev1.EnvVar {Name : "MODEL_SOURCE_TYPE" , Value : MODEL_SOURCE_MODELHUB },
118
+ corev1.EnvVar {Name : "MODEL_ID" , Value : p .modelID },
119
+ corev1.EnvVar {Name : "MODEL_HUB_NAME" , Value : p .modelHub },
120
+ )
140
121
141
- // Should always carry over container env
142
- assert .Contains (t , initContainer .Env , corev1.EnvVar {Name : "HTTP_PROXY" , Value : "http://1.1.1.1" })
122
+ if p .fileName != nil {
123
+ envs = append (envs , corev1.EnvVar {Name : "MODEL_FILENAME" , Value : * p .fileName })
124
+ }
125
+ if p .modelRevision != nil {
126
+ envs = append (envs , corev1.EnvVar {Name : "REVISION" , Value : * p .modelRevision })
127
+ }
128
+ if p .modelAllowPatterns != nil {
129
+ envs = append (envs , corev1.EnvVar {
130
+ Name : "MODEL_ALLOW_PATTERNS" ,
131
+ Value : strings .Join (p .modelAllowPatterns , "," ),
132
+ })
133
+ }
134
+ if p .modelIgnorePatterns != nil {
135
+ envs = append (envs , corev1.EnvVar {
136
+ Name : "MODEL_IGNORE_PATTERNS" ,
137
+ Value : strings .Join (p .modelIgnorePatterns , "," ),
138
+ })
139
+ }
140
+
141
+ for _ , tokenName := range []string {"HUGGING_FACE_HUB_TOKEN" , "HF_TOKEN" } {
142
+ envs = append (envs , corev1.EnvVar {
143
+ Name : tokenName ,
144
+ ValueFrom : & corev1.EnvVarSource {
145
+ SecretKeyRef : & corev1.SecretKeySelector {
146
+ LocalObjectReference : corev1.LocalObjectReference {Name : MODELHUB_SECRET_NAME },
147
+ Key : HUGGING_FACE_TOKEN_KEY ,
148
+ Optional : ptr .To (true ),
149
+ },
150
+ },
143
151
})
144
152
}
153
+
154
+ return envs
145
155
}
0 commit comments