diff --git a/controllers/lmes/config.go b/controllers/lmes/config.go new file mode 100644 index 00000000..53b47002 --- /dev/null +++ b/controllers/lmes/config.go @@ -0,0 +1,108 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package lmes + +import ( + "fmt" + "reflect" + "strconv" + "strings" + "time" + + "github.com/go-logr/logr" + corev1 "k8s.io/api/core/v1" +) + +var options *serviceOptions = &serviceOptions{ + DriverImage: DefaultDriverImage, + PodImage: DefaultPodImage, + PodCheckingInterval: DefaultPodCheckingInterval, + ImagePullPolicy: DefaultImagePullPolicy, + MaxBatchSize: DefaultMaxBatchSize, + DetectDevice: DefaultDetectDevice, + DefaultBatchSize: DefaultBatchSize, +} + +type serviceOptions struct { + PodImage string + DriverImage string + PodCheckingInterval time.Duration + ImagePullPolicy corev1.PullPolicy + MaxBatchSize int + DefaultBatchSize int + DetectDevice bool +} + +func constructOptionsFromConfigMap(log *logr.Logger, configmap *corev1.ConfigMap) error { + + rv := reflect.ValueOf(options).Elem() + var msgs []string + + for idx, cap := 0, rv.NumField(); idx < cap; idx++ { + frv := rv.Field(idx) + fname := rv.Type().Field(idx).Name + configKey, ok := optionKeys[fname] + if !ok { + continue + } + + if v, found := configmap.Data[configKey]; found { + var err error + switch frv.Type().Name() { + case "string": + frv.SetString(v) + case "bool": + val, err := strconv.ParseBool(v) + if err != nil { + val = DefaultDetectDevice + msgs = append(msgs, fmt.Sprintf("invalid setting for %v: %v, use default setting instead", optionKeys[fname], val)) + } + frv.SetBool(val) + case "int": + var intVal int + intVal, err = strconv.Atoi(v) + if err == nil { + frv.SetInt(int64(intVal)) + } + case "Duration": + var d time.Duration + d, err = time.ParseDuration(v) + if err == nil { + frv.Set(reflect.ValueOf(d)) + } + case "PullPolicy": + if p, found := pullPolicyMap[corev1.PullPolicy(v)]; found { + frv.Set(reflect.ValueOf(p)) + } else { + err = fmt.Errorf("invalid PullPolicy") + } + default: + return fmt.Errorf("can not handle the config %v, type: %v", optionKeys[fname], frv.Type().Name()) + } + + if err != nil { + msgs = append(msgs, fmt.Sprintf("invalid setting for %v: %v, use default setting instead", optionKeys[fname], v)) + } + } + } + + if len(msgs) > 0 && log != nil { + log.Error(fmt.Errorf("some settings in the configmap are invalid"), strings.Join(msgs, "\n")) + } + + return nil +} diff --git a/controllers/lmes/driver/driver_test.go b/controllers/lmes/driver/driver_test.go index b74d039d..4d0485f8 100644 --- a/controllers/lmes/driver/driver_test.go +++ b/controllers/lmes/driver/driver_test.go @@ -54,7 +54,7 @@ func genRandomSocketPath() string { return p } -func runDirverAndWait4Complete(t *testing.T, driver Driver, returnError bool) (progressMsgs []string, results string) { +func runDriverAndWait4Complete(t *testing.T, driver Driver, returnError bool) (progressMsgs []string, results string) { go func() { if returnError { assert.NotNil(t, driver.Run()) @@ -88,7 +88,7 @@ func Test_Driver(t *testing.T) { }) assert.Nil(t, err) - runDirverAndWait4Complete(t, driver, false) + runDriverAndWait4Complete(t, driver, false) assert.Nil(t, driver.Shutdown()) assert.Nil(t, os.Remove("./stderr.log")) @@ -105,7 +105,7 @@ func Test_Wait4Shutdown(t *testing.T) { }) assert.Nil(t, err) - runDirverAndWait4Complete(t, driver, false) + runDriverAndWait4Complete(t, driver, false) // can still get the status even the user program finishes time.Sleep(time.Second * 3) @@ -132,7 +132,7 @@ func Test_ProgressUpdate(t *testing.T) { }) assert.Nil(t, err) - msgs, _ := runDirverAndWait4Complete(t, driver, false) + msgs, _ := runDriverAndWait4Complete(t, driver, false) assert.Equal(t, []string{ "initializing the evaluation job", @@ -156,7 +156,7 @@ func Test_DetectDeviceError(t *testing.T) { }) assert.Nil(t, err) - msgs, _ := runDirverAndWait4Complete(t, driver, true) + msgs, _ := runDriverAndWait4Complete(t, driver, true) assert.Equal(t, []string{ "failed to detect available device(s): exit status 1", }, msgs) @@ -216,7 +216,7 @@ func Test_TaskRecipes(t *testing.T) { }) assert.Nil(t, err) - msgs, _ := runDirverAndWait4Complete(t, driver, false) + msgs, _ := runDriverAndWait4Complete(t, driver, false) assert.Equal(t, []string{ "initializing the evaluation job", @@ -264,7 +264,7 @@ func Test_CustomCards(t *testing.T) { os.Mkdir("cards", 0750) - msgs, _ := runDirverAndWait4Complete(t, driver, false) + msgs, _ := runDriverAndWait4Complete(t, driver, false) assert.Equal(t, []string{ "initializing the evaluation job", @@ -303,7 +303,7 @@ func Test_ProgramError(t *testing.T) { }) assert.Nil(t, err) - msgs, _ := runDirverAndWait4Complete(t, driver, true) + msgs, _ := runDriverAndWait4Complete(t, driver, true) assert.Equal(t, []string{ "initializing the evaluation job", diff --git a/controllers/lmes/lmevaljob_controller.go b/controllers/lmes/lmevaljob_controller.go index 32400369..8f72c593 100644 --- a/controllers/lmes/lmevaljob_controller.go +++ b/controllers/lmes/lmevaljob_controller.go @@ -21,9 +21,7 @@ import ( "context" "fmt" "maps" - "reflect" "slices" - "strconv" "strings" "sync" "time" @@ -78,7 +76,7 @@ var ( // maintain a list of key-time pair data. // provide a function to add the key and update the time -// atomitcally and return a reconcile requeue event +// atomically and return a reconcile requeue event // if needed. type syncedMap4Reconciler struct { data map[string]time.Time @@ -92,22 +90,11 @@ type LMEvalJobReconciler struct { Recorder record.EventRecorder ConfigMap string Namespace string - options *ServiceOptions restConfig *rest.Config restClient rest.Interface pullingJobs *syncedMap4Reconciler } -type ServiceOptions struct { - PodImage string - DriverImage string - PodCheckingInterval time.Duration - ImagePullPolicy corev1.PullPolicy - MaxBatchSize int - DefaultBatchSize int - DetectDevice bool -} - // The registered function to set up LMES controller func ControllerSetUp(mgr manager.Manager, ns, configmap string, recorder record.EventRecorder) error { clientset, err := kubernetes.NewForConfig(mgr.GetConfig()) @@ -189,10 +176,10 @@ func (r *LMEvalJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) ( case lmesv1alpha1.ScheduledJobState: // the job's pod has been created and the driver hasn't updated the state yet // let's check the pod status and detect pod failure if there is - // TODO: need a timeout/retry mechanism here to transite to other states + // TODO: need a timeout/retry mechanism here to transit to other states return r.checkScheduledPod(ctx, log, job) case lmesv1alpha1.RunningJobState: - // TODO: need a timeout/retry mechanism here to transite to other states + // TODO: need a timeout/retry mechanism here to transit to other states return r.checkScheduledPod(ctx, log, job) case lmesv1alpha1.CompleteJobState: return r.handleComplete(ctx, log, job) @@ -209,6 +196,7 @@ func (r *LMEvalJobReconciler) SetupWithManager(mgr ctrl.Manager) error { // Add a runnable to retrieve the settings from the specified configmap if err := mgr.Add(manager.RunnableFunc(func(ctx context.Context) error { var cm corev1.ConfigMap + log := log.FromContext(ctx) if err := r.Get( ctx, types.NamespacedName{Namespace: r.Namespace, Name: r.ConfigMap}, @@ -221,8 +209,7 @@ func (r *LMEvalJobReconciler) SetupWithManager(mgr ctrl.Manager) error { return err } - - if err := r.constructOptionsFromConfigMap(ctx, &cm); err != nil { + if err := constructOptionsFromConfigMap(&log, &cm); err != nil { return err } @@ -324,82 +311,11 @@ func (r *LMEvalJobReconciler) remoteCommand(ctx context.Context, job *lmesv1alph return outBuff.Bytes(), errBuf.Bytes(), nil } -func (r *LMEvalJobReconciler) constructOptionsFromConfigMap( - ctx context.Context, configmap *corev1.ConfigMap) error { - r.options = &ServiceOptions{ - DriverImage: DefaultDriverImage, - PodImage: DefaultPodImage, - PodCheckingInterval: DefaultPodCheckingInterval, - ImagePullPolicy: DefaultImagePullPolicy, - MaxBatchSize: DefaultMaxBatchSize, - DetectDevice: DefaultDetectDevice, - DefaultBatchSize: DefaultBatchSize, - } - - log := log.FromContext(ctx) - rv := reflect.ValueOf(r.options).Elem() - var msgs []string - - for idx, cap := 0, rv.NumField(); idx < cap; idx++ { - frv := rv.Field(idx) - fname := rv.Type().Field(idx).Name - configKey, ok := optionKeys[fname] - if !ok { - continue - } - - if v, found := configmap.Data[configKey]; found { - var err error - switch frv.Type().Name() { - case "string": - frv.SetString(v) - case "bool": - val, err := strconv.ParseBool(v) - if err != nil { - val = DefaultDetectDevice - msgs = append(msgs, fmt.Sprintf("invalid setting for %v: %v, use default setting instead", optionKeys[fname], val)) - } - frv.SetBool(val) - case "int": - var intVal int - intVal, err = strconv.Atoi(v) - if err == nil { - frv.SetInt(int64(intVal)) - } - case "Duration": - var d time.Duration - d, err = time.ParseDuration(v) - if err == nil { - frv.Set(reflect.ValueOf(d)) - } - case "PullPolicy": - if p, found := pullPolicyMap[corev1.PullPolicy(v)]; found { - frv.Set(reflect.ValueOf(p)) - } else { - err = fmt.Errorf("invalid PullPolicy") - } - default: - return fmt.Errorf("can not handle the config %v, type: %v", optionKeys[fname], frv.Type().Name()) - } - - if err != nil { - msgs = append(msgs, fmt.Sprintf("invalid setting for %v: %v, use default setting instead", optionKeys[fname], v)) - } - } - } - - if len(msgs) > 0 { - log.Error(fmt.Errorf("some settings in the configmap are invalid"), strings.Join(msgs, "\n")) - } - - return nil -} - func (r *LMEvalJobReconciler) handleDeletion(ctx context.Context, job *lmesv1alpha1.LMEvalJob, log logr.Logger) (reconcile.Result, error) { defer r.pullingJobs.remove(string(job.GetUID())) if controllerutil.ContainsFinalizer(job, lmesv1alpha1.FinalizerName) { - // delete the correspondling pod if needed + // delete the corresponding pod if needed // remove our finalizer from the list and update it. if job.Status.State != lmesv1alpha1.CompleteJobState || job.Status.Reason != lmesv1alpha1.CancelledReason { @@ -436,7 +352,7 @@ func (r *LMEvalJobReconciler) handleNewCR(ctx context.Context, log logr.Logger, job.Name, job.Namespace)) // Since finalizers were updated. Need to fetch the new LMEvalJob - // End the current reconsile and get revisioned job in next reconsile + // End the current reconcile and get revisioned job in next reconcile return ctrl.Result{}, nil } @@ -456,7 +372,7 @@ func (r *LMEvalJobReconciler) handleNewCR(ctx context.Context, log logr.Logger, // construct a new pod and create a pod for the job currentTime := v1.Now() - pod := r.createPod(job, log) + pod := createPod(options, job, log) if err := r.Create(ctx, pod, &client.CreateOptions{}); err != nil { // Failed to create the pod. Mark the status as complete with failed job.Status.State = lmesv1alpha1.CompleteJobState @@ -483,7 +399,7 @@ func (r *LMEvalJobReconciler) handleNewCR(ctx context.Context, log logr.Logger, job.Namespace)) log.Info("Successfully create a Pod for the Job") // Check the pod after the config interval - return r.pullingJobs.addOrUpdate(string(job.GetUID()), r.options.PodCheckingInterval), nil + return r.pullingJobs.addOrUpdate(string(job.GetUID()), options.PodCheckingInterval), nil } func (r *LMEvalJobReconciler) checkScheduledPod(ctx context.Context, log logr.Logger, job *lmesv1alpha1.LMEvalJob) (ctrl.Result, error) { @@ -498,7 +414,7 @@ func (r *LMEvalJobReconciler) checkScheduledPod(ctx context.Context, log logr.Lo log.Error(err, "unable to update LMEvalJob status", "state", job.Status.State) return ctrl.Result{}, err } - r.Recorder.Event(job, "Warning", "PodMising", + r.Recorder.Event(job, "Warning", "PodMissing", fmt.Sprintf("the pod for the LMEvalJob %s in namespace %s is gone", job.Name, job.Namespace)) @@ -508,7 +424,7 @@ func (r *LMEvalJobReconciler) checkScheduledPod(ctx context.Context, log logr.Lo if mainIdx := getContainerByName(&pod.Status, "main"); mainIdx == -1 { // waiting for the main container to be up - return r.pullingJobs.addOrUpdate(string(job.GetUID()), r.options.PodCheckingInterval), nil + return r.pullingJobs.addOrUpdate(string(job.GetUID()), options.PodCheckingInterval), nil } else if podFailed, msg := isContainerFailed(&pod.Status.ContainerStatuses[mainIdx]); podFailed { job.Status.State = lmesv1alpha1.CompleteJobState job.Status.Reason = lmesv1alpha1.FailedReason @@ -519,7 +435,7 @@ func (r *LMEvalJobReconciler) checkScheduledPod(ctx context.Context, log logr.Lo log.Info("detect an error on the job's pod. marked the job as done", "name", job.Name) return ctrl.Result{}, err } else if pod.Status.ContainerStatuses[mainIdx].State.Running == nil { - return r.pullingJobs.addOrUpdate(string(job.GetUID()), r.options.PodCheckingInterval), nil + return r.pullingJobs.addOrUpdate(string(job.GetUID()), options.PodCheckingInterval), nil } // pull status from the driver @@ -530,7 +446,7 @@ func (r *LMEvalJobReconciler) checkScheduledPod(ctx context.Context, log logr.Lo if err != nil { log.Error(err, "unable to retrieve the status from the job's pod. retry after the pulling interval") } - return r.pullingJobs.addOrUpdate(string(job.GetUID()), r.options.PodCheckingInterval), nil + return r.pullingJobs.addOrUpdate(string(job.GetUID()), options.PodCheckingInterval), nil } func (r *LMEvalJobReconciler) getPod(ctx context.Context, job *lmesv1alpha1.LMEvalJob) (*corev1.Pod, error) { @@ -579,7 +495,7 @@ func (r *LMEvalJobReconciler) handleComplete(ctx context.Context, log logr.Logge // send shutdown command if the main container is running if err := r.shutdownDriver(ctx, job); err != nil { log.Error(err, "failed to shutdown the job pod. retry after the pulling interval") - return r.pullingJobs.addOrUpdate(string(job.GetUID()), r.options.PodCheckingInterval), nil + return r.pullingJobs.addOrUpdate(string(job.GetUID()), options.PodCheckingInterval), nil } } } else { @@ -617,8 +533,8 @@ func (r *LMEvalJobReconciler) handleCancel(ctx context.Context, log logr.Logger, job.Status.Reason = lmesv1alpha1.CancelledReason if err := r.deleteJobPod(ctx, job); err != nil { // leave the state as is and retry again - log.Error(err, "failed to delete pod. scheduled a retry", "interval", r.options.PodCheckingInterval.String()) - return r.pullingJobs.addOrUpdate(string(job.GetUID()), r.options.PodCheckingInterval), err + log.Error(err, "failed to delete pod. scheduled a retry", "interval", options.PodCheckingInterval.String()) + return r.pullingJobs.addOrUpdate(string(job.GetUID()), options.PodCheckingInterval), err } } @@ -658,7 +574,7 @@ func (r *LMEvalJobReconciler) validateCustomCard(job *lmesv1alpha1.LMEvalJob, lo return nil } -func (r *LMEvalJobReconciler) createPod(job *lmesv1alpha1.LMEvalJob, log logr.Logger) *corev1.Pod { +func createPod(svcOpts *serviceOptions, job *lmesv1alpha1.LMEvalJob, log logr.Logger) *corev1.Pod { var allowPrivilegeEscalation = false var runAsNonRootUser = true var ownerRefController = true @@ -712,8 +628,8 @@ func (r *LMEvalJobReconciler) createPod(job *lmesv1alpha1.LMEvalJob, log logr.Lo InitContainers: []corev1.Container{ { Name: "driver", - Image: r.options.DriverImage, - ImagePullPolicy: r.options.ImagePullPolicy, + Image: svcOpts.DriverImage, + ImagePullPolicy: svcOpts.ImagePullPolicy, Command: []string{DriverPath, "--copy", DestDriverPath}, SecurityContext: &corev1.SecurityContext{ AllowPrivilegeEscalation: &allowPrivilegeEscalation, @@ -735,11 +651,11 @@ func (r *LMEvalJobReconciler) createPod(job *lmesv1alpha1.LMEvalJob, log logr.Lo Containers: []corev1.Container{ { Name: "main", - Image: r.options.PodImage, - ImagePullPolicy: r.options.ImagePullPolicy, + Image: svcOpts.PodImage, + ImagePullPolicy: svcOpts.ImagePullPolicy, Env: envVars, - Command: r.generateCmd(job), - Args: r.generateArgs(job, log), + Command: generateCmd(svcOpts, job), + Args: generateArgs(svcOpts, job, log), SecurityContext: &corev1.SecurityContext{ AllowPrivilegeEscalation: &allowPrivilegeEscalation, RunAsUser: &runAsUser, @@ -793,13 +709,13 @@ func getResources(resources *corev1.ResourceRequirements) *corev1.ResourceRequir // Merge the map based on the filters. If the names in the `src` map contains any prefixes // in the prefixFilters list, those KV will be discarded, otherwise, KV will be merge into // `dest` map. -func mergeMapWithFilters(dest, src map[string]string, prefixFitlers []string, log logr.Logger) { - if len(prefixFitlers) == 0 { +func mergeMapWithFilters(dest, src map[string]string, prefixFilters []string, log logr.Logger) { + if len(prefixFilters) == 0 { // Fast path if the labelFilterPrefix is empty. maps.Copy(dest, src) } else { for k, v := range src { - if slices.ContainsFunc(prefixFitlers, func(prefix string) bool { + if slices.ContainsFunc(prefixFilters, func(prefix string) bool { return strings.HasPrefix(k, prefix) }) { log.Info("the label is not propagated to the pod", k, v) @@ -810,7 +726,7 @@ func mergeMapWithFilters(dest, src map[string]string, prefixFitlers []string, lo } } -func (r *LMEvalJobReconciler) generateArgs(job *lmesv1alpha1.LMEvalJob, log logr.Logger) []string { +func generateArgs(svcOpts *serviceOptions, job *lmesv1alpha1.LMEvalJob, log logr.Logger) []string { if job == nil { return nil } @@ -844,13 +760,13 @@ func (r *LMEvalJobReconciler) generateArgs(job *lmesv1alpha1.LMEvalJob, log logr cmds = append(cmds, "--log_samples") } // --batch_size - var batchSize = r.options.DefaultBatchSize + var batchSize = svcOpts.DefaultBatchSize if job.Spec.BatchSize != nil && *job.Spec.BatchSize > 0 { batchSize = *job.Spec.BatchSize } // This could be done in the webhook if it's enabled. - if batchSize > r.options.MaxBatchSize { - batchSize = r.options.MaxBatchSize + if batchSize > svcOpts.MaxBatchSize { + batchSize = svcOpts.MaxBatchSize log.Info("batchSize is greater than max-batch-size of the controller's configuration, use the max-batch-size instead") } cmds = append(cmds, "--batch_size", fmt.Sprintf("%d", batchSize)) @@ -864,13 +780,13 @@ func concatTasks(tasks lmesv1alpha1.TaskList) []string { } recipesName := make([]string, len(tasks.TaskRecipes)) for i := range tasks.TaskRecipes { - // assign internal userd task name + // assign internal used task name recipesName[i] = fmt.Sprintf("%s_%d", driver.TaskRecipePrefix, i) } return append(tasks.TaskNames, recipesName...) } -func (r *LMEvalJobReconciler) generateCmd(job *lmesv1alpha1.LMEvalJob) []string { +func generateCmd(svcOpts *serviceOptions, job *lmesv1alpha1.LMEvalJob) []string { if job == nil { return nil } @@ -879,7 +795,7 @@ func (r *LMEvalJobReconciler) generateCmd(job *lmesv1alpha1.LMEvalJob) []string "--output-path", "/opt/app-root/src/output", } - if r.options.DetectDevice { + if svcOpts.DetectDevice { cmds = append(cmds, "--detect-device") } diff --git a/controllers/lmes/lmevaljob_controller_test.go b/controllers/lmes/lmevaljob_controller_test.go index b04b21d1..daed836d 100644 --- a/controllers/lmes/lmevaljob_controller_test.go +++ b/controllers/lmes/lmevaljob_controller_test.go @@ -18,6 +18,7 @@ package lmes import ( "context" + "strconv" "testing" "github.com/stretchr/testify/assert" @@ -38,14 +39,12 @@ var ( func Test_SimplePod(t *testing.T) { log := log.FromContext(context.Background()) - lmevalRec := LMEvalJobReconciler{ - Namespace: "test", - options: &ServiceOptions{ - PodImage: "podimage:latest", - DriverImage: "driver:latest", - ImagePullPolicy: corev1.PullAlways, - }, + svcOpts := &serviceOptions{ + PodImage: "podimage:latest", + DriverImage: "driver:latest", + ImagePullPolicy: corev1.PullAlways, } + var job = &lmesv1alpha1.LMEvalJob{ ObjectMeta: metav1.ObjectMeta{ Name: "test", @@ -92,8 +91,8 @@ func Test_SimplePod(t *testing.T) { InitContainers: []corev1.Container{ { Name: "driver", - Image: lmevalRec.options.DriverImage, - ImagePullPolicy: lmevalRec.options.ImagePullPolicy, + Image: svcOpts.DriverImage, + ImagePullPolicy: svcOpts.ImagePullPolicy, Command: []string{DriverPath, "--copy", DestDriverPath}, SecurityContext: &corev1.SecurityContext{ AllowPrivilegeEscalation: &allowPrivilegeEscalation, @@ -115,10 +114,10 @@ func Test_SimplePod(t *testing.T) { Containers: []corev1.Container{ { Name: "main", - Image: lmevalRec.options.PodImage, - ImagePullPolicy: lmevalRec.options.ImagePullPolicy, - Command: lmevalRec.generateCmd(job), - Args: lmevalRec.generateArgs(job, log), + Image: svcOpts.PodImage, + ImagePullPolicy: svcOpts.ImagePullPolicy, + Command: generateCmd(svcOpts, job), + Args: generateArgs(svcOpts, job, log), SecurityContext: &corev1.SecurityContext{ AllowPrivilegeEscalation: &allowPrivilegeEscalation, RunAsUser: &runAsUser, @@ -153,20 +152,17 @@ func Test_SimplePod(t *testing.T) { }, } - newPod := lmevalRec.createPod(job, log) + newPod := createPod(svcOpts, job, log) assert.Equal(t, expect, newPod) } func Test_WithLabelsAnnotationsResourcesVolumes(t *testing.T) { log := log.FromContext(context.Background()) - lmevalRec := LMEvalJobReconciler{ - Namespace: "test", - options: &ServiceOptions{ - PodImage: "podimage:latest", - DriverImage: "driver:latest", - ImagePullPolicy: corev1.PullAlways, - }, + svcOpts := &serviceOptions{ + PodImage: "podimage:latest", + DriverImage: "driver:latest", + ImagePullPolicy: corev1.PullAlways, } var job = &lmesv1alpha1.LMEvalJob{ ObjectMeta: metav1.ObjectMeta{ @@ -210,7 +206,7 @@ func Test_WithLabelsAnnotationsResourcesVolumes(t *testing.T) { }, Volumes: []corev1.Volume{ { - Name: "addtionalVolume", + Name: "additionalVolume", VolumeSource: corev1.VolumeSource{ PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{ ClaimName: "mypvc", @@ -254,8 +250,8 @@ func Test_WithLabelsAnnotationsResourcesVolumes(t *testing.T) { InitContainers: []corev1.Container{ { Name: "driver", - Image: lmevalRec.options.DriverImage, - ImagePullPolicy: lmevalRec.options.ImagePullPolicy, + Image: svcOpts.DriverImage, + ImagePullPolicy: svcOpts.ImagePullPolicy, Command: []string{DriverPath, "--copy", DestDriverPath}, SecurityContext: &corev1.SecurityContext{ AllowPrivilegeEscalation: &allowPrivilegeEscalation, @@ -277,10 +273,10 @@ func Test_WithLabelsAnnotationsResourcesVolumes(t *testing.T) { Containers: []corev1.Container{ { Name: "main", - Image: lmevalRec.options.PodImage, - ImagePullPolicy: lmevalRec.options.ImagePullPolicy, - Command: lmevalRec.generateCmd(job), - Args: lmevalRec.generateArgs(job, log), + Image: svcOpts.PodImage, + ImagePullPolicy: svcOpts.ImagePullPolicy, + Command: generateCmd(svcOpts, job), + Args: generateArgs(svcOpts, job, log), SecurityContext: &corev1.SecurityContext{ AllowPrivilegeEscalation: &allowPrivilegeEscalation, RunAsUser: &runAsUser, @@ -320,7 +316,7 @@ func Test_WithLabelsAnnotationsResourcesVolumes(t *testing.T) { }, }, { - Name: "addtionalVolume", + Name: "additionalVolume", VolumeSource: corev1.VolumeSource{ PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{ ClaimName: "mypvc", @@ -333,7 +329,7 @@ func Test_WithLabelsAnnotationsResourcesVolumes(t *testing.T) { }, } - newPod := lmevalRec.createPod(job, log) + newPod := createPod(svcOpts, job, log) assert.Equal(t, expect, newPod) @@ -348,19 +344,16 @@ func Test_WithLabelsAnnotationsResourcesVolumes(t *testing.T) { "custom/annotation1": "annotation1", } - newPod = lmevalRec.createPod(job, log) + newPod = createPod(svcOpts, job, log) assert.Equal(t, expect, newPod) } func Test_EnvSecretsPod(t *testing.T) { log := log.FromContext(context.Background()) - lmevalRec := LMEvalJobReconciler{ - Namespace: "test", - options: &ServiceOptions{ - PodImage: "podimage:latest", - DriverImage: "driver:latest", - ImagePullPolicy: corev1.PullAlways, - }, + svcOpts := &serviceOptions{ + PodImage: "podimage:latest", + DriverImage: "driver:latest", + ImagePullPolicy: corev1.PullAlways, } var job = &lmesv1alpha1.LMEvalJob{ ObjectMeta: metav1.ObjectMeta{ @@ -425,8 +418,8 @@ func Test_EnvSecretsPod(t *testing.T) { InitContainers: []corev1.Container{ { Name: "driver", - Image: lmevalRec.options.DriverImage, - ImagePullPolicy: lmevalRec.options.ImagePullPolicy, + Image: svcOpts.DriverImage, + ImagePullPolicy: svcOpts.ImagePullPolicy, Command: []string{DriverPath, "--copy", DestDriverPath}, SecurityContext: &corev1.SecurityContext{ AllowPrivilegeEscalation: &allowPrivilegeEscalation, @@ -448,8 +441,8 @@ func Test_EnvSecretsPod(t *testing.T) { Containers: []corev1.Container{ { Name: "main", - Image: lmevalRec.options.PodImage, - ImagePullPolicy: lmevalRec.options.ImagePullPolicy, + Image: svcOpts.PodImage, + ImagePullPolicy: svcOpts.ImagePullPolicy, Env: []corev1.EnvVar{ { Name: "my_env", @@ -463,8 +456,8 @@ func Test_EnvSecretsPod(t *testing.T) { }, }, }, - Command: lmevalRec.generateCmd(job), - Args: lmevalRec.generateArgs(job, log), + Command: generateCmd(svcOpts, job), + Args: generateArgs(svcOpts, job, log), SecurityContext: &corev1.SecurityContext{ AllowPrivilegeEscalation: &allowPrivilegeEscalation, RunAsUser: &runAsUser, @@ -499,20 +492,17 @@ func Test_EnvSecretsPod(t *testing.T) { }, } - newPod := lmevalRec.createPod(job, log) + newPod := createPod(svcOpts, job, log) // maybe only verify the envs: Containers[0].Env assert.Equal(t, expect, newPod) } func Test_FileSecretsPod(t *testing.T) { log := log.FromContext(context.Background()) - lmevalRec := LMEvalJobReconciler{ - Namespace: "test", - options: &ServiceOptions{ - PodImage: "podimage:latest", - DriverImage: "driver:latest", - ImagePullPolicy: corev1.PullAlways, - }, + svcOpts := &serviceOptions{ + PodImage: "podimage:latest", + DriverImage: "driver:latest", + ImagePullPolicy: corev1.PullAlways, } var job = &lmesv1alpha1.LMEvalJob{ ObjectMeta: metav1.ObjectMeta{ @@ -587,8 +577,8 @@ func Test_FileSecretsPod(t *testing.T) { InitContainers: []corev1.Container{ { Name: "driver", - Image: lmevalRec.options.DriverImage, - ImagePullPolicy: lmevalRec.options.ImagePullPolicy, + Image: svcOpts.DriverImage, + ImagePullPolicy: svcOpts.ImagePullPolicy, Command: []string{DriverPath, "--copy", DestDriverPath}, SecurityContext: &corev1.SecurityContext{ AllowPrivilegeEscalation: &allowPrivilegeEscalation, @@ -610,10 +600,10 @@ func Test_FileSecretsPod(t *testing.T) { Containers: []corev1.Container{ { Name: "main", - Image: lmevalRec.options.PodImage, - ImagePullPolicy: lmevalRec.options.ImagePullPolicy, - Command: lmevalRec.generateCmd(job), - Args: lmevalRec.generateArgs(job, log), + Image: svcOpts.PodImage, + ImagePullPolicy: svcOpts.ImagePullPolicy, + Command: generateCmd(svcOpts, job), + Args: generateArgs(svcOpts, job, log), SecurityContext: &corev1.SecurityContext{ AllowPrivilegeEscalation: &allowPrivilegeEscalation, RunAsUser: &runAsUser, @@ -667,22 +657,19 @@ func Test_FileSecretsPod(t *testing.T) { }, } - newPod := lmevalRec.createPod(job, log) + newPod := createPod(svcOpts, job, log) // maybe only verify the envs: Containers[0].Env assert.Equal(t, expect, newPod) } func Test_GenerateArgBatchSize(t *testing.T) { log := log.FromContext(context.Background()) - lmevalRec := LMEvalJobReconciler{ - Namespace: "test", - options: &ServiceOptions{ - PodImage: "podimage:latest", - DriverImage: "driver:latest", - ImagePullPolicy: corev1.PullAlways, - MaxBatchSize: 24, - DefaultBatchSize: 8, - }, + svcOpts := &serviceOptions{ + PodImage: "podimage:latest", + DriverImage: "driver:latest", + ImagePullPolicy: corev1.PullAlways, + MaxBatchSize: 20, + DefaultBatchSize: 4, } var job = &lmesv1alpha1.LMEvalJob{ ObjectMeta: metav1.ObjectMeta{ @@ -708,16 +695,16 @@ func Test_GenerateArgBatchSize(t *testing.T) { // no batchSize in the job, use default batchSize assert.Equal(t, []string{ "sh", "-ec", - "python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2 --include_path /opt/app-root/src/my_tasks --batch_size 8", - }, lmevalRec.generateArgs(job, log)) + "python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2 --include_path /opt/app-root/src/my_tasks --batch_size " + strconv.Itoa(svcOpts.DefaultBatchSize), + }, generateArgs(svcOpts, job, log)) // exceed the max-batch-size, use max-batch-size var biggerBatchSize = 30 job.Spec.BatchSize = &biggerBatchSize assert.Equal(t, []string{ "sh", "-ec", - "python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2 --include_path /opt/app-root/src/my_tasks --batch_size 24", - }, lmevalRec.generateArgs(job, log)) + "python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2 --include_path /opt/app-root/src/my_tasks --batch_size " + strconv.Itoa(svcOpts.MaxBatchSize), + }, generateArgs(svcOpts, job, log)) // normal batchSize var normalBatchSize = 16 @@ -725,20 +712,17 @@ func Test_GenerateArgBatchSize(t *testing.T) { assert.Equal(t, []string{ "sh", "-ec", "python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2 --include_path /opt/app-root/src/my_tasks --batch_size 16", - }, lmevalRec.generateArgs(job, log)) + }, generateArgs(svcOpts, job, log)) } func Test_GenerateArgCmdTaskRecipes(t *testing.T) { log := log.FromContext(context.Background()) - lmevalRec := LMEvalJobReconciler{ - Namespace: "test", - options: &ServiceOptions{ - PodImage: "podimage:latest", - DriverImage: "driver:latest", - ImagePullPolicy: corev1.PullAlways, - DefaultBatchSize: DefaultBatchSize, - MaxBatchSize: DefaultMaxBatchSize, - }, + svcOpts := &serviceOptions{ + PodImage: "podimage:latest", + DriverImage: "driver:latest", + ImagePullPolicy: corev1.PullAlways, + MaxBatchSize: options.MaxBatchSize, + DefaultBatchSize: options.DefaultBatchSize, } var format = "unitxt.format" var numDemos = 5 @@ -778,14 +762,14 @@ func Test_GenerateArgCmdTaskRecipes(t *testing.T) { assert.Equal(t, []string{ "sh", "-ec", "python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2,tr_0 --include_path /opt/app-root/src/my_tasks --batch_size 8", - }, lmevalRec.generateArgs(job, log)) + }, generateArgs(svcOpts, job, log)) assert.Equal(t, []string{ "/opt/app-root/src/bin/driver", "--output-path", "/opt/app-root/src/output", "--task-recipe", "card=unitxt.card1,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10", "--", - }, lmevalRec.generateCmd(job)) + }, generateCmd(svcOpts, job)) job.Spec.TaskList.TaskRecipes = append(job.Spec.TaskList.TaskRecipes, lmesv1alpha1.TaskRecipe{ @@ -803,7 +787,7 @@ func Test_GenerateArgCmdTaskRecipes(t *testing.T) { assert.Equal(t, []string{ "sh", "-ec", "python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2,tr_0,tr_1 --include_path /opt/app-root/src/my_tasks --batch_size 8", - }, lmevalRec.generateArgs(job, log)) + }, generateArgs(svcOpts, job, log)) assert.Equal(t, []string{ "/opt/app-root/src/bin/driver", @@ -811,20 +795,17 @@ func Test_GenerateArgCmdTaskRecipes(t *testing.T) { "--task-recipe", "card=unitxt.card1,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10", "--task-recipe", "card=unitxt.card2,template=unitxt.template2,metrics=[unitxt.metric3,unitxt.metric4],format=unitxt.format,num_demos=5,demos_pool_size=10", "--", - }, lmevalRec.generateCmd(job)) + }, generateCmd(svcOpts, job)) } func Test_GenerateArgCmdCustomCard(t *testing.T) { log := log.FromContext(context.Background()) - lmevalRec := LMEvalJobReconciler{ - Namespace: "test", - options: &ServiceOptions{ - PodImage: "podimage:latest", - DriverImage: "driver:latest", - ImagePullPolicy: corev1.PullAlways, - DefaultBatchSize: DefaultBatchSize, - MaxBatchSize: DefaultMaxBatchSize, - }, + svcOpts := &serviceOptions{ + PodImage: "podimage:latest", + DriverImage: "driver:latest", + ImagePullPolicy: corev1.PullAlways, + MaxBatchSize: options.MaxBatchSize, + DefaultBatchSize: options.DefaultBatchSize, } var format = "unitxt.format" var numDemos = 5 @@ -849,7 +830,7 @@ func Test_GenerateArgCmdCustomCard(t *testing.T) { TaskRecipes: []lmesv1alpha1.TaskRecipe{ { Card: lmesv1alpha1.Card{ - Custom: `{ "__type__": "task_card", "loader": { "__type__": "load_hf", "path": "wmt16", "name": "de-en" }, "preprocess_steps": [ { "__type__": "copy", "field": "translation/en", "to_field": "text" }, { "__type__": "copy", "field": "translation/de", "to_field": "translation" }, { "__type__": "set", "fields": { "source_language": "english", "target_language": "deutch" } } ], "task": "tasks.translation.directed", "templates": "templates.translation.directed.all" }`, + Custom: `{ "__type__": "task_card", "loader": { "__type__": "load_hf", "path": "wmt16", "name": "de-en" }, "preprocess_steps": [ { "__type__": "copy", "field": "translation/en", "to_field": "text" }, { "__type__": "copy", "field": "translation/de", "to_field": "translation" }, { "__type__": "set", "fields": { "source_language": "english", "target_language": "dutch" } } ], "task": "tasks.translation.directed", "templates": "templates.translation.directed.all" }`, }, Template: "unitxt.template", Format: &format, @@ -865,26 +846,21 @@ func Test_GenerateArgCmdCustomCard(t *testing.T) { assert.Equal(t, []string{ "sh", "-ec", "python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2,tr_0 --include_path /opt/app-root/src/my_tasks --batch_size 8", - }, lmevalRec.generateArgs(job, log)) + }, generateArgs(svcOpts, job, log)) assert.Equal(t, []string{ "/opt/app-root/src/bin/driver", "--output-path", "/opt/app-root/src/output", "--task-recipe", "card=cards.custom_0,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10", - "--custom-card", `{ "__type__": "task_card", "loader": { "__type__": "load_hf", "path": "wmt16", "name": "de-en" }, "preprocess_steps": [ { "__type__": "copy", "field": "translation/en", "to_field": "text" }, { "__type__": "copy", "field": "translation/de", "to_field": "translation" }, { "__type__": "set", "fields": { "source_language": "english", "target_language": "deutch" } } ], "task": "tasks.translation.directed", "templates": "templates.translation.directed.all" }`, + "--custom-card", `{ "__type__": "task_card", "loader": { "__type__": "load_hf", "path": "wmt16", "name": "de-en" }, "preprocess_steps": [ { "__type__": "copy", "field": "translation/en", "to_field": "text" }, { "__type__": "copy", "field": "translation/de", "to_field": "translation" }, { "__type__": "set", "fields": { "source_language": "english", "target_language": "dutch" } } ], "task": "tasks.translation.directed", "templates": "templates.translation.directed.all" }`, "--", - }, lmevalRec.generateCmd(job)) + }, generateCmd(svcOpts, job)) } func Test_CustomCardValidation(t *testing.T) { log := log.FromContext(context.Background()) lmevalRec := LMEvalJobReconciler{ Namespace: "test", - options: &ServiceOptions{ - PodImage: "podimage:latest", - DriverImage: "driver:latest", - ImagePullPolicy: corev1.PullAlways, - }, } var job = &lmesv1alpha1.LMEvalJob{ ObjectMeta: metav1.ObjectMeta{ @@ -934,7 +910,7 @@ func Test_CustomCardValidation(t *testing.T) { "__type__": "set", "fields": { "source_language": "english", - "target_language": "deutch" + "target_language": "dutch" } } ], @@ -967,7 +943,7 @@ func Test_CustomCardValidation(t *testing.T) { "__type__": "set", "fields": { "source_language": "english", - "target_language": "deutch" + "target_language": "dutch" } } ],