Skip to content

Commit 67db925

Browse files
authored
Add vllm quickstart (#10978)
* temp * add doc * finish * done * fix * add initial docker readme * temp * done fixing vllm_quickstart * done * remove not used file * add * fix
1 parent 56cb992 commit 67db925

File tree

4 files changed

+270
-0
lines changed

4 files changed

+270
-0
lines changed

.gitignore

+4
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,7 @@ __pycache__
5050
target
5151
build
5252
dist
53+
54+
# For readthedocs
55+
docs/readthedocs/requirements-doc.txt
56+
docs/readthedocs/_build/*

docs/readthedocs/source/_templates/sidebar_quicklinks.html

+3
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@
5555
<li>
5656
<a href="doc/LLM/Quickstart/fastchat_quickstart.html">Run IPEX-LLM Serving with FastChat</a>
5757
</li>
58+
<li>
59+
<a href="doc/LLM/Quickstart/vLLM_quickstart.html">Run IPEX-LLM Serving with vLLM</a>
60+
</li>
5861
<li>
5962
<a href="doc/LLM/Quickstart/axolotl_quickstart.html">Finetune LLM with Axolotl on Intel GPU</a>
6063
</li>

docs/readthedocs/source/doc/LLM/Quickstart/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ This section includes efficient guide to show you how to:
2424
* `Run Ollama with IPEX-LLM on Intel GPU <./ollama_quickstart.html>`_
2525
* `Run Llama 3 on Intel GPU using llama.cpp and ollama with IPEX-LLM <./llama3_llamacpp_ollama_quickstart.html>`_
2626
* `Run IPEX-LLM Serving with FastChat <./fastchat_quickstart.html>`_
27+
* `Run IPEX-LLM Serving wit vLLM on Intel GPU<./vLLM_quickstart.html>`_
2728
* `Finetune LLM with Axolotl on Intel GPU <./axolotl_quickstart.html>`_
2829
* `Run IPEX-LLM serving on Multiple Intel GPUs using DeepSpeed AutoTP and FastApi <./deepspeed_autotp_fastapi_quickstart.html>`_
2930

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
# Serving using IPEX-LLM and vLLM on Intel GPU
2+
3+
vLLM is a fast and easy-to-use library for LLM inference and serving. You can find the detailed information at their [homepage](https://github.com/vllm-project/vllm).
4+
5+
IPEX-LLM can be integrated into vLLM so that user can use `IPEX-LLM` to boost the performance of vLLM engine on Intel **GPUs** *(e.g., local PC with descrete GPU such as Arc, Flex and Max)*.
6+
7+
8+
## Quick Start
9+
10+
This quickstart guide walks you through installing and running `vLLM` with `ipex-llm`.
11+
12+
### 1. Install IPEX-LLM for vLLM
13+
14+
IPEX-LLM's support for `vLLM` now is available for only Linux system.
15+
16+
Visit [Install IPEX-LLM on Linux with Intel GPU](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Quickstart/install_linux_gpu.html) and follow the instructions in section [Install Prerequisites](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Quickstart/install_linux_gpu.html#install-prerequisites) to isntall prerequisites that are needed for running code on Intel GPUs.
17+
18+
Then,follow instructions in section [Install ipex-llm](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Quickstart/install_linux_gpu.html#install-ipex-llm) to install `ipex-llm[xpu]` and setup the recommended runtime configurations.
19+
20+
**After the installation, you should have created a conda environment, named `ipex-vllm` for instance, for running `vLLM` commands with IPEX-LLM.**
21+
22+
### 2. Install vLLM
23+
24+
Currently, we maintain a specific branch of vLLM, which only works on Intel GPUs.
25+
26+
Activate the `ipex-vllm` conda environment and install vLLM by execcuting the commands below.
27+
28+
```bash
29+
conda activate ipex-vllm
30+
source /opt/intel/oneapi/setvars.sh
31+
git clone -b sycl_xpu https://github.com/analytics-zoo/vllm.git
32+
cd vllm
33+
pip install -r requirements-xpu.txt
34+
pip install --no-deps xformers
35+
VLLM_BUILD_XPU_OPS=1 pip install --no-build-isolation -v -e .
36+
pip install outlines==0.0.34 --no-deps
37+
pip install interegular cloudpickle diskcache joblib lark nest-asyncio numba scipy
38+
# For Qwen model support
39+
pip install transformers_stream_generator einops tiktoken
40+
```
41+
42+
**Now you are all set to use vLLM with IPEX-LLM**
43+
44+
## 3. Offline inference/Service
45+
46+
### Offline inference
47+
48+
To run offline inference using vLLM for a quick impression, use the following example.
49+
50+
```eval_rst
51+
.. note::
52+
53+
Please modify the MODEL_PATH in offline_inference.py to use your chosen model.
54+
You can try modify load_in_low_bit to different values in **[sym_int4, fp8, fp16]** to use different quantization dtype.
55+
```
56+
57+
```bash
58+
#!/bin/bash
59+
wget https://raw.githubusercontent.com/intel-analytics/ipex-llm/main/python/llm/example/GPU/vLLM-Serving/offline_inference.py
60+
python offline_inference.py
61+
```
62+
63+
For instructions on how to change the `load_in_low_bit` value in `offline_inference.py`, check the following example:
64+
65+
```bash
66+
llm = LLM(model="YOUR_MODEL",
67+
device="xpu",
68+
dtype="float16",
69+
enforce_eager=True,
70+
# Simply change here for the desired load_in_low_bit value
71+
load_in_low_bit="sym_int4",
72+
tensor_parallel_size=1,
73+
trust_remote_code=True)
74+
```
75+
76+
The result of executing `Baichuan2-7B-Chat` model with `sym_int4` low-bit format is shown as follows:
77+
78+
```
79+
Prompt: 'Hello, my name is', Generated text: ' [Your Name] and I am a [Your Job Title] at [Your'
80+
Prompt: 'The president of the United States is', Generated text: ' the head of state and head of government in the United States. The president leads'
81+
Prompt: 'The capital of France is', Generated text: ' Paris.\nThe capital of France is Paris.'
82+
Prompt: 'The future of AI is', Generated text: " bright, but it's not without challenges. As AI continues to evolve,"
83+
```
84+
85+
### Service
86+
87+
```eval_rst
88+
.. note::
89+
90+
Because of using JIT compilation for kernels. We recommend to send a few requests for warmup before using the service for the best performance.
91+
```
92+
93+
To fully utilize the continuous batching feature of the `vLLM`, you can send requests to the service using `curl` or other similar methods. The requests sent to the engine will be batched at token level. Queries will be executed in the same `forward` step of the LLM and be removed when they are finished instead of waiting for all sequences to be finished.
94+
95+
96+
For vLLM, you can start the service using the following command:
97+
98+
```bash
99+
#!/bin/bash
100+
model="YOUR_MODEL_PATH"
101+
served_model_name="YOUR_MODEL_NAME"
102+
103+
# You may need to adjust the value of
104+
# --max-model-len, --max-num-batched-tokens, --max-num-seqs
105+
# to acquire the best performance
106+
107+
python -m ipex_llm.vllm.entrypoints.openai.api_server \
108+
--served-model-name $served_model_name \
109+
--port 8000 \
110+
--model $model \
111+
--trust-remote-code \
112+
--gpu-memory-utilization 0.75 \
113+
--device xpu \
114+
--dtype float16 \
115+
--enforce-eager \
116+
--load-in-low-bit sym_int4 \
117+
--max-model-len 4096 \
118+
--max-num-batched-tokens 10240 \
119+
--max-num-seqs 12 \
120+
--tensor-parallel-size 1
121+
```
122+
123+
You can tune the service using these four arguments:
124+
125+
1. `--gpu-memory-utilization`: The fraction of GPU memory to be used for the model executor, which can range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory utilization. If unspecified, will use the default value of 0.9.
126+
2. `--max-model-len`: Model context length. If unspecified, will be automatically derived from the model config.
127+
3. `--max-num-batched-token`: Maximum number of batched tokens per iteration.
128+
4. `--max-num-seq`: Maximum number of sequences per iteration. Default: 256
129+
130+
If the service have been booted successfully, the console will display messages similar to the following:
131+
132+
<a href="https://llm-assets.readthedocs.io/en/latest/_images/start-vllm-service.png" target="_blank">
133+
<img src="https://llm-assets.readthedocs.io/en/latest/_images/start-vllm-service.png" width=100%; />
134+
</a>
135+
136+
137+
After the service has been booted successfully, you can send a test request using `curl`. Here, `YOUR_MODEL` should be set equal to `$served_model_name` in your booting script, e.g. `Qwen1.5`.
138+
139+
140+
```bash
141+
curl http://localhost:8000/v1/completions \
142+
-H "Content-Type: application/json" \
143+
-d '{
144+
"model": "YOUR_MODEL",
145+
"prompt": "San Francisco is a",
146+
"max_tokens": 128,
147+
"temperature": 0
148+
}' | jq '.choices[0].text'
149+
```
150+
151+
Below shows an example output using `Qwen1.5-7B-Chat` with low-bit format `sym_int4`:
152+
153+
<a href="https://llm-assets.readthedocs.io/en/latest/_images/vllm-curl-result.png" target="_blank">
154+
<img src="https://llm-assets.readthedocs.io/en/latest/_images/vllm-curl-result.png" width=100%; />
155+
</a>
156+
157+
```eval_rst
158+
.. tip::
159+
160+
If your local LLM is running on Intel Arc™ A-Series Graphics with Linux OS (Kernel 6.2), it is recommended to additionaly set the following environment variable for optimal performance before starting the service:
161+
162+
.. code-block:: bash
163+
164+
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
165+
166+
```
167+
168+
## 4. About Tensor parallel
169+
170+
> Note: We recommend to use docker for tensor parallel deployment. Check our serving docker image `intelanalytics/ipex-llm-serving-xpu`.
171+
172+
We have also supported tensor parallel by using multiple Intel GPU cards. To enable tensor parallel, you will need to install `libfabric-dev` in your environment. In ubuntu, you can install it by:
173+
174+
```bash
175+
sudo apt-get install libfabric-dev
176+
```
177+
178+
To deploy your model across multiple cards, simplely change the value of `--tensor-parallel-size` to the desired value.
179+
180+
181+
For instance, if you have two Arc A770 cards in your environment, then you can set this value to 2. Some OneCCL environment variable settings are also needed, check the following example:
182+
183+
```bash
184+
#!/bin/bash
185+
model="YOUR_MODEL_PATH"
186+
served_model_name="YOUR_MODEL_NAME"
187+
188+
# CCL needed environment variables
189+
export CCL_WORKER_COUNT=2
190+
export FI_PROVIDER=shm
191+
export CCL_ATL_TRANSPORT=ofi
192+
export CCL_ZE_IPC_EXCHANGE=sockets
193+
export CCL_ATL_SHM=1
194+
# You may need to adjust the value of
195+
# --max-model-len, --max-num-batched-tokens, --max-num-seqs
196+
# to acquire the best performance
197+
198+
python -m ipex_llm.vllm.entrypoints.openai.api_server \
199+
--served-model-name $served_model_name \
200+
--port 8000 \
201+
--model $model \
202+
--trust-remote-code \
203+
--gpu-memory-utilization 0.75 \
204+
--device xpu \
205+
--dtype float16 \
206+
--enforce-eager \
207+
--load-in-low-bit sym_int4 \
208+
--max-model-len 4096 \
209+
--max-num-batched-tokens 10240 \
210+
--max-num-seqs 12 \
211+
--tensor-parallel-size 2
212+
```
213+
214+
If the service have booted successfully, you should see the output similar to the following figure:
215+
216+
<a href="https://llm-assets.readthedocs.io/en/latest/_images/start-vllm-service.png" target="_blank">
217+
<img src="https://llm-assets.readthedocs.io/en/latest/_images/start-vllm-service.png" width=100%; />
218+
</a>
219+
220+
## 5.Performing benchmark
221+
222+
To perform benchmark, you can use the **benchmark_throughput** script that is originally provided by vLLM repo.
223+
224+
```bash
225+
conda activate ipex-vllm
226+
227+
source /opt/intel/oneapi/setvars.sh
228+
229+
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
230+
231+
wget https://raw.githubusercontent.com/intel-analytics/ipex-llm/main/docker/llm/serving/xpu/docker/benchmark_vllm_throughput.py -O benchmark_throughput.py
232+
233+
export MODEL="YOUR_MODEL"
234+
235+
# You can change load-in-low-bit from values in [sym_int4, fp8, fp16]
236+
237+
python3 ./benchmark_throughput.py \
238+
--backend vllm \
239+
--dataset ./ShareGPT_V3_unfiltered_cleaned_split.json \
240+
--model $MODEL \
241+
--num-prompts 1000 \
242+
--seed 42 \
243+
--trust-remote-code \
244+
--enforce-eager \
245+
--dtype float16 \
246+
--device xpu \
247+
--load-in-low-bit sym_int4 \
248+
--gpu-memory-utilization 0.85
249+
```
250+
251+
The following figure shows the result of benchmarking `Llama-2-7b-chat-hf` using 50 prompts:
252+
253+
<a href="https://llm-assets.readthedocs.io/en/latest/_images/vllm-benchmark-result.png" target="_blank">
254+
<img src="https://llm-assets.readthedocs.io/en/latest/_images/vllm-benchmark-result.png" width=100%; />
255+
</a>
256+
257+
258+
```eval_rst
259+
.. tip::
260+
261+
To find the best config that fits your workload, you may need to start the service and use tools like `wrk` or `jmeter` to perform a stress tests.
262+
```

0 commit comments

Comments
 (0)