Skip to content

Commit 176f071

Browse files
authored
Improve the readme of vineyard llm kv cache. (#1888)
Fixes #1886 Signed-off-by: Ye Cao <[email protected]>
1 parent 397d274 commit 176f071

File tree

2 files changed

+379
-27
lines changed

2 files changed

+379
-27
lines changed

modules/llm-cache/README.md

Lines changed: 379 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,379 @@
1+
# Vineyard LLM KV Cache
2+
3+
## Background
4+
5+
Large Language Models (LLMs) are popular for their ability to generate content and solve complex tasks. However, LLM inference can be costly due to extensive GPU use and slow service engine speeds, particularly in multiple conversations. With rising demand, optimizing LLM inference throughput in multi-turn dialogues and cutting costs is crucial.
6+
7+
Specifically, the inference of LLM contains two phase: **Prefill** and **Decode**. The **Prefill** is to calculate the KV Cache of input tokens and the **Decode** is to generate the output tokens based on the calculated KV Cache. In multi-turn dialogues, the current input token will be superimposed with the previous output and input into the model as the new input for inference. The KV Cache of the previous input tokens can be reused in the **Prefill** phase, which can slow down the First Token Time (FTT) and improve the overall throughput.
8+
9+
To address the above issues, we have integrated Vineyard into LLM inference scenarios. There are currently two implementation methods: **radix tree** + **vineyard blob** and **chunk token hash** + **distributed filesystem**.
10+
11+
## Design
12+
13+
### Radix Tree + Vineyard Blob
14+
15+
In this method, the tokens are constructed as a radix tree and the KV tensors of these tokens are stored in Vineyard Blob (Use Memory). Also, we have some memory optimization strategies to reduce the memory usage of the radix tree such as LRU(Least Recently Used) cache and pruning.
16+
17+
18+
### Token Chunk Hash + Distributed FileSystem
19+
20+
In this method, the tokens are chunked (e,g. 16 or 32 tokens per chunk) as a hash and the KV tensors of these tokens are stored in a distributed filesystem. Besides, we have some GC(Garbage Collection) strategies to reduce the KV tensors in the distributed filesystem.
21+
22+
### Comparison
23+
24+
In this section, we will compare the two methods in terms of latency and suitable scenarios.
25+
26+
**Latency**: In a single machine, the `radix tree + vineyard blob` is faster than the `token chunk hash + distributed filesystem` method as it uses memory to store the KV tensors. When it comes to a distributed environment, the metadata synchronization from Etcd of vineyard blob will be a bottleneck.
27+
28+
29+
**Suitable Scenarios**: The main factor in choosing the method is the scenario scale. If you only want to run the LLM inference in a single machine, the `radix tree + vineyard blob` method is a better choice. If you want to run the LLM inference in a distributed environment, the `token chunk hash + distributed filesystem` method is a better choice.
30+
31+
32+
## Usage
33+
34+
We provide [C++](https://github.com/v6d-io/v6d/blob/main/modules/llm-cache/ds/kv_state_cache_manager.h) and [Python](https://github.com/v6d-io/v6d/blob/main/python/vineyard/llm/__init__.py) APIs for Vineyard LLM KV Cache. Based on the inference framework, you can use the corresponding API to integrate the Vineyard LLM KV Cache.
35+
36+
### C++ API
37+
38+
1. First, you need to install the required dependencies.
39+
40+
```bash
41+
$ cd v6d && git submodule update --init --recursive
42+
```
43+
44+
2. Then, you can build the vineyard server and vineyard llm kv cache library.
45+
46+
```bash
47+
$ mkdir build && cd build
48+
$ cmake .. -DCMAKE_BUILD_TYPE=Release \
49+
-DBUILD_SHARED_LIBS=ON \
50+
-DUSE_STATIC_BOOST_LIBS=OFF \
51+
-DBUILD_VINEYARD_SERVER=ON \
52+
-DBUILD_VINEYARD_CLIENT=OFF \
53+
-DBUILD_VINEYARD_PYTHON_BINDINGS=OFF \
54+
-DBUILD_VINEYARD_PYPI_PACKAGES=OFF \
55+
-DBUILD_VINEYARD_LLM_CACHE=ON \
56+
-DBUILD_VINEYARD_BASIC=OFF \
57+
-DBUILD_VINEYARD_GRAPH=OFF \
58+
-DBUILD_VINEYARD_IO=OFF \
59+
-DBUILD_VINEYARD_HOSSEINMOEIN_DATAFRAME=OFF \
60+
-DBUILD_VINEYARD_TESTS=ON \
61+
-DBUILD_VINEYARD_TESTS_ALL=OFF \
62+
-DBUILD_VINEYARD_PROFILING=OFF
63+
$ make -j
64+
$ make vineyard_llm_cache_tests -j
65+
```
66+
67+
After the build, you can check the `vineyardd` and `libvineyard_llm_cache.so` in the `build` directory.
68+
69+
```bash
70+
$ ls build/bin
71+
vineyardd
72+
$ ls /usr/local/lib/libvineyard_llm_cache.so
73+
/usr/local/lib/libvineyard_llm_cache.so
74+
```
75+
76+
3. Run the vineyard llm kv cache test.
77+
78+
- First, Build the vineyard llm kv cache test as follows.
79+
80+
```bash
81+
$ cd build && make vineyard_llm_cache_tests -j
82+
```
83+
84+
- Open a terminal to start the vineyard server.
85+
86+
```bash
87+
$ ./build/bin/vineyardd --socket /tmp/vineyard_test.sock
88+
```
89+
90+
Then open another terminal to run the vineyard llm kv cache test.
91+
92+
```bash
93+
$ ./bin/kv_state_cache_test --client-num 1 --vineyard-ipc-sockets /tmp/vineyard_test.sock
94+
```
95+
96+
For more information about how to use the C++ API, you can refer to the the [C++ API implementation](https://github.com/v6d-io/v6d/blob/main/modules/llm-cache/ds/kv_state_cache_manager.cc) and the [related tests](https://github.com/v6d-io/v6d/tree/main/modules/llm-cache/tests).
97+
98+
99+
### Python API
100+
101+
1. First, same as the C++ API, you need to install the required dependencies.
102+
103+
```bash
104+
$ cd v6d && git submodule update --init --recursive
105+
```
106+
107+
2. Then, you can build the vineyard server and vineyard llm kv cache python
108+
library.
109+
110+
```bash
111+
$ mkdir build && cd build
112+
$ cmake .. -DCMAKE_BUILD_TYPE=Release \
113+
-DBUILD_SHARED_LIBS=ON \
114+
-DUSE_STATIC_BOOST_LIBS=OFF \
115+
-DBUILD_VINEYARD_SERVER=ON \
116+
-DBUILD_VINEYARD_CLIENT=OFF \
117+
-DBUILD_VINEYARD_PYTHON_BINDINGS=ON \
118+
-DBUILD_VINEYARD_PYPI_PACKAGES=OFF \
119+
-DBUILD_VINEYARD_LLM_CACHE=ON \
120+
-DBUILD_VINEYARD_BASIC=OFF \
121+
-DBUILD_VINEYARD_GRAPH=OFF \
122+
-DBUILD_VINEYARD_IO=OFF \
123+
-DBUILD_VINEYARD_HOSSEINMOEIN_DATAFRAME=OFF \
124+
-DBUILD_VINEYARD_TESTS=ON \
125+
-DBUILD_VINEYARD_TESTS_ALL=OFF \
126+
-DBUILD_VINEYARD_PROFILING=OFF
127+
$ make -j
128+
$ make vineyard_llm_python -j
129+
```
130+
131+
3. After the build, you can run the vineyard llm kv cache test as follows.
132+
133+
**Radix Tree + Vineyard Blob**
134+
135+
- Open a terminal to run the vineyard server.
136+
137+
```bash
138+
$ ./build/bin/vineyardd --socket /tmp/vineyard_test.sock
139+
```
140+
141+
- Open another terminal to enable the vineyard llm kv cache python module.
142+
143+
```bash
144+
export PYTHONPATH=/INPUT_YOUR_PATH_HERE/v6d/python:$PYTHONPATH
145+
```
146+
147+
- Then you can run the following python code to test the vineyard llm kv cache.
148+
149+
```python
150+
import numpy as np
151+
import vineyard
152+
153+
from vineyard.llm import KVCache
154+
from vineyard.llm import KVTensor
155+
from vineyard.llm.config import FileCacheConfig
156+
from vineyard.llm.config import VineyardCacheConfig
157+
158+
vineyard_cache_config = VineyardCacheConfig(
159+
socket="/tmp/vineyard_test.sock"
160+
block_size=5,
161+
sync_interval=3,
162+
llm_cache_sync_lock="llmCacheSyncLock",
163+
llm_cache_object_name="llm_cache_object",
164+
llm_ref_cnt_object_name="llm_refcnt_object",
165+
)
166+
cache = KVCache(
167+
cache_config=vineyard_cache_config,
168+
tensor_bytes=16, # should be the same as the nbytes of the tensor
169+
cache_capacity=10,
170+
layer=2,
171+
)
172+
173+
tokens = [1, 2, 3, 4]
174+
175+
kv_tensors_to_update = []
176+
kv_tensors = []
177+
for _ in range(len(tokens)):
178+
k_tensor = np.random.rand(2, 2).astype(np.float32)
179+
v_tensor = np.random.rand(2, 2).astype(np.float32)
180+
kv_tensors.append([(k_tensor, v_tensor) for _ in range(cache.layer)])
181+
kv_tensors_to_update.append(
182+
[
183+
(
184+
KVTensor(k_tensor.ctypes.data, k_tensor.nbytes),
185+
KVTensor(v_tensor.ctypes.data, v_tensor.nbytes),
186+
)
187+
for _ in range(cache.layer)
188+
]
189+
)
190+
191+
# insert the token list and the related kv cache list
192+
updated = cache.update(None, tokens, kv_tensors_to_update)
193+
assert updated == len(tokens)
194+
195+
kv_tensors_to_query = []
196+
kv_tensors_from_cache = []
197+
for _ in range(len(tokens)):
198+
kv_tensors_to_query.append(
199+
[
200+
(
201+
KVTensor(0, 0),
202+
KVTensor(0, 0),
203+
)
204+
for _ in range(cache.layer)
205+
]
206+
)
207+
208+
matched = cache.query(tokens, kv_tensors_to_query)
209+
kv_tensors_from_cache = kv_tensors_to_query[:matched]
210+
assert matched == len(tokens)
211+
212+
assert len(kv_tensors) == len(kv_tensors_from_cache)
213+
for kv, kv_from_cache in zip(kv_tensors, kv_tensors_from_cache):
214+
assert len(kv) == len(kv_from_cache)
215+
for (k_tensor, v_tensor), (queried_k_tensor, queried_v_tensor) in zip(
216+
kv, kv_from_cache
217+
):
218+
queried_k_tensor = np.frombuffer(
219+
queried_k_tensor,
220+
dtype=k_tensor.dtype,
221+
).reshape(k_tensor.shape)
222+
queried_v_tensor = np.frombuffer(
223+
queried_v_tensor,
224+
dtype=v_tensor.dtype,
225+
).reshape(v_tensor.shape)
226+
assert np.array_equal(k_tensor, queried_k_tensor)
227+
assert np.array_equal(v_tensor, queried_v_tensor)
228+
```
229+
230+
**Token Chunk Hash + Distributed FileSystem**
231+
232+
Same as previous step, you need to enable the vineyard llm kv cache python module.
233+
234+
```bash
235+
$ export PYTHONPATH=/INPUT_YOUR_PATH_HERE/v6d/python:$PYTHONPATH
236+
```
237+
238+
- Then you can the following python code to run the vineyard llm kv cache test.
239+
240+
```python
241+
import numpy as np
242+
import vineyard
243+
244+
from vineyard.llm import KVCache
245+
from vineyard.llm import KVTensor
246+
from vineyard.llm.config import FileCacheConfig
247+
from vineyard.llm.config import VineyardCacheConfig
248+
249+
file_cache_config = FileCacheConfig(
250+
chunk_size=2,
251+
split_number=2,
252+
root="/tmp/vineyard/llm_cache",
253+
)
254+
cache = KVCache(
255+
cache_config=file_cache_config,
256+
tensor_bytes=16, # should be the same as the nbytes of the tensor
257+
cache_capacity=10,
258+
layer=2,
259+
)
260+
261+
tokens = [1, 2, 3, 4]
262+
original_kv_tensors = []
263+
for i in range(0, len(tokens), file_cache_config.chunk_size):
264+
kv_tensors_to_update = []
265+
k_tensor = np.random.rand(2, 2).astype(np.float32)
266+
v_tensor = np.random.rand(2, 2).astype(np.float32)
267+
for _ in range(file_cache_config.chunk_size):
268+
original_kv_tensors.append(
269+
[(k_tensor, v_tensor) for _ in range(cache.layer)]
270+
)
271+
kv_tensors_to_update.append(
272+
[
273+
(
274+
KVTensor(k_tensor.ctypes.data, k_tensor.nbytes),
275+
KVTensor(v_tensor.ctypes.data, v_tensor.nbytes),
276+
)
277+
for _ in range(cache.layer)
278+
]
279+
)
280+
updated = cache.update(
281+
tokens[:i],
282+
tokens[i : i + file_cache_config.chunk_size],
283+
kv_tensors_to_update,
284+
)
285+
assert updated == file_cache_config.chunk_size
286+
287+
kv_tensors_from_cache = []
288+
kv_tensors = []
289+
for _ in range(len(tokens)):
290+
k_tensor = np.empty((2, 2), dtype=np.float32)
291+
v_tensor = np.empty((2, 2), dtype=np.float32)
292+
kv_tensors_from_cache.append([(k_tensor, v_tensor) for _ in range(cache.layer)])
293+
kv_tensors.append(
294+
[
295+
(
296+
KVTensor(k_tensor.ctypes.data, k_tensor.nbytes),
297+
KVTensor(v_tensor.ctypes.data, v_tensor.nbytes),
298+
)
299+
for _ in range(cache.layer)
300+
]
301+
)
302+
matched = cache.query(tokens, kv_tensors)
303+
assert matched == len(tokens)
304+
305+
assert len(kv_tensors) == len(kv_tensors_from_cache)
306+
for kv, kv_from_cache in zip(original_kv_tensors, kv_tensors_from_cache):
307+
assert len(kv) == len(kv_from_cache)
308+
for (k_tensor, v_tensor), (queried_k_tensor, queried_v_tensor) in zip(
309+
kv, kv_from_cache
310+
):
311+
np.array_equal(k_tensor, queried_k_tensor)
312+
np.array_equal(v_tensor, queried_v_tensor)
313+
```
314+
315+
After running the above code, you can check the KV Tensor file under the directory `/tmp/vineyard/llm_cache` as follows.
316+
317+
```bash
318+
$ ls /tmp/vineyard/llm_cache
319+
44 c3 __temp
320+
```
321+
322+
### Performance
323+
324+
We have conducted some performance tests on the `Token Chunk Hash + Distributed FileSystem`.
325+
The test environment includes the local SSD and distributed FS.
326+
327+
**Based on SSD**
328+
329+
The max read throughput of SSD is around 3GiB/s, the max write throughput of SSD is around 1.5GiB/s. Based on the machine, we can get the performance of vineyard llm kv cache as follows.
330+
331+
| query (token/s) | update (token/s) |
332+
|-----------------|------------------|
333+
| 605 | 324 |
334+
335+
The kv tensor size of a token is around 5MB, and the throughput is as follows.
336+
337+
| query (MiB/s) | update (MiB/s) |
338+
|-----------------|------------------|
339+
| 605 * 5 = 3025 | 324 * 5 = 1620 |
340+
341+
342+
**Based on DFS**
343+
344+
We use the [Aliyun CPFS](https://www.aliyun.com/product/nas_cpfs) as the dfs in the benchmark test. The max write throughput of CPFS is around 20GB/s, and the max read throughput is 40GB/s. Based on the CPFS, we test the throughput of fio with multiple
345+
worker, which can be regarded as a CPFS client.
346+
347+
| worker | write (MiB/s) | read (MiB/s) | CPFS aggregate bandwidth (write/read) |
348+
|--------|---------------|--------------|---------------------------------------|
349+
| 1 | 1315 | 2016 | 1315 / 2016 |
350+
| 2 | 1175 | 1960 | 2360 / 3920 |
351+
| 4 | 928 | 1780 | 3712 / 7120 |
352+
| 8 | 895 | 1819 | 7160 / 14552 |
353+
| 16 | 638 | 1609 | 10208 / 25744 |
354+
| 32 | 586 | 1308 | 18752 / 41856 |
355+
356+
We test the vineyard llm kv cache with 32 workers, and the throughput of a single worker
357+
is as follows.
358+
359+
| query (token/s) | update (token/s) |
360+
|-----------------|------------------|
361+
| 375 | 252 |
362+
363+
Same as the SSD, the kv tensor size of a token is around 5MB, and the throughput is as follows.
364+
365+
| query (MiB/s) | update (MiB/s) |
366+
|-----------------|------------------|
367+
| 375 * 5 = 1875 | 252 * 5 = 1260 |
368+
369+
### Conclusion
370+
371+
`Radix Tree + Vineyard Blob` is highly affected by the synchronization of the metadata from Etcd, which is a bottleneck in the distributed environment. In the future, we can leverage the RDMA to support fast remote read/write and reduce the synchronization cost of the metadata with new architecture such as Master-Slave.
372+
373+
`Token Chunk Hash + Distributed Filesystem` can make full use of the bandwidth of SSD and DFS, which can ensure that the overall inference throughput is improved at a lower SLO.
374+
375+
### Future work
376+
377+
- Support the RDMA.
378+
- Create multiple replicas of an object in different instances, which can serve read request concurrently.
379+
- Implement a load balancer to balance the burden of different vineyardd instances and the requests from the clients.

0 commit comments

Comments
 (0)