Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TPU Webhook] Fix KubeRay headless worker svc truncation bug #963

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions ray-on-gke/tpu/kuberay-tpu-webhook/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"time"

ray "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
utils "github.com/ray-project/kuberay/ray-operator/controllers/ray/utils"
admissionv1 "k8s.io/api/admission/v1"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
Expand Down Expand Up @@ -244,16 +245,27 @@ func extractRayCluster(admissionReview *admissionv1.AdmissionReview) (*ray.RayCl
return &rayCluster, nil
}

// generateHeadlessServiceName returns the expected TPU headless service name for a RayCluster
func generateHeadlessServiceName(clusterName string) string {
serviceName := fmt.Sprintf("%s-%s", clusterName, headlessServiceSuffix)

// Apply the same truncation as in the RayCluster controller when generating the headless service
// name. This is to maintain the up-to 63 char compatibility guarantee for hostnames (RFC 1123).
serviceName = utils.CheckName(serviceName)
return serviceName
}

// genDNSHostnames returns list of DNS hostnames for TPU VM hosts as a string
func genDNSHostnames(numOfHosts int32, groupName string, clusterName string, namespace string, replicaIndex int) (string, error) {
if numOfHosts == 0 {
err := errors.New("workerGroupSpec NumOfHosts not set")
return "", err
}
headlessServiceName := generateHeadlessServiceName(clusterName)
hostNames := make([]string, numOfHosts)
// Host names will be of the form {WORKER_GROUP_NAME}-{REPLICA_INDEX}-{HOST_INDEX}.headless-worker-svc
// Host names will be of the form {WORKER_GROUP_NAME}-{REPLICA_INDEX}-{HOST_INDEX}.{CLUSTER_NAME}-headless-worker-svc
for j := 0; j < int(numOfHosts); j++ {
hostNames[j] = fmt.Sprintf("%s-%d-%d.%s-%s", groupName, replicaIndex, j, clusterName, headlessServiceSuffix)
hostNames[j] = fmt.Sprintf("%s-%d-%d.%s", groupName, replicaIndex, j, headlessServiceName)
}
klog.V(1).InfoS("genDNSHostnames", "RayCluster", namespace+"/"+clusterName, "NumOfHosts", numOfHosts, "Replica Index", replicaIndex)
return strings.Join(hostNames, ","), nil
Expand All @@ -268,7 +280,7 @@ func injectHostnames(clusterName string, hostNames string, envPath string, conta
Value: hostNames,
}
subdomainPatch["path"] = subdomainPath
subdomainPatch["value"] = fmt.Sprintf("%s-%s", clusterName, headlessServiceSuffix)
subdomainPatch["value"] = generateHeadlessServiceName(clusterName)
// create new EnvVar array if container.Env is empty, and append hostnames if not
if len(container.Env) == 0 {
hostNamesPatch["path"] = envPath
Expand Down Expand Up @@ -678,7 +690,7 @@ func (t *TPUWebhookServer) mutatePod(admissionReview *admissionv1.AdmissionRevie
return nil, err
}
klog.V(1).InfoS("mutatePod", "RayCluster", namespace+"/"+clusterName, "TPU_WORKER_HOSTNAMES", hostnames)
klog.V(1).InfoS("mutatePod", "RayCluster", namespace+"/"+clusterName, "subdomain", clusterName+"-"+headlessServiceSuffix)
klog.V(1).InfoS("mutatePod", "RayCluster", namespace+"/"+clusterName, "subdomain", generateHeadlessServiceName(clusterName))
injectHostnames(clusterName, hostnames, path, container, &patches)
}
// inject TPU_WORKER_ID
Expand Down
78 changes: 60 additions & 18 deletions ray-on-gke/tpu/kuberay-tpu-webhook/webhook_main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -778,26 +778,30 @@ func Test_ExtractRayCluster(t *testing.T) {

func Test_GenDNSHostnames(t *testing.T) {
tests := map[string]struct {
clusterName string
replicaIndex int
numOfHosts int32
expectedHostnames string
expectedError error
}{
"genDNSHostnames with NumOfHosts == 0": {
// a workergroup can't have NumOfHosts set to 0 so this should error out
clusterName: "test-cluster",
replicaIndex: 0,
numOfHosts: int32(0),
expectedError: errors.New("workerGroupSpec NumOfHosts not set"),
},
"genDNSHostnames with NumOfHosts == 1": {
// Single-host worker group, should return a single DNS hostname. This function will
// never be called for single-host groups, but we don't necessarily want it to error if it does.
clusterName: "test-cluster",
replicaIndex: 0,
numOfHosts: int32(1),
expectedHostnames: fmt.Sprintf("%s-%d-%d.%s-%s", "test-group", 0, 0, "test-cluster", headlessServiceSuffix),
},
"genDNSHostnames with NumOfHosts > 1": {
// multi-host worker group, should return a string list of DNS hostnames for the given replica
clusterName: "test-cluster",
replicaIndex: 1,
numOfHosts: int32(4),
expectedHostnames: strings.Join([]string{fmt.Sprintf("%s-%d-%d.%s-%s", "test-group", 1, 0, "test-cluster", headlessServiceSuffix),
Expand All @@ -806,12 +810,21 @@ func Test_GenDNSHostnames(t *testing.T) {
fmt.Sprintf("%s-%d-%d.%s-%s", "test-group", 1, 3, "test-cluster", headlessServiceSuffix),
}, ","),
},
"genDNSHostnames with long RayCluster name": {
// Multi-host worker group in a RayCluster with a name that will be truncated
clusterName: "long-raycluster-name-to-be-truncated",
replicaIndex: 1,
numOfHosts: int32(2),
expectedHostnames: strings.Join([]string{fmt.Sprintf("%s-%d-%d.%s-%s", "test-group", 1, 0, "aycluster-name-to-be-truncated", headlessServiceSuffix),
fmt.Sprintf("%s-%d-%d.%s-%s", "test-group", 1, 1, "aycluster-name-to-be-truncated", headlessServiceSuffix),
}, ","),
},
}

// validate that genDNSHostnames correctly returns a string list of DNS addressable hostnames
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
hostnames, err := genDNSHostnames(tc.numOfHosts, "test-group", "test-cluster", "test-namespace", tc.replicaIndex)
hostnames, err := genDNSHostnames(tc.numOfHosts, "test-group", tc.clusterName, "test-namespace", tc.replicaIndex)
if err != nil {
assert.Equal(t, tc.expectedError, err)
} else {
Expand All @@ -823,21 +836,15 @@ func Test_GenDNSHostnames(t *testing.T) {

func Test_InjectHostnames(t *testing.T) {
tests := map[string]struct {
numOfHosts int
clusterName string
groupName string
expectedSubdomain string
expectedHostnames string
}{
"injectHostnames for single-host worker group": {
// should create a patch to set the subdomain and a single TPU_WORKER_HOSTNAMES DNS hostname
numOfHosts: 1,
groupName: "test-group-name",
expectedSubdomain: fmt.Sprintf("%s-%s", "test-cluster", headlessServiceSuffix),
expectedHostnames: fmt.Sprintf("%s-%d-%d.%s-%s", "test-group", 0, 0, "test-cluster", headlessServiceSuffix),
},
"injectHostnames for multi-host worker group": {
// should create a patch to set the subdomain and TPU_WORKER_HOSTNAMES for all hosts
numOfHosts: 1,
// Should create a patch to set the subdomain and TPU_WORKER_HOSTNAMES for all hosts.
// This function is only called for multi-host TPU worker groups.
clusterName: "test-cluster",
groupName: "test-group-name",
expectedSubdomain: fmt.Sprintf("%s-%s", "test-cluster", headlessServiceSuffix),
expectedHostnames: strings.Join([]string{fmt.Sprintf("%s-%d-%d.%s-%s", "test-group", 1, 0, "test-cluster", headlessServiceSuffix),
Expand All @@ -846,21 +853,33 @@ func Test_InjectHostnames(t *testing.T) {
fmt.Sprintf("%s-%d-%d.%s-%s", "test-group", 1, 3, "test-cluster", headlessServiceSuffix),
}, ","),
},
"injectHostnames for multi-host worker group with truncated service name": {
// Should create a patch to set the subdomain and TPU_WORKER_HOSTNAMES for all hosts, with the
// correct subdomain truncated to match the created service name.
clusterName: "extremely-long-test-raycluster-name",
groupName: "test-group-name",
expectedSubdomain: fmt.Sprintf("%s-%s", "mely-long-test-raycluster-name", headlessServiceSuffix),
expectedHostnames: strings.Join([]string{fmt.Sprintf("%s-%d-%d.%s-%s", "test-group", 1, 0, "mely-long-test-raycluster-name", headlessServiceSuffix),
fmt.Sprintf("%s-%d-%d.%s-%s", "test-group", 1, 1, "mely-long-test-raycluster-name", headlessServiceSuffix),
fmt.Sprintf("%s-%d-%d.%s-%s", "test-group", 1, 2, "mely-long-test-raycluster-name", headlessServiceSuffix),
fmt.Sprintf("%s-%d-%d.%s-%s", "test-group", 1, 3, "mely-long-test-raycluster-name", headlessServiceSuffix),
}, ","),
},
}

// check that a valid subdomain and TPU_WORKER_HOSTNAMES are injected into the Pod
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
testPod := getTestTPUWorker("test-cluster", "test-group", "test-namespace", "tpu-v4-podslice", "2x2x1", "4")
testPod := getTestTPUWorker(tc.clusterName, tc.groupName, "test-namespace", "tpu-v4-podslice", "2x2x2", "4")
expectedEnv := []corev1.EnvVar{corev1.EnvVar{Name: "TPU_WORKER_HOSTNAMES", Value: tc.expectedHostnames}}
expectedPatches := []patch{}
injectHostnames("test-cluster", tc.expectedHostnames, "/spec/containers/0/env", testPod.Spec.Containers[0], &expectedPatches)
patches := []patch{}
injectHostnames(tc.clusterName, tc.expectedHostnames, "/spec/containers/0/env", testPod.Spec.Containers[0], &patches)
// check subdomain patch
assert.Equal(t, "/spec/subdomain", expectedPatches[0]["path"])
assert.Equal(t, tc.expectedSubdomain, expectedPatches[0]["value"])
assert.Equal(t, "/spec/subdomain", patches[0]["path"])
assert.Equal(t, tc.expectedSubdomain, patches[0]["value"])
// check hostnames patch
assert.Equal(t, "/spec/containers/0/env", expectedPatches[1]["path"])
assert.Equal(t, expectedEnv, expectedPatches[1]["value"])
assert.Equal(t, "/spec/containers/0/env", patches[1]["path"])
assert.Equal(t, expectedEnv, patches[1]["value"])
})
}
}
Expand Down Expand Up @@ -1464,3 +1483,26 @@ func Test_MutatePod(t *testing.T) {
})
}
}

func Test_GenerateHeadlessServiceName(t *testing.T) {
tests := map[string]struct {
testRayClusterName string
expectedServiceName string
}{
"RayCluster name + headless-worker-svc is less than 50 chars, no truncation": {
testRayClusterName: "test-raycluster", // 15 chars
expectedServiceName: "test-raycluster-headless-worker-svc", // 35 chars
},
"RayCluster name + headless-worker-svc is more than 50 chars, name is truncated": {
testRayClusterName: "extremely-long-test-raycluster-name", // 35 chars
expectedServiceName: "mely-long-test-raycluster-name-headless-worker-svc", // 50 chars
},
}

for name, tc := range tests {
t.Run(name, func(t *testing.T) {
serviceName := generateHeadlessServiceName(tc.testRayClusterName)
assert.Equal(t, tc.expectedServiceName, serviceName)
})
}
}