Skip to content

Commit 2488697

Browse files
ryanaolearyArthurKamalov
authored andcommitted
[TPU Webhook] Fix KubeRay headless worker svc truncation bug (GoogleCloudPlatform#963)
Fix headless service truncation bug Signed-off-by: Ryan O'Leary <[email protected]>
1 parent 9895e42 commit 2488697

File tree

2 files changed

+75
-22
lines changed

2 files changed

+75
-22
lines changed

ray-on-gke/tpu/kuberay-tpu-webhook/main.go

+15-4
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import (
3333
"time"
3434

3535
ray "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
36+
utils "github.com/ray-project/kuberay/ray-operator/controllers/ray/utils"
3637
admissionv1 "k8s.io/api/admission/v1"
3738
corev1 "k8s.io/api/core/v1"
3839
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
@@ -244,16 +245,26 @@ func extractRayCluster(admissionReview *admissionv1.AdmissionReview) (*ray.RayCl
244245
return &rayCluster, nil
245246
}
246247

248+
// generateHeadlessServiceName returns the expected TPU headless service name for a RayCluster
249+
func generateHeadlessServiceName(clusterName string) string {
250+
serviceName := fmt.Sprintf("%s-%s", clusterName, headlessServiceSuffix)
251+
252+
// Apply the same truncation as in the RayCluster controller when generating the headless service
253+
// name. This is to maintain the up-to 63 char compatibility guarantee for hostnames (RFC 1123).
254+
return utils.CheckName(serviceName)
255+
}
256+
247257
// genDNSHostnames returns list of DNS hostnames for TPU VM hosts as a string
248258
func genDNSHostnames(numOfHosts int32, groupName string, clusterName string, namespace string, replicaIndex int) (string, error) {
249259
if numOfHosts == 0 {
250260
err := errors.New("workerGroupSpec NumOfHosts not set")
251261
return "", err
252262
}
263+
headlessServiceName := generateHeadlessServiceName(clusterName)
253264
hostNames := make([]string, numOfHosts)
254-
// Host names will be of the form {WORKER_GROUP_NAME}-{REPLICA_INDEX}-{HOST_INDEX}.headless-worker-svc
265+
// Host names will be of the form {WORKER_GROUP_NAME}-{REPLICA_INDEX}-{HOST_INDEX}.{CLUSTER_NAME}-headless-worker-svc
255266
for j := 0; j < int(numOfHosts); j++ {
256-
hostNames[j] = fmt.Sprintf("%s-%d-%d.%s-%s", groupName, replicaIndex, j, clusterName, headlessServiceSuffix)
267+
hostNames[j] = fmt.Sprintf("%s-%d-%d.%s", groupName, replicaIndex, j, headlessServiceName)
257268
}
258269
klog.V(1).InfoS("genDNSHostnames", "RayCluster", namespace+"/"+clusterName, "NumOfHosts", numOfHosts, "Replica Index", replicaIndex)
259270
return strings.Join(hostNames, ","), nil
@@ -268,7 +279,7 @@ func injectHostnames(clusterName string, hostNames string, envPath string, conta
268279
Value: hostNames,
269280
}
270281
subdomainPatch["path"] = subdomainPath
271-
subdomainPatch["value"] = fmt.Sprintf("%s-%s", clusterName, headlessServiceSuffix)
282+
subdomainPatch["value"] = generateHeadlessServiceName(clusterName)
272283
// create new EnvVar array if container.Env is empty, and append hostnames if not
273284
if len(container.Env) == 0 {
274285
hostNamesPatch["path"] = envPath
@@ -678,7 +689,7 @@ func (t *TPUWebhookServer) mutatePod(admissionReview *admissionv1.AdmissionRevie
678689
return nil, err
679690
}
680691
klog.V(1).InfoS("mutatePod", "RayCluster", namespace+"/"+clusterName, "TPU_WORKER_HOSTNAMES", hostnames)
681-
klog.V(1).InfoS("mutatePod", "RayCluster", namespace+"/"+clusterName, "subdomain", clusterName+"-"+headlessServiceSuffix)
692+
klog.V(1).InfoS("mutatePod", "RayCluster", namespace+"/"+clusterName, "subdomain", generateHeadlessServiceName(clusterName))
682693
injectHostnames(clusterName, hostnames, path, container, &patches)
683694
}
684695
// inject TPU_WORKER_ID

ray-on-gke/tpu/kuberay-tpu-webhook/webhook_main_test.go

+60-18
Original file line numberDiff line numberDiff line change
@@ -778,26 +778,30 @@ func Test_ExtractRayCluster(t *testing.T) {
778778

779779
func Test_GenDNSHostnames(t *testing.T) {
780780
tests := map[string]struct {
781+
clusterName string
781782
replicaIndex int
782783
numOfHosts int32
783784
expectedHostnames string
784785
expectedError error
785786
}{
786787
"genDNSHostnames with NumOfHosts == 0": {
787788
// a workergroup can't have NumOfHosts set to 0 so this should error out
789+
clusterName: "test-cluster",
788790
replicaIndex: 0,
789791
numOfHosts: int32(0),
790792
expectedError: errors.New("workerGroupSpec NumOfHosts not set"),
791793
},
792794
"genDNSHostnames with NumOfHosts == 1": {
793795
// Single-host worker group, should return a single DNS hostname. This function will
794796
// never be called for single-host groups, but we don't necessarily want it to error if it does.
797+
clusterName: "test-cluster",
795798
replicaIndex: 0,
796799
numOfHosts: int32(1),
797800
expectedHostnames: fmt.Sprintf("%s-%d-%d.%s-%s", "test-group", 0, 0, "test-cluster", headlessServiceSuffix),
798801
},
799802
"genDNSHostnames with NumOfHosts > 1": {
800803
// multi-host worker group, should return a string list of DNS hostnames for the given replica
804+
clusterName: "test-cluster",
801805
replicaIndex: 1,
802806
numOfHosts: int32(4),
803807
expectedHostnames: strings.Join([]string{fmt.Sprintf("%s-%d-%d.%s-%s", "test-group", 1, 0, "test-cluster", headlessServiceSuffix),
@@ -806,12 +810,21 @@ func Test_GenDNSHostnames(t *testing.T) {
806810
fmt.Sprintf("%s-%d-%d.%s-%s", "test-group", 1, 3, "test-cluster", headlessServiceSuffix),
807811
}, ","),
808812
},
813+
"genDNSHostnames with long RayCluster name": {
814+
// Multi-host worker group in a RayCluster with a name that will be truncated
815+
clusterName: "long-raycluster-name-to-be-truncated",
816+
replicaIndex: 1,
817+
numOfHosts: int32(2),
818+
expectedHostnames: strings.Join([]string{fmt.Sprintf("%s-%d-%d.%s-%s", "test-group", 1, 0, "aycluster-name-to-be-truncated", headlessServiceSuffix),
819+
fmt.Sprintf("%s-%d-%d.%s-%s", "test-group", 1, 1, "aycluster-name-to-be-truncated", headlessServiceSuffix),
820+
}, ","),
821+
},
809822
}
810823

811824
// validate that genDNSHostnames correctly returns a string list of DNS addressable hostnames
812825
for name, tc := range tests {
813826
t.Run(name, func(t *testing.T) {
814-
hostnames, err := genDNSHostnames(tc.numOfHosts, "test-group", "test-cluster", "test-namespace", tc.replicaIndex)
827+
hostnames, err := genDNSHostnames(tc.numOfHosts, "test-group", tc.clusterName, "test-namespace", tc.replicaIndex)
815828
if err != nil {
816829
assert.Equal(t, tc.expectedError, err)
817830
} else {
@@ -823,21 +836,15 @@ func Test_GenDNSHostnames(t *testing.T) {
823836

824837
func Test_InjectHostnames(t *testing.T) {
825838
tests := map[string]struct {
826-
numOfHosts int
839+
clusterName string
827840
groupName string
828841
expectedSubdomain string
829842
expectedHostnames string
830843
}{
831-
"injectHostnames for single-host worker group": {
832-
// should create a patch to set the subdomain and a single TPU_WORKER_HOSTNAMES DNS hostname
833-
numOfHosts: 1,
834-
groupName: "test-group-name",
835-
expectedSubdomain: fmt.Sprintf("%s-%s", "test-cluster", headlessServiceSuffix),
836-
expectedHostnames: fmt.Sprintf("%s-%d-%d.%s-%s", "test-group", 0, 0, "test-cluster", headlessServiceSuffix),
837-
},
838844
"injectHostnames for multi-host worker group": {
839-
// should create a patch to set the subdomain and TPU_WORKER_HOSTNAMES for all hosts
840-
numOfHosts: 1,
845+
// Should create a patch to set the subdomain and TPU_WORKER_HOSTNAMES for all hosts.
846+
// This function is only called for multi-host TPU worker groups.
847+
clusterName: "test-cluster",
841848
groupName: "test-group-name",
842849
expectedSubdomain: fmt.Sprintf("%s-%s", "test-cluster", headlessServiceSuffix),
843850
expectedHostnames: strings.Join([]string{fmt.Sprintf("%s-%d-%d.%s-%s", "test-group", 1, 0, "test-cluster", headlessServiceSuffix),
@@ -846,21 +853,33 @@ func Test_InjectHostnames(t *testing.T) {
846853
fmt.Sprintf("%s-%d-%d.%s-%s", "test-group", 1, 3, "test-cluster", headlessServiceSuffix),
847854
}, ","),
848855
},
856+
"injectHostnames for multi-host worker group with truncated service name": {
857+
// Should create a patch to set the subdomain and TPU_WORKER_HOSTNAMES for all hosts, with the
858+
// correct subdomain truncated to match the created service name.
859+
clusterName: "extremely-long-test-raycluster-name",
860+
groupName: "test-group-name",
861+
expectedSubdomain: fmt.Sprintf("%s-%s", "mely-long-test-raycluster-name", headlessServiceSuffix),
862+
expectedHostnames: strings.Join([]string{fmt.Sprintf("%s-%d-%d.%s-%s", "test-group", 1, 0, "mely-long-test-raycluster-name", headlessServiceSuffix),
863+
fmt.Sprintf("%s-%d-%d.%s-%s", "test-group", 1, 1, "mely-long-test-raycluster-name", headlessServiceSuffix),
864+
fmt.Sprintf("%s-%d-%d.%s-%s", "test-group", 1, 2, "mely-long-test-raycluster-name", headlessServiceSuffix),
865+
fmt.Sprintf("%s-%d-%d.%s-%s", "test-group", 1, 3, "mely-long-test-raycluster-name", headlessServiceSuffix),
866+
}, ","),
867+
},
849868
}
850869

851870
// check that a valid subdomain and TPU_WORKER_HOSTNAMES are injected into the Pod
852871
for name, tc := range tests {
853872
t.Run(name, func(t *testing.T) {
854-
testPod := getTestTPUWorker("test-cluster", "test-group", "test-namespace", "tpu-v4-podslice", "2x2x1", "4")
873+
testPod := getTestTPUWorker(tc.clusterName, tc.groupName, "test-namespace", "tpu-v4-podslice", "2x2x2", "4")
855874
expectedEnv := []corev1.EnvVar{corev1.EnvVar{Name: "TPU_WORKER_HOSTNAMES", Value: tc.expectedHostnames}}
856-
expectedPatches := []patch{}
857-
injectHostnames("test-cluster", tc.expectedHostnames, "/spec/containers/0/env", testPod.Spec.Containers[0], &expectedPatches)
875+
patches := []patch{}
876+
injectHostnames(tc.clusterName, tc.expectedHostnames, "/spec/containers/0/env", testPod.Spec.Containers[0], &patches)
858877
// check subdomain patch
859-
assert.Equal(t, "/spec/subdomain", expectedPatches[0]["path"])
860-
assert.Equal(t, tc.expectedSubdomain, expectedPatches[0]["value"])
878+
assert.Equal(t, "/spec/subdomain", patches[0]["path"])
879+
assert.Equal(t, tc.expectedSubdomain, patches[0]["value"])
861880
// check hostnames patch
862-
assert.Equal(t, "/spec/containers/0/env", expectedPatches[1]["path"])
863-
assert.Equal(t, expectedEnv, expectedPatches[1]["value"])
881+
assert.Equal(t, "/spec/containers/0/env", patches[1]["path"])
882+
assert.Equal(t, expectedEnv, patches[1]["value"])
864883
})
865884
}
866885
}
@@ -1464,3 +1483,26 @@ func Test_MutatePod(t *testing.T) {
14641483
})
14651484
}
14661485
}
1486+
1487+
func Test_GenerateHeadlessServiceName(t *testing.T) {
1488+
tests := map[string]struct {
1489+
testRayClusterName string
1490+
expectedServiceName string
1491+
}{
1492+
"RayCluster name + headless-worker-svc is less than 50 chars, no truncation": {
1493+
testRayClusterName: "test-raycluster", // 15 chars
1494+
expectedServiceName: "test-raycluster-headless-worker-svc", // 35 chars
1495+
},
1496+
"RayCluster name + headless-worker-svc is more than 50 chars, name is truncated": {
1497+
testRayClusterName: "extremely-long-test-raycluster-name", // 35 chars
1498+
expectedServiceName: "mely-long-test-raycluster-name-headless-worker-svc", // 50 chars
1499+
},
1500+
}
1501+
1502+
for name, tc := range tests {
1503+
t.Run(name, func(t *testing.T) {
1504+
serviceName := generateHeadlessServiceName(tc.testRayClusterName)
1505+
assert.Equal(t, tc.expectedServiceName, serviceName)
1506+
})
1507+
}
1508+
}

0 commit comments

Comments
 (0)