Skip to content

[Roadmap] Prefill and Decoding Disaggregation #4655

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
11 of 13 tasks
ByronHsu opened this issue Mar 21, 2025 · 27 comments
Open
11 of 13 tasks

[Roadmap] Prefill and Decoding Disaggregation #4655

ByronHsu opened this issue Mar 21, 2025 · 27 comments

Comments

@ByronHsu
Copy link
Collaborator

ByronHsu commented Mar 21, 2025

Design:

SGLang PD Disaggregation (Open Source)

Progress

@zhyncs zhyncs pinned this issue Mar 22, 2025
@zhyncs zhyncs changed the title PD Tracker [Track] Prefill and Decoding Disaggregation Mar 22, 2025
@zhyncs zhyncs changed the title [Track] Prefill and Decoding Disaggregation [Roadmap] Prefill and Decoding Disaggregation Mar 22, 2025
@stmatengss
Copy link
Collaborator

Good job! @ByronHsu We mooncake team will integrate the mooncake transfer engine to PD disaggregation ASAP. Related PR will be available soon. Thx.

@zhyncs
Copy link
Member

zhyncs commented Mar 22, 2025

Good job! @ByronHsu We mooncake team will integrate the mooncake transfer engine to PD disaggregation ASAP. Related PR will be available soon. Thx.

Thanks @stmatengss! Please let me know when it's ready as it is our top priority to complete it. We will ensure the review and merge process goes smoothly!

@ShangmingCai
Copy link
Collaborator

Good job! @ByronHsu We mooncake team will integrate the mooncake transfer engine to PD disaggregation ASAP. Related PR will be available soon. Thx.

+1. Will be on it ASAP. Cheers for the collaboration.

@HaiShaw
Copy link
Collaborator

HaiShaw commented Mar 23, 2025

@ByronHsu @ShangmingCai AMD's support on this and Mooncake will be fully available soon. Thanks.

@trevor-m
Copy link
Collaborator

Hi @ByronHsu, I'll be working on the NVIDIA NIXL integration

@thesues
Copy link
Contributor

thesues commented Mar 26, 2025

hi, @ByronHsu , I have a question of you design.

Is the pre-allocated memory GPU memory or CPU memory?if that is GPU memory, so it could use RDMA GPUDirect copy. But the drawback is decoder may allocate too many GPU memory before computation started?

@Luis-xu
Copy link

Luis-xu commented Mar 27, 2025

hi, @ByronHsu , I have a question of you design.

Is the pre-allocated memory GPU memory or CPU memory?if that is GPU memory, so it could use RDMA GPUDirect copy. But the drawback is decoder may allocate too many GPU memory before computation started?

@thesues I have the same opinion as you, but I recently learned about the Dynamo and found that the PD separation implemented also retains the P2P transmission path. I haven't figured out the relationship between it and multi-level cache. Perhaps different transmission paths are needed for different levels of kv cache. Correspondingly, a unique design is also needed in the upper-level scheduling.

@trevor-m
Copy link
Collaborator

@ByronHsu I've started a WIP branch here for NIXL: trevor-m@6d862c5

@KivenChen
Copy link
Contributor

I am currently heavily experimenting with dynamo integration. Does anyone share the same interest?

@ByronHsu
Copy link
Collaborator Author

ByronHsu commented Mar 27, 2025

Is the pre-allocated memory GPU memory or CPU memory?if that is GPU memory, so it could use RDMA GPUDirect copy. But the drawback is decoder may allocate too many GPU memory before computation started?

Good point! @thesues that might be the case if decode's memory is constrained. However, the existing design works well for us under reasonable QPS. Recently i read this paper https://arxiv.org/html/2501.14743v1, and it suggested a pull-based model which might worth a try, but would need some modification on the current kv transfer interface.

@libratiger
Copy link
Contributor

@ByronHsu Whether using the pull or push model, this functionality could be hidden within the transfer engine. The inference framework would simply place tokens into the transfer engine and consume them from it, essentially treating the transfer engine as a queue.

This may can keep SGLANG simpler and we can offer a abstract layer to hidden different transfer engines.

@libratiger
Copy link
Contributor

The design doc have very carefully consideration, maybe we can stand on a higher layer, so we can move forward faster?

such as the scatter-gather elements (SGE) in RDMA is useful, but this can also be done on the common network transfer, this is used widely in the linux kernel.

@Venkat2811
Copy link

@ByronHsu

Rust PD Load Balancer - call out for contribution

I'm interested. Will reach-out in SGL slack.

@EstherBear
Copy link

@ByronHsu

