Configure autoscaling for LLM workloads on TPUs


This page shows how to set up your autoscaling infrastructure by using the GKE Horizontal Pod Autoscaler (HPA) to deploy the Gemma large language model (LLM) using single-host JetStream.

To learn more about selecting metrics for autoscaling, see Best practices for autoscaling LLM workloads with TPUs on GKE.

Before you begin

Before you start, make sure you have performed the following tasks:

  • Enable the Google Kubernetes Engine API.
  • Enable Google Kubernetes Engine API
  • If you want to use the Google Cloud CLI for this task, install and then initialize the gcloud CLI. If you previously installed the gcloud CLI, get the latest version by running gcloud components update.

Autoscale using metrics

You can use the workload-specific performance metrics that are emitted by the JetStream inference server or TPU performance metrics to direct autoscaling for your Pods.

To set up autoscaling with metrics, follow these steps:

  1. Export the metrics from the JetStream server to Cloud Monitoring. You use Google Cloud Managed Service for Prometheus, which simplifies deploying and configuring your Prometheus collector. Google Cloud Managed Service for Prometheus is enabled by default in your GKE cluster; you can also enable it manually.

    The following example manifest shows how to set up your PodMonitoring resource definitions to direct Google Cloud Managed Service for Prometheus to scrape metrics from your Pods at recurring intervals of 15 seconds:

    If you need to scrape server metrics, use the following manifest. With server metrics, scrape intervals as frequent as 5 seconds are supported.

    apiVersion: monitoring.googleapis.com/v1
    kind: PodMonitoring
    metadata:
      name: jetstream-podmonitoring
    spec:
      selector:
        matchLabels:
          app: maxengine-server
      endpoints:
      - interval: 15s
        path: "/"
        port: PROMETHEUS_PORT
      targetLabels:
        metadata:
        - pod
        - container
        - node
    

    If you need to scrape TPU metrics, use the following manifest. With system metrics, scrape intervals as frequent as 15 seconds are supported.

    apiVersion: monitoring.googleapis.com/v1
    kind: PodMonitoring
    metadata:
      name: tpu-metrics-exporter
      namespace: kube-system
      labels:
        k8s-app: tpu-device-plugin
    spec:
      endpoints:
        - port: 2112
          interval: 15s
      selector:
        matchLabels:
          k8s-app: tpu-device-plugin
    
  2. Install a Metrics Adapter. This adapter makes the server metrics that you exported to Monitoring visible to the HPA controller. For more details, see Horizontal pod autoscaling in the Google Cloud Managed Service for Prometheus documentation.

    Custom Metrics Stackdriver Adapter

    The Custom Metrics Stackdriver Adapter supports querying metrics from Google Cloud Managed Service for Prometheus, starting with version v0.13.1 of the adapter.

    To install the Custom Metrics Stackdriver Adapter, do the following:

    1. Set up managed collection in your cluster.

    2. Install the Custom Metrics Stackdriver Adapter in your cluster.

      kubectl apply -f https://raw.githubusercontent.com/GoogleCloudPlatform/k8s-stackdriver/master/custom-metrics-stackdriver-adapter/deploy/production/adapter_new_resource_model.yaml
      
    3. If you have Workload Identity Federation for GKE enabled on your Kubernetes cluster and you use Workload Identity Federation for GKE, you must also grant the Monitoring Viewer role to the service account the adapter runs under. Replace PROJECT_ID with your project ID.

    export PROJECT_NUMBER=$(gcloud projects describe PROJECT_ID --format 'get(projectNumber)')
    gcloud projects add-iam-policy-binding projects/PROJECT_ID \
      --role roles/monitoring.viewer \
      --member=principal://iam.googleapis.com/projects/$PROJECT_NUMBER/locations/global/workloadIdentityPools/PROJECT_ID.svc.id.goog/subject/ns/custom-metrics/sa/custom-metrics-stackdriver-adapter
    

    Prometheus Adapter

    Be aware of these considerations when using prometheus-adapter to scale using Google Cloud Managed Service for Prometheus:

    • Route queries through the Prometheus frontend UI proxy, just like when querying Google Cloud Managed Service for Prometheus using the Prometheus API or UI. This frontend is installed in a later step.
    • By default, the prometheus-url argument of the prometheus-adapter Deployment is set to --prometheus-url=http://frontend.default.svc:9090/, where default is the namespace where you deployed the frontend. If you deployed the frontend in another namespace, configure this argument accordingly.
    • In the .seriesQuery field of the rules config, you can't use a regular expression (regex) matcher on a metric name. Instead, fully specify metric names.

    As data can take slightly longer to be available within Google Cloud Managed Service for Prometheus compared to upstream Prometheus, configuring overly eager autoscaling logic can cause unwanted behavior. Although there is no guarantee on data freshness, data is typically available to query 3-7 seconds after it is sent to Google Cloud Managed Service for Prometheus, excluding any network latency.

    All queries issued by prometheus-adapter are global in scope. This means that if you have applications in two namespaces that emit identically named metrics, an HPA configuration using that metric scales using data from both applications. To avoid scaling using incorrect data, always use namespace or cluster filters in your PromQL.

    To set up an example HPA configuration using prometheus-adapter and managed collection, follow these steps:

    1. Set up managed collection in your cluster.
    2. Deploy the Prometheus frontend UI proxy in your cluster. Create the following manifest named prometheus-frontend.yaml:

        apiVersion: apps/v1
        kind: Deployment
        metadata:
          name: frontend
        spec:
          replicas: 2
          selector:
            matchLabels:
              app: frontend
          template:
            metadata:
              labels:
                app: frontend
            spec:
              automountServiceAccountToken: true
              affinity:
                nodeAffinity:
                  requiredDuringSchedulingIgnoredDuringExecution:
                    nodeSelectorTerms:
                    - matchExpressions:
                      - key: kubernetes.io/arch
                        operator: In
                        values:
                        - arm64
                        - amd64
                      - key: kubernetes.io/os
                        operator: In
                        values:
                        - linux
              containers:
              - name: frontend
                image: gke.gcr.io/prometheus-engine/frontend:v0.8.0-gke.4
                args:
                - "--web.listen-address=:9090"
                - "--query.project-id=PROJECT_ID"
                ports:
                - name: web
                  containerPort: 9090
                readinessProbe:
                  httpGet:
                    path: /-/ready
                    port: web
                securityContext:
                  allowPrivilegeEscalation: false
                  capabilities:
                    drop:
                    - all
                  privileged: false
                  runAsGroup: 1000
                  runAsNonRoot: true
                  runAsUser: 1000
                livenessProbe:
                  httpGet:
                    path: /-/healthy
                    port: web
        ---
        apiVersion: v1
        kind: Service
        metadata:
          name: prometheus
        spec:
          clusterIP: None
          selector:
            app: frontend
          ports:
          - name: web
            port: 9090
      

      Then, apply the manifest:

      kubectl apply -f prometheus-frontend.yaml
      
    3. Ensure prometheus-adapter is installed in your cluster by installing the prometheus-community/prometheus-adapter helm chart. Create the following values.yaml file:

      rules:
        default: false
        external:
        - seriesQuery: 'jetstream_prefill_backlog_size'
          resources:
            template: <<.Resource>>
          name:
            matches: ""
            as: "jetstream_prefill_backlog_size"
          metricsQuery: avg(<<.Series>>{<<.LabelMatchers>>,cluster="CLUSTER_NAME"})
        - seriesQuery: 'jetstream_slots_used_percentage'
          resources:
            template: <<.Resource>>
          name:
            matches: ""
            as: "jetstream_slots_used_percentage"
          metricsQuery: avg(<<.Series>>{<<.LabelMatchers>>,cluster="CLUSTER_NAME"})
        - seriesQuery: 'memory_used'
          resources:
            template: <<.Resource>>
          name:
            matches: ""
            as: "memory_used_percentage"
          metricsQuery: avg(memory_used{cluster="CLUSTER_NAME",exported_namespace="default",container="jetstream-http"}) / avg(memory_total{cluster="CLUSTER_NAME",exported_namespace="default",container="jetstream-http"})
      

      Then, use this file as the values file for deploying your helm chart:

      helm repo add prometheus-community https://prometheus-community.github.io/helm-charts && helm repo update && helm install example-release prometheus-community/prometheus-adapter -f values.yaml
      

    If you use Workload Identity Federation for GKE, you also need to configure and authorize a service account by running the following commands:

    1. First, create your in-cluster and Google Cloud service accounts:

      gcloud iam service-accounts create prom-frontend-sa && kubectl create sa prom-frontend-sa
      
    2. Then, bind the two service accounts, make sure to replace PROJECT_ID with your project ID:

      gcloud iam service-accounts add-iam-policy-binding \
        --role roles/iam.workloadIdentityUser \
        --member "serviceAccount:PROJECT_ID.svc.id.goog[default/prom-frontend-sa]" \
        jetstream-iam-sa@PROJECT_ID.iam.gserviceaccount.com \
      &&
      kubectl annotate serviceaccount \
        --namespace default \
        prom-frontend-sa \
        iam.gke.io/gcp-service-account=jetstream-iam-sa@PROJECT_ID.iam.gserviceaccount.com
      
    3. Next, give the Google Cloud service account the monitoring.viewer role:

      gcloud projects add-iam-policy-binding PROJECT_ID \
        --member=serviceAccount:jetstream-iam-sa@PROJECT_ID.iam.gserviceaccount.com \
        --role=roles/monitoring.viewer
      
    4. Finally, set your frontend deployments service account to be your new in-cluster service account:

      kubectl set serviceaccount deployment frontend prom-frontend-sa
      
  3. Set up the metric-based HPA resource. Deploy an HPA resource that is based on your preferred server metric. For more details, see Horizontal pod autoscaling in the Google Cloud Managed Service for Prometheus documentation. The specific HPA configuration depends on the type of metric (server or TPU) and which metric adapter is installed.

    A few values are required across all HPA configurations and must be set in order to create an HPA resource:

    • MIN_REPLICAS: The minimum number of JetStream pod replicas allowed. If not modifying the JetStream deployment manifest from the Deploy JetStream step, we recommend setting this to 1.
    • MAX_REPLICAS: The maximum number of JetStream pod replicas allowed. The example JetStream deployment requires 8 chips per replica and the node pool contains 16 chips. If you want to keep scale up latency low, set this to 2. Larger values trigger the Cluster Autoscaler to create new nodes in the node pool, thus increasing scale up latency.
    • TARGET: The targeted average for this metric across all JetStream instances. See the Kubernetes Documentation for Autoscaling for more information about how replica count is determined from this value.

    Custom Metrics Stackdriver Adapter

    Custom Metrics Stackdriver Adapter supports scaling your workload with the average value of individual metric queries from Google Cloud Managed Service for Prometheus across all Pods. When using Custom Metrics Stackdriver Adapter, we advise scaling with the jetstream_prefill_backlog_size and jetstream_slots_used_percentage server metrics and the memory_used TPU metric.

    To create an HPA manifest for scaling with server metrics, create the following hpa.yaml file:

    apiVersion: autoscaling/v2
    kind: HorizontalPodAutoscaler
    metadata:
      name: jetstream-hpa
      namespace: default
    spec:
      scaleTargetRef:
        apiVersion: apps/v1
        kind: Deployment
        name: maxengine-server
      minReplicas: MIN_REPLICAS
      maxReplicas: MAX_REPLICAS
      metrics:
      - type: Pods
        pods:
          metric:
            name: prometheus.googleapis.com|jetstream_METRIC|gauge
          target:
            type: AverageValue
            averageValue: TARGET
    

    When using the Custom Metrics Stackdriver Adapter with TPU metrics, we recommend only using the kubernetes.io|node|accelerator|memory_used metric for scaling. To create an HPA manifest for scaling with this metric, create the following hpa.yaml file:

    apiVersion: autoscaling/v2
    kind: HorizontalPodAutoscaler
    metadata:
      name: jetstream-hpa
      namespace: default
    spec:
      scaleTargetRef:
        apiVersion: apps/v1
        kind: Deployment
        name: maxengine-server
      minReplicas: MIN_REPLICAS
      maxReplicas: MAX_REPLICAS
      metrics:
      - type: External
        external:
          metric:
            name: prometheus.googleapis.com|memory_used|gauge
            selector:
              matchLabels:
                metric.labels.container: jetstream-http
                metric.labels.exported_namespace: default
          target:
            type: AverageValue
            averageValue: TARGET
    

    Prometheus Adapter

    Prometheus Adapter supports scaling your workload with the value of PromQL queries from Google Cloud Managed Service for Prometheus. Earlier, you defined the jetstream_prefill_backlog_size and jetstream_slots_used_percentage server metrics that represent the average value across all Pods.

    To create an HPA manifest for scaling with server metrics, create the following hpa.yaml file:

    apiVersion: autoscaling/v2
    kind: HorizontalPodAutoscaler
    metadata:
      name: jetstream-hpa
      namespace: default
    spec:
      scaleTargetRef:
        apiVersion: apps/v1
        kind: Deployment
        name: maxengine-server
      minReplicas: MIN_REPLICAS
      maxReplicas: MAX_REPLICAS
      metrics:
      - type: External
        external:
          metric:
            name: jetstream_METRIC
          target:
            type: AverageValue
            averageValue: TARGET
    

    To create an HPA manifest for scaling with TPU metrics, we recommend only using the memory_used_percentage defined in the prometheus-adapter helm values file. memory_used_percentage is the name given to the following PromQL query which reflects the current average memory used across all accelerators:

    avg(kubernetes_io:node_accelerator_memory_used{cluster_name="CLUSTER_NAME"}) / avg(kubernetes_io:node_accelerator_memory_total{cluster_name="CLUSTER_NAME"})
    

    To create an HPA manifest for scaling with memory_used_percentage, create the following hpa.yaml file:

    apiVersion: autoscaling/v2
    kind: HorizontalPodAutoscaler
    metadata:
      name: jetstream-hpa
      namespace: default
    spec:
      scaleTargetRef:
        apiVersion: apps/v1
        kind: Deployment
        name: maxengine-server
      minReplicas: MIN_REPLICAS
      maxReplicas: MAX_REPLICAS
      metrics:
      - type: External
        external:
          metric:
            name: memory_used_percentage
          target:
            type: AverageValue
            averageValue: TARGET
    

