@@ -18,10 +18,15 @@ limitations under the License.
18
18
package modelSource
19
19
20
20
import (
21
+ "strconv"
22
+ "strings"
21
23
"testing"
22
24
25
+ "github.com/google/go-cmp/cmp"
26
+ "github.com/google/go-cmp/cmp/cmpopts"
23
27
"github.com/stretchr/testify/assert"
24
28
corev1 "k8s.io/api/core/v1"
29
+ "k8s.io/utils/ptr"
25
30
26
31
"github.com/inftyai/llmaz/pkg"
27
32
)
@@ -32,12 +37,11 @@ func Test_ModelHubProvider_InjectModelLoader(t *testing.T) {
32
37
allowPatterns := []string {"*.gguf" , "*.json" }
33
38
ignorePatterns := []string {"*.tmp" }
34
39
35
- testCases := []struct {
36
- name string
37
- provider * ModelHubProvider
38
- index int
39
- expectMainModel bool
40
- expectEnvContains []string
40
+ tests := []struct {
41
+ name string
42
+ provider * ModelHubProvider
43
+ index int
44
+ expectMainModel bool
41
45
}{
42
46
{
43
47
name : "inject full modelhub with fileName, revision, allow/ignore" ,
@@ -52,11 +56,6 @@ func Test_ModelHubProvider_InjectModelLoader(t *testing.T) {
52
56
},
53
57
index : 0 ,
54
58
expectMainModel : true ,
55
- expectEnvContains : []string {
56
- "MODEL_SOURCE_TYPE" , "MODEL_ID" , "MODEL_HUB_NAME" , "MODEL_FILENAME" ,
57
- "REVISION" , "MODEL_ALLOW_PATTERNS" , "MODEL_IGNORE_PATTERNS" ,
58
- "HUGGING_FACE_HUB_TOKEN" , "HF_TOKEN" ,
59
- },
60
59
},
61
60
{
62
61
name : "inject with index > 0 skips volume/container mount" ,
@@ -67,15 +66,16 @@ func Test_ModelHubProvider_InjectModelLoader(t *testing.T) {
67
66
},
68
67
index : 1 ,
69
68
expectMainModel : false ,
70
- expectEnvContains : []string {
71
- "MODEL_SOURCE_TYPE" , "MODEL_ID" , "MODEL_HUB_NAME" ,
72
- "HUGGING_FACE_HUB_TOKEN" , "HF_TOKEN" ,
73
- },
74
69
},
75
70
}
76
71
77
- for _ , tc := range testCases {
78
- t .Run (tc .name , func (t * testing.T ) {
72
+ envSortOpt := cmpopts .SortSlices (func (a , b corev1.EnvVar ) bool {
73
+ return a .Name < b .Name
74
+ })
75
+
76
+ for _ , tt := range tests {
77
+ // tt := tt
78
+ t .Run (tt .name , func (t * testing.T ) {
79
79
template := & corev1.PodTemplateSpec {
80
80
Spec : corev1.PodSpec {
81
81
Containers : []corev1.Container {
@@ -89,57 +89,94 @@ func Test_ModelHubProvider_InjectModelLoader(t *testing.T) {
89
89
},
90
90
}
91
91
92
- tc .provider .InjectModelLoader (template , tc .index )
92
+ tt .provider .InjectModelLoader (template , tt .index )
93
93
94
94
assert .Len (t , template .Spec .InitContainers , 1 )
95
95
initContainer := template .Spec .InitContainers [0 ]
96
+
96
97
expectedName := MODEL_LOADER_CONTAINER_NAME
97
- if tc .index != 0 {
98
- expectedName += "-" + string ( rune ( '0' + tc .index ) )
98
+ if tt .index != 0 {
99
+ expectedName += "-" + strconv . Itoa ( tt .index )
99
100
}
100
101
assert .Equal (t , expectedName , initContainer .Name )
101
102
assert .Equal (t , pkg .LOADER_IMAGE , initContainer .Image )
102
103
103
- // Check env vars exist
104
- for _ , key := range tc .expectEnvContains {
105
- found := false
106
- for _ , env := range initContainer .Env {
107
- if env .Name == key {
108
- found = true
109
- break
110
- }
111
- }
112
- assert .True (t , found , "expected env %s not found" , key )
104
+ wantEnv := buildExpectedEnv (tt .provider )
105
+ if diff := cmp .Diff (wantEnv , initContainer .Env , envSortOpt ); diff != "" {
106
+ t .Errorf ("InitContainer.Env mismatch (-want +got):\n %s" , diff )
113
107
}
114
108
115
- // Main model should inject volume & container mount
116
- if tc .expectMainModel {
117
- // Volume should be present
118
- foundVol := false
119
- for _ , v := range template .Spec .Volumes {
120
- if v .Name == MODEL_VOLUME_NAME {
121
- foundVol = true
122
- break
123
- }
124
- }
125
- assert .True (t , foundVol , "volume not injected" )
126
-
127
- // Runner container mount should exist
128
- foundMount := false
129
- for _ , m := range template .Spec .Containers [0 ].VolumeMounts {
130
- if m .Name == MODEL_VOLUME_NAME && m .ReadOnly && m .MountPath == CONTAINER_MODEL_PATH {
131
- foundMount = true
132
- }
133
- }
134
- assert .True (t , foundMount , "volume mount not injected to runner" )
109
+ if tt .expectMainModel {
110
+ assert .True (t , hasVolume (template .Spec .Volumes , MODEL_VOLUME_NAME ), "model volume missing" )
111
+ assert .True (t , hasMount (template .Spec .Containers [0 ].VolumeMounts , MODEL_VOLUME_NAME ), "runner volumeMount missing" )
135
112
} else {
136
- // No volumes or mounts should be injected
137
- assert .Empty (t , template .Spec .Volumes )
138
- assert .Empty (t , template .Spec .Containers [0 ].VolumeMounts )
113
+ assert .Empty (t , template .Spec .Volumes , "unexpected volumes for sub-model" )
114
+ assert .Empty (t , template .Spec .Containers [0 ].VolumeMounts , "unexpected mounts for sub-model" )
139
115
}
116
+ })
117
+ }
118
+ }
119
+
120
+ func buildExpectedEnv (p * ModelHubProvider ) []corev1.EnvVar {
121
+ envs := make ([]corev1.EnvVar , 0 , 10 )
122
+
123
+ envs = append (envs , corev1.EnvVar {Name : "HTTP_PROXY" , Value : "http://1.1.1.1" })
124
+
125
+ envs = append (envs ,
126
+ corev1.EnvVar {Name : "MODEL_SOURCE_TYPE" , Value : MODEL_SOURCE_MODELHUB },
127
+ corev1.EnvVar {Name : "MODEL_ID" , Value : p .modelID },
128
+ corev1.EnvVar {Name : "MODEL_HUB_NAME" , Value : p .modelHub },
129
+ )
130
+
131
+ if p .fileName != nil {
132
+ envs = append (envs , corev1.EnvVar {Name : "MODEL_FILENAME" , Value : * p .fileName })
133
+ }
134
+ if p .modelRevision != nil {
135
+ envs = append (envs , corev1.EnvVar {Name : "REVISION" , Value : * p .modelRevision })
136
+ }
137
+ if p .modelAllowPatterns != nil {
138
+ envs = append (envs , corev1.EnvVar {
139
+ Name : "MODEL_ALLOW_PATTERNS" ,
140
+ Value : strings .Join (p .modelAllowPatterns , "," ),
141
+ })
142
+ }
143
+ if p .modelIgnorePatterns != nil {
144
+ envs = append (envs , corev1.EnvVar {
145
+ Name : "MODEL_IGNORE_PATTERNS" ,
146
+ Value : strings .Join (p .modelIgnorePatterns , "," ),
147
+ })
148
+ }
140
149
141
- // Should always carry over container env
142
- assert .Contains (t , initContainer .Env , corev1.EnvVar {Name : "HTTP_PROXY" , Value : "http://1.1.1.1" })
150
+ for _ , tokenName := range []string {"HUGGING_FACE_HUB_TOKEN" , "HF_TOKEN" } {
151
+ envs = append (envs , corev1.EnvVar {
152
+ Name : tokenName ,
153
+ ValueFrom : & corev1.EnvVarSource {
154
+ SecretKeyRef : & corev1.SecretKeySelector {
155
+ LocalObjectReference : corev1.LocalObjectReference {Name : MODELHUB_SECRET_NAME },
156
+ Key : HUGGINGFACE_TOKEN_KEY ,
157
+ Optional : ptr .To (true ),
158
+ },
159
+ },
143
160
})
144
161
}
162
+
163
+ return envs
164
+ }
165
+
166
+ func hasVolume (vols []corev1.Volume , name string ) bool {
167
+ for _ , v := range vols {
168
+ if v .Name == name {
169
+ return true
170
+ }
171
+ }
172
+ return false
173
+ }
174
+
175
+ func hasMount (mounts []corev1.VolumeMount , name string ) bool {
176
+ for _ , m := range mounts {
177
+ if m .Name == name && m .ReadOnly && m .MountPath == CONTAINER_MODEL_PATH {
178
+ return true
179
+ }
180
+ }
181
+ return false
145
182
}
0 commit comments