Rust PD Load Balancer - call out for contribution

I'm interested. Will reach-out in SGL slack.

Hi @Venkat2811 , I'm also working on a PR for rust pd load balancer. Maybe we can work on it together?

@anapple-hub
Copy link

Hi @ByronHsu , currently, P/D Disaggregation seems to lack support for dp-attention. For example, when enabling DP, prepare_dp_attn_batch() is not invoked, and parameters like global_num_tokens_gpu are missing in get_dp_local_info(). Additionally, in event_loop_normal_disagg_decode, some ranks with batch_forward_mode set to extend will perform stream_output, while others in idle or decode modes will execute run_batch.
I noticed that enabling dp-attention is not listed in the current roadmap. Are there plans to support this in the future?

@lambda7xx
Copy link

can I join PD task? @ByronHsu

@XucSh
Copy link
Contributor

XucSh commented Apr 14, 2025

PD + fault tolerance

Hi, @ByronHsu ,I'm interested in this part. Could i take it ? Cc @stmatengss @ShangmingCai

@trevor-m
Copy link
Collaborator

New NIXL transfer engine PR: #5477

@ZhongYingMatrix
Copy link

Hi @ByronHsu , I notice that in earlier designs, KV Transfer is designed to be layer-by-layer and chunk-by-chunk. Is there any consideration regarding the removal of this part of the design?

@CSEEduanyu
Copy link

I encountered this error while running in a container environment:
RuntimeError: Mooncake memory registration failed.
E0511 07:42:36.170303 8945 rdma_context.cpp:198] Failed to register memory 0x2e1cbac200: Bad address [14]

@ShangmingCai
Copy link
Collaborator

I encountered this error while running in a container environment: RuntimeError: Mooncake memory registration failed. E0511 07:42:36.170303 8945 rdma_context.cpp:198] Failed to register memory 0x2e1cbac200: Bad address [14]

@CSEEduanyu Does your env support GDR? If not, then your env cannot run PD with SGLang, and Mooncake will report failures when registering your GPU mem.

@CSEEduanyu
Copy link

I encountered this error while running in a container environment: RuntimeError: Mooncake memory registration failed. E0511 07:42:36.170303 8945 rdma_context.cpp:198] Failed to register memory 0x2e1cbac200: Bad address [14]

@CSEEduanyu Does your env support GDR? If not, then your env cannot run PD with SGLang, and Mooncake will report failures when registering your GPU mem.

i can run gdrcopy_copybw :
GPU id:0; name: NVIDIA H800; Bus id: 0000:63:00
GPU id:1; name: NVIDIA H800; Bus id: 0000:67:00
GPU id:2; name: NVIDIA H800; Bus id: 0000:6b:00
GPU id:3; name: NVIDIA H800; Bus id: 0000:6f:00
GPU id:4; name: NVIDIA H800; Bus id: 0000:a3:00
GPU id:5; name: NVIDIA H800; Bus id: 0000:a7:00
GPU id:6; name: NVIDIA H800; Bus id: 0000:ab:00
GPU id:7; name: NVIDIA H800; Bus id: 0000:af:00
selecting device 0
testing size: 131072
rounded size: 131072
gpu alloc fn: cuMemAlloc
device ptr: 7fbfb7e00000
map_d_ptr: 0x7fc1e81e7000
info.va: 7fbfb7e00000
info.mapped_size: 131072
info.page_size: 65536
info.mapped: 1
info.wc_mapping: 1
page offset: 0
user-space pointer:0x7fc1e81e7000
writing test, size=131072 offset=0 num_iters=10000
write BW: 17884.9MB/s
reading test, size=131072 offset=0 num_iters=100
read BW: 669.866MB/s
unmapping buffer
unpinning buffer
closing gdrdrv

@ShangmingCai
Copy link
Collaborator

@CSEEduanyu Please open an issue in the mooncake repo, and provide a detailed log to help us identify the root cause of your problem.

@pc-neo
Copy link
Contributor

pc-neo commented May 16, 2025

@ch-wan @ByronHsu hi, I am intrested the feature for MTP compatible with PD_disaggregation , can I join u to develop this?

@CSEEduanyu
Copy link

Is there a comprehensive benchmark to verify the improvement of PD Disaggregation? @ByronHsu

@wqlxx
Copy link

wqlxx commented May 26, 2025

I run PD disaggregation in 8*L20 server. I use docker with image lmsysorg/sglang:v0.4.6.post5-cu124

