Skip to content

Commit 54531da

Browse files
authored
TPU Metrics PodMonitoring (#761)
* first commit * terraform fmt * remove default * more descriptive metric_scrape_interval comment * more descriptive comments * object for metrics config * update readme * update tutorial readme
1 parent c166bfa commit 54531da

File tree

8 files changed

+88
-20
lines changed

8 files changed

+88
-20
lines changed

modules/jetstream-maxtext-deployment/README.md

+16-4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ Assure the following environment variables are set:
66
- MODEL_NAME: The name of your LLM (as of the writing of this README valid options are "gemma-7b", "llama2-7b", "llama2-13b")
77
- PARAMETERS_PATH: Where to find the parameters for your LLM (if using the checkpoint-converter it will be "gs:\/\/$BUCKET_NAME\/final\/unscanned\/gemma_7b-it\/0\/checkpoints\/0\/items" where $BUCKET_NAME is the same one used in the checkpoint-converter)
88
- (optional) METRICS_PORT: Port to emit custom metrics on
9+
- (optional) SERVER_METRICS_SCRAPE_INTERVAL: How often to scrape Jetstream server metrics
10+
- (optional) SYSTEM_METRICS_SCRAPE_INTERVAL: How often to scrape TPU system metrics
911
- (optional) TPU_TOPOLOGY: Topology of TPU chips used by jetstream (default: "2x4")
1012
- (optional) TPU_TYPE: Type of TPUs used (default: "tpu-v5-lite-podslice")
1113
- (optional) TPU_CHIP_COUNT: Number of TPU chips requested, can be obtained by algebraically evaluating TPU_TOPOLOGY
@@ -53,11 +55,21 @@ cat ./templates/deployment.yaml.tftpl >> "$JETSTREAM_MANIFEST"
5355
PODMONITORING_MANIFEST=$(mktemp)
5456
cat ./templates/podmonitoring.yaml.tftpl >> "$PODMONITORING_MANIFEST"
5557
56-
if [ "$METRICS_PORT" != "" ]; then
57-
cat $PODMONITORING_MANIFEST | sed "s/\${metrics_port}/$METRICS_PORT/g" >> "$PODMONITORING_MANIFEST"
58-
cat $JETSTREAM_MANIFEST | sed "s/\${metrics_port_arg}/prometheus_port=$METRICS_PORT/g" >> "$JETSTREAM_MANIFEST"
59-
58+
PODMONITORING_TPU_MANIFEST=$(mktemp)
59+
cat ./templates/podmonitoring-tpu.yaml.tftpl >> "$PODMONITORING_TPU_MANIFEST"
60+
61+
if [ "$SYSTEM_METRICS_SCRAPE_INTERVAL" != "" ]; then
62+
cat $PODMONITORING_TPU_MANIFEST \
63+
| sed "s/\${metrics_scrape_interval}/$SYSTEM_METRICS_SCRAPE_INTERVAL/g" >> "$PODMONITORING_TPU_MANIFEST"
64+
cat $PODMONITORING_TPU_MANIFEST | kubectl apply -f -
65+
fi
66+
67+
if [ "$METRICS_PORT" != "" ] && [ "$SERVER_METRICS_SCRAPE_INTERVAL" != "" ]; then
68+
cat $PODMONITORING_MANIFEST \
69+
| sed "s/\${metrics_port}/$METRICS_PORT/g" \
70+
| sed "s/\${metrics_scrape_interval}/$SERVER_METRICS_SCRAPE_INTERVAL/g" >> "$PODMONITORING_MANIFEST"
6071
cat $PODMONITORING_MANIFEST | kubectl apply -f -
72+
cat $JETSTREAM_MANIFEST | sed "s/\${metrics_port_arg}/prometheus_port=$METRICS_PORT/g" >> "$JETSTREAM_MANIFEST"
6173
else
6274
cat $JETSTREAM_MANIFEST | sed "s/\${metrics_port_arg}//g" >> "$JETSTREAM_MANIFEST"
6375
fi

modules/jetstream-maxtext-deployment/main.tf

+12-4
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ locals {
1818
deployment_template = "${path.module}/templates/deployment.yaml.tftpl"
1919
service_template = "${path.module}/templates/service.yaml.tftpl"
2020
podmonitoring_template = "${path.module}/templates/podmonitoring.yaml.tftpl"
21+
podmonitoring_tpu_template = "${path.module}/templates/podmonitoring-tpu.yaml.tftpl"
2122
cmsa_jetstream_hpa_template = "${path.module}/templates/custom-metrics-stackdriver-adapter/hpa.jetstream.yaml.tftpl"
2223
prometheus_jetstream_hpa_template = "${path.module}/templates/prometheus-adapter/hpa.jetstream.yaml.tftpl"
2324
}
@@ -30,7 +31,7 @@ resource "kubernetes_manifest" "jetstream-deployment" {
3031
model_name = var.maxengine_deployment_settings.model_name
3132
tokenizer = strcontains(var.maxengine_deployment_settings.model_name, "gemma") ? "assets/tokenizer.gemma" : (strcontains(var.maxengine_deployment_settings.model_name, "llama") ? "assets/tokenizer.llama2" : "")
3233
load_parameters_path_arg = var.maxengine_deployment_settings.parameters_path
33-
metrics_port_arg = var.maxengine_deployment_settings.metrics_port != null ? format("prometheus_port=%d", var.maxengine_deployment_settings.metrics_port) : "",
34+
metrics_port_arg = try(format("prometheus_port=%d", var.maxengine_deployment_settings.metrics.server.port), ""),
3435
tpu-topology = var.maxengine_deployment_settings.accelerator_selectors.topology
3536
tpu-type = var.maxengine_deployment_settings.accelerator_selectors.accelerator
3637
tpu-chip-count = var.maxengine_deployment_settings.accelerator_selectors.chip_count
@@ -43,10 +44,17 @@ resource "kubernetes_manifest" "jetstream-service" {
4344
}
4445

4546
resource "kubernetes_manifest" "jetstream-podmonitoring" {
46-
count = var.maxengine_deployment_settings.metrics_port != null ? 1 : 0
47+
count = try(var.maxengine_deployment_settings.metrics.server != null ? 1 : 0, 0)
4748
manifest = yamldecode(templatefile(local.podmonitoring_template, {
48-
metrics_port = var.maxengine_deployment_settings.metrics_port != null ? var.maxengine_deployment_settings.metrics_port : "",
49-
metrics_scrape_interval = var.maxengine_deployment_settings.metrics_scrape_interval
49+
metrics_port = try(var.maxengine_deployment_settings.metrics.server.port, 0),
50+
metrics_scrape_interval = try(var.maxengine_deployment_settings.metrics.server.scrape_interval, 0),
51+
}))
52+
}
53+
54+
resource "kubernetes_manifest" "jetstream-podmonitoring-tpu" {
55+
count = try(var.maxengine_deployment_settings.metrics.system != null ? 1 : 0, 0)
56+
manifest = yamldecode(templatefile(local.podmonitoring_tpu_template, {
57+
metrics_scrape_interval = try(var.maxengine_deployment_settings.metrics.system.scrape_interval, 0)
5058
}))
5159
}
5260

modules/jetstream-maxtext-deployment/templates/deployment.yaml.tftpl

+1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ spec:
3838
- load_parameters_path=${load_parameters_path_arg}
3939
- ${metrics_port_arg}
4040
ports:
41+
- containerPort: 8431 # Port to export TPU runtime metrics, if supported.
4142
- containerPort: 9000
4243
resources:
4344
requests:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
apiVersion: monitoring.googleapis.com/v1
2+
kind: PodMonitoring
3+
metadata:
4+
name: tpu-metrics-exporter
5+
namespace: kube-system
6+
labels:
7+
k8s-app: tpu-device-plugin
8+
spec:
9+
endpoints:
10+
- port: 2112
11+
interval: ${metrics_scrape_interval}s
12+
selector:
13+
matchLabels:
14+
k8s-app: tpu-device-plugin

modules/jetstream-maxtext-deployment/variables.tf

+22-4
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,18 @@ variable "maxengine_deployment_settings" {
2929
maxengine_server_image = optional(string, "us-docker.pkg.dev/cloud-tpu-images/inference/maxengine-server:v0.2.2")
3030
jetstream_http_server_image = optional(string, "us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.2")
3131

32-
model_name = string // Name of your LLM (for example: "gemma-7b")
33-
parameters_path = string // Path to the paramters for your model
34-
metrics_port = optional(number) // Emit Jetstream metrics on this port of each container
35-
metrics_scrape_interval = optional(number) // Interval for scraping metrics (default: 10s)
32+
model_name = string // Name of your LLM (for example: "gemma-7b")
33+
parameters_path = string // Path to the parameters for your model
34+
35+
metrics = optional(object({ // Settings for metrics server
36+
server = optional(object({ // Settings for Jetstream server metrics
37+
port = number
38+
scrape_interval = number
39+
}))
40+
system = optional(object({ // Settings for TPU metrics
41+
scrape_interval = number
42+
}))
43+
}))
3644

3745
accelerator_selectors = object({
3846
topology = string
@@ -45,6 +53,16 @@ variable "maxengine_deployment_settings" {
4553
condition = contains(["gemma-7b", "llama2-7b", "llama2-13b"], var.maxengine_deployment_settings.model_name)
4654
error_message = "model_name must be one of \"gemma-7b\", \"llama2-7b\", or \"llama2-13b\""
4755
}
56+
57+
validation {
58+
condition = try(var.maxengine_deployment_settings.metrics.server.scrape_interval >= 5, true)
59+
error_message = "Server metrics scrape interval may not be shorter than 5s"
60+
}
61+
62+
validation {
63+
condition = try(var.maxengine_deployment_settings.metrics.system.scrape_interval >= 15, true)
64+
error_message = "TPU system metrics scrape interval may not be shorter than 15s"
65+
}
4866
}
4967

5068
variable "hpa_config" {

tutorials-and-examples/inference-servers/jetstream/maxtext/single-host-inference/README.md

+4-2
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,10 @@ For deploying autoscaling components via terraform, a few more variables to be s
137137

138138
```
139139
maxengine_deployment_settings = {
140-
metrics_port = <same as above>
141-
metrics_scrape_interval
140+
metrics = {
141+
port: <same as above> # which port will we scrape server metrics from
142+
scrape_interval: 5s # how often do we scrape
143+
}
142144
}
143145
144146
hpa_config = {

tutorials-and-examples/inference-servers/jetstream/maxtext/single-host-inference/terraform/sample-terraform.tfvars

+7-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
maxengine_deployment_settings = {
2-
metrics_port = 9100
3-
metrics_scrape_interval = 10
2+
metrics = {
3+
server = {
4+
port = 9100
5+
scrape_interval : 10
6+
}
7+
}
8+
49
accelerator_selectors = {
510
topology = "2x4"
611
accelerator = "tpu-v5-lite-podslice"

tutorials-and-examples/inference-servers/jetstream/maxtext/single-host-inference/terraform/variables.tf

+12-4
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,18 @@ variable "maxengine_deployment_settings" {
5454
maxengine_server_image = optional(string)
5555
jetstream_http_server_image = optional(string)
5656

57-
model_name = string // Name of your LLM (for example: "gemma-7b")
58-
parameters_path = string // Path to the parameters for your model
59-
metrics_port = optional(number) // Emit Jetstream metrics on this port of each container
60-
metrics_scrape_interval = optional(number) // Interval for scraping metrics (default: 10s)
57+
model_name = string // Name of your LLM (for example: "gemma-7b")
58+
parameters_path = string // Path to the parameters for your model
59+
60+
metrics = optional(object({ // Settings for metrics server
61+
server = optional(object({ // Settings for Jetstream server metrics
62+
port = number
63+
scrape_interval = number
64+
}))
65+
system = optional(object({ // Settings for TPU metrics
66+
scrape_interval = number
67+
}))
68+
}))
6169

6270
accelerator_selectors = object({
6371
topology = string

0 commit comments

Comments
 (0)