Skip to content

Commit 55ad4cb

Browse files
committed
Fix missing ReplicaIndexLabel when using RunLauncherAsWorker
Signed-off-by: GonzaloSaez <[email protected]>
1 parent 5beeaf0 commit 55ad4cb

File tree

2 files changed

+22
-5
lines changed

2 files changed

+22
-5
lines changed

pkg/controller/mpi_job_controller.go

+20-4
Original file line numberDiff line numberDiff line change
@@ -1277,7 +1277,7 @@ func newConfigMap(mpiJob *kubeflow.MPIJob, workerReplicas int32) *corev1.ConfigM
12771277
// note that pod.spec.dnsConfig also affect the svc resolution
12781278
// ref: https://kubernetes.io/docs/concepts/services-networking/dns-pod-service/
12791279
// launcher can be reach with hostname or service name
1280-
if ptr.Deref(mpiJob.Spec.RunLauncherAsWorker, false) {
1280+
if runLauncherAsWorker(mpiJob) {
12811281
name := mpiJob.Name + launcherSuffix
12821282
switch mpiJob.Spec.MPIImplementation {
12831283
case kubeflow.MPIImplementationOpenMPI:
@@ -1325,7 +1325,7 @@ func updateDiscoverHostsInConfigMap(configMap *corev1.ConfigMap, mpiJob *kubeflo
13251325
buffer.WriteString("#!/bin/sh\n")
13261326

13271327
// We don't check if launcher is running here, launcher should always be there or the job failed
1328-
if ptr.Deref(mpiJob.Spec.RunLauncherAsWorker, false) {
1328+
if runLauncherAsWorker(mpiJob) {
13291329
name := mpiJob.Name + launcherSuffix
13301330
buffer.WriteString(fmt.Sprintf("echo %s.%s.%s.svc\n", name, mpiJob.Name, mpiJob.Namespace))
13311331
}
@@ -1408,6 +1408,19 @@ func workerName(mpiJob *kubeflow.MPIJob, index int) string {
14081408
return fmt.Sprintf("%s%s-%d", mpiJob.Name, workerSuffix, index)
14091409
}
14101410

1411+
func runLauncherAsWorker(mpiJob *kubeflow.MPIJob) bool {
1412+
return ptr.Deref(mpiJob.Spec.RunLauncherAsWorker, false)
1413+
}
1414+
1415+
func workerReplicaIndexLabel(mpiJob *kubeflow.MPIJob, index int) string {
1416+
// When running the launcher as a worker, some integrations such as Kueue's TAS, require all pods in the PodGroup
1417+
// to have a valid and unique index label. That's why we have to pad by one.
1418+
if runLauncherAsWorker(mpiJob) {
1419+
return strconv.Itoa(index + 1)
1420+
}
1421+
return strconv.Itoa(index)
1422+
}
1423+
14111424
// newWorker creates a new worker Pod for an MPIJob resource. It also
14121425
// sets the appropriate OwnerReferences on the resource so handleObject can
14131426
// discover the MPIJob resource that 'owns' it.
@@ -1423,7 +1436,7 @@ func (c *MPIJobController) newWorker(mpiJob *kubeflow.MPIJob, index int) *corev1
14231436
for key, value := range defaultLabels(mpiJob.Name, worker) {
14241437
podTemplate.Labels[key] = value
14251438
}
1426-
podTemplate.Labels[kubeflow.ReplicaIndexLabel] = strconv.Itoa(index)
1439+
podTemplate.Labels[kubeflow.ReplicaIndexLabel] = workerReplicaIndexLabel(mpiJob, index)
14271440
podTemplate.Spec.Hostname = name
14281441
podTemplate.Spec.Subdomain = mpiJob.Name // Matches job' Service name.
14291442
if podTemplate.Spec.HostNetwork {
@@ -1509,6 +1522,9 @@ func (c *MPIJobController) newLauncherPodTemplate(mpiJob *kubeflow.MPIJob) corev
15091522
if c.PodGroupCtrl != nil {
15101523
c.PodGroupCtrl.decoratePodTemplateSpec(podTemplate, mpiJob.Name)
15111524
}
1525+
if runLauncherAsWorker(mpiJob) {
1526+
podTemplate.Labels[kubeflow.ReplicaIndexLabel] = "0"
1527+
}
15121528
podTemplate.Spec.Hostname = launcherName
15131529
podTemplate.Spec.Subdomain = mpiJob.Name // Matches job' Service name.
15141530
if podTemplate.Spec.HostNetwork {
@@ -1535,7 +1551,7 @@ func (c *MPIJobController) newLauncherPodTemplate(mpiJob *kubeflow.MPIJob) corev
15351551
case kubeflow.MPIImplementationMPICH:
15361552
container.Env = append(container.Env, mpichEnvVars...)
15371553
}
1538-
if !ptr.Deref(mpiJob.Spec.RunLauncherAsWorker, false) {
1554+
if !runLauncherAsWorker(mpiJob) {
15391555
container.Env = append(container.Env,
15401556
// We overwrite these environment variables so that users will not
15411557
// be mistakenly using GPU resources for launcher due to potential

pkg/controller/mpi_job_controller_test.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -1391,6 +1391,7 @@ func TestNewLauncherAndWorker(t *testing.T) {
13911391
kubeflow.OperatorNameLabel: kubeflow.OperatorName,
13921392
kubeflow.JobNameLabel: "foo",
13931393
kubeflow.JobRoleLabel: "launcher",
1394+
kubeflow.ReplicaIndexLabel: "0",
13941395
},
13951396
},
13961397
Spec: corev1.PodSpec{
@@ -1445,7 +1446,7 @@ func TestNewLauncherAndWorker(t *testing.T) {
14451446
kubeflow.OperatorNameLabel: kubeflow.OperatorName,
14461447
kubeflow.JobNameLabel: "foo",
14471448
kubeflow.JobRoleLabel: "worker",
1448-
kubeflow.ReplicaIndexLabel: "0",
1449+
kubeflow.ReplicaIndexLabel: "1",
14491450
},
14501451
},
14511452
Spec: corev1.PodSpec{

0 commit comments

Comments
 (0)