Scale using multiple metrics

You can also configure scaling based on multiple metrics. To learn about how replica count is determined using multiple metrics, refer to the Kubernetes documentation on auto-scaling. To build this type of HPA manifest, collect all entries from the spec.metrics field of each HPA resource into a single HPA resource. The following snippet shows an example of how you can bundle the HPA resources:

apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
  name: jetstream-hpa-multiple-metrics
  namespace: default
spec:
  scaleTargetRef:
    apiVersion: apps/v1
    kind: Deployment
    name: maxengine-server
  minReplicas: MIN_REPLICAS
  maxReplicas: MAX_REPLICAS
  metrics:
  - type: Pods
    pods:
      metric:
        name: jetstream_METRIC
      target:
        type: AverageValue
      averageValue: JETSTREAM_METRIC_TARGET
  - type: External
    external:
      metric:
        name: memory_used_percentage
      target:
        type: AverageValue
      averageValue: EXTERNAL_METRIC_TARGET

Monitor and test autoscaling

You can observe how your JetStream workloads scale based on your HPA configuration.

To observe the replica count in real-time, run the following command:

kubectl get hpa --watch

The output from this command should be similar to the following:

NAME            REFERENCE                     TARGETS      MINPODS   MAXPODS   REPLICAS   AGE
jetstream-hpa   Deployment/maxengine-server   0/10 (avg)   1         2         1          1m

To test your HPA's ability to scale, use the following command which sends a burst of 100 requests to the model endpoint. This will exhaust the available decode slots and cause a backlog of requests on the prefill queue, triggering the HPA to increase the size of the model deployment.

seq 100 | xargs -P 100 -n 1 curl --request POST --header "Content-type: application/json" -s localhost:8000/generate --data '{ "prompt": "Can you provide a comprehensive and detailed overview of the history and development of artificial intelligence.", "max_tokens": 200 }'

What's next