docker run -it --device=/dev/infiniband/ --device=/dev/knem --ulimit memlock=-1 --shm-size 32g --cap-add CAP_SYS_ADMIN --privileged --gpus all --net host  -v /home/wq:/home/wq -v /home/models:/home/models --ipc=host --name sglang_mooncake_wq_0.4.6 lmsysorg/sglang:v0.4.6.post5-cu124

and get the RuntimeError: Mooncake memory registration failed. error.

root@pod-hpc-01:/sgl-workspace# python3 -m sglang.launch_server --model-path /home/models/Qwen2.5-0.5B-Instruct --disaggregation-mode prefill --disaggregation-ib-device mlx5_2 --disaggregation-transfer-backend mooncake
Cuda graph is disabled for prefill server
[2025-05-26 02:50:37] server_args=ServerArgs(model_path='/home/models/Qwen2.5-0.5B-Instruct', tokenizer_path='/home/models/Qwen2.5-0.5B-Instruct', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization=None, quantization_param_path=None, context_length=None, device='cuda', served_model_name='/home/models/Qwen2.5-0.5B-Instruct', chat_template=None, completion_template=None, is_embedding=False, enable_multimodal=None, revision=None, host='127.0.0.1', port=30000, mem_fraction_static=0.88, max_running_requests=None, max_total_tokens=None, chunked_prefill_size=16384, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, tp_size=1, pp_size=1, max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=160454688, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, log_level='info', log_level_http=None, log_requests=False, log_requests_level=0, show_time_cost=False, enable_metrics=False, bucket_time_to_first_token=None, bucket_e2e_request_latency=None, bucket_inter_token_latency=None, collect_tokens_histogram=False, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, api_key=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', speculative_algorithm=None, speculative_draft_model_path=None, speculative_num_steps=None, speculative_eagle_topk=None, speculative_num_draft_tokens=None, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_cuda_graph=True, disable_cuda_graph_padding=False, enable_nccl_nvls=False, enable_tokenizer_batch_encode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_dp_lm_head=False, enable_ep_moe=False, enable_deepep_moe=False, deepep_mode='auto', ep_num_redundant_experts=0, ep_dispatch_algorithm=None, init_expert_location='trivial', enable_eplb=False, eplb_rebalance_num_iterations=1000, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=None, enable_expert_distribution_metrics=False, deepep_config=None, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=None, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, tool_call_parser=None, enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through_selective', flashinfer_mla_disable_ragged=False, warmups=None, moe_dense_tp_size=None, n_share_experts_fusion=0, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, mm_attention_backend=None, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='prefill', disaggregation_bootstrap_port=8998, disaggregation_transfer_backend='mooncake', disaggregation_ib_device='mlx5_2', pdlb_url=None)
[2025-05-26 02:50:41] Attention backend not set. Use flashinfer backend by default.
[2025-05-26 02:50:41] Init torch distributed begin.
[2025-05-26 02:50:41] Init torch distributed ends. mem usage=0.00 GB
[2025-05-26 02:50:41] init_expert_location from trivial
[2025-05-26 02:50:42] Load weight begin. avail mem=44.19 GB
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  7.17it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  7.16it/s]

[2025-05-26 02:50:42] Load weight end. type=Qwen2ForCausalLM, dtype=torch.bfloat16, avail mem=43.21 GB, mem usage=0.98 GB.
[2025-05-26 02:50:42] KV Cache is allocated. #tokens: 3312237, K size: 18.95 GB, V size: 18.95 GB
[2025-05-26 02:50:42] Memory pool end. avail mem=4.70 GB
[2025-05-26 02:50:42] max_total_num_tokens=3312237, chunked_prefill_size=16384, max_prefill_tokens=16384, max_running_requests=4097, context_len=32768
WARNING: Logging before InitGoogleLogging() is written to STDERR
I0526 02:50:43.498081  1739 transfer_engine.cpp:350] Metrics reporting is disabled (set MC_TE_METRIC=1 to enable)
I0526 02:50:43.498109  1739 transfer_engine.cpp:44] Transfer Engine starting. Server: 10.110.183.21, Metadata: P2PHANDSHAKE, ip_or_host_name: , rpc_port: 0
I0526 02:50:43.498136  1739 transfer_engine.cpp:100] Transfer Engine RPC using P2P handshake, listening on 10.110.183.21:16076
I0526 02:50:43.498211  1739 transfer_engine.cpp:112] Auto-discovering topology...
I0526 02:50:43.499500  1739 transfer_engine.cpp:127] Topology discovery complete. Found 1 HCAs.
I0526 02:50:43.504783  1739 rdma_context.cpp:416] Find best gid index: 3 on mlx5_2/
I0526 02:50:43.505787  1739 rdma_context.cpp:125] RDMA device: mlx5_2, LID: 0, GID: (GID_Index 3) 00:00:00:00:00:00:00:00:00:00:ff:ff:ac:10:00:15
E0526 02:50:43.916359  1739 rdma_context.cpp:203] Failed to register memory 0x7f11be000000: Bad address [14]
[2025-05-26 02:50:43] Mooncake memory registration failed.
[2025-05-26 02:50:43] Scheduler hit an exception: Traceback (most recent call last):
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 2297, in run_scheduler_process
    scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 463, in __init__
    self.init_disaggregation()
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 616, in init_disaggregation
    self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
  File "/sgl-workspace/sglang/python/sglang/srt/disaggregation/prefill.py", line 87, in __init__
    self.kv_manager = self._init_kv_manager()
  File "/sgl-workspace/sglang/python/sglang/srt/disaggregation/prefill.py", line 125, in _init_kv_manager
    kv_manager = kv_manager_class(
  File "/sgl-workspace/sglang/python/sglang/srt/disaggregation/mooncake/conn.py", line 149, in __init__
    self.register_buffer_to_engine()
  File "/sgl-workspace/sglang/python/sglang/srt/disaggregation/mooncake/conn.py", line 176, in register_buffer_to_engine
    self.engine.register(kv_data_ptr, kv_data_len)
  File "/sgl-workspace/sglang/python/sglang/srt/disaggregation/mooncake/transfer_engine.py", line 36, in register
    raise RuntimeError("Mooncake memory registration failed.")
RuntimeError: Mooncake memory registration failed.

[2025-05-26 02:50:43] Received sigquit from a child process. It usually means the child failed.
Killed

my nvidia-smi topo is

root@pod-hpc-01:/home/wq/gdrcopy# nvidia-smi topo -m
        GPU0    GPU1    GPU2    GPU3    GPU4    GPU5    GPU6    GPU7    NIC0    NIC1    NIC2    NIC3    NIC4    NIC5    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      PIX     PIX     PIX     SYS     SYS     SYS     SYS     PIX     PIX     SYS     SYS     SYS     SYS     0-55,112-167    0               N/A
GPU1    PIX      X      PIX     PIX     SYS     SYS     SYS     SYS     PIX     PIX     SYS     SYS     SYS     SYS     0-55,112-167    0               N/A
GPU2    PIX     PIX      X      PIX     SYS     SYS     SYS     SYS     PIX     PIX     SYS     SYS     SYS     SYS     0-55,112-167    0               N/A
GPU3    PIX     PIX     PIX      X      SYS     SYS     SYS     SYS     PIX     PIX     SYS     SYS     SYS     SYS     0-55,112-167    0               N/A
GPU4    SYS     SYS     SYS     SYS      X      PIX     PIX     PIX     SYS     SYS     PIX     PIX     SYS     SYS     56-111,168-223  1               N/A
GPU5    SYS     SYS     SYS     SYS     PIX      X      PIX     PIX     SYS     SYS     PIX     PIX     SYS     SYS     56-111,168-223  1               N/A
GPU6    SYS     SYS     SYS     SYS     PIX     PIX      X      PIX     SYS     SYS     PIX     PIX     SYS     SYS     56-111,168-223  1               N/A
GPU7    SYS     SYS     SYS     SYS     PIX     PIX     PIX      X      SYS     SYS     PIX     PIX     SYS     SYS     56-111,168-223  1               N/A
NIC0    PIX     PIX     PIX     PIX     SYS     SYS     SYS     SYS      X      PIX     SYS     SYS     SYS     SYS
NIC1    PIX     PIX     PIX     PIX     SYS     SYS     SYS     SYS     PIX      X      SYS     SYS     SYS     SYS
NIC2    SYS     SYS     SYS     SYS     PIX     PIX     PIX     PIX     SYS     SYS      X      PIX     SYS     SYS
NIC3    SYS     SYS     SYS     SYS     PIX     PIX     PIX     PIX     SYS     SYS     PIX      X      SYS     SYS
NIC4    SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS      X      PIX
NIC5    SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     PIX      X 

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NIC Legend:

  NIC0: mlx5_0
  NIC1: mlx5_1
  NIC2: mlx5_2
  NIC3: mlx5_3
  NIC4: mlx5_4
  NIC5: mlx5_5

I have try to set export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:False. But nothing happned.

@ShangmingCai
Copy link
Collaborator

@wqlxx Try removing this config instead of setting it to False? Also, please make sure that the docker has sudo permit and is setting with privileged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests