Skip to content

Commit aaf2e6b

Browse files
authored
[model] fix kv cache (#7564)
1 parent 9deece1 commit aaf2e6b

File tree

16 files changed

+122
-64
lines changed

16 files changed

+122
-64
lines changed

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
204204

205205
[23/08/11] We supported **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [examples](examples/README.md) for usage.
206206

207-
[23/07/31] We supported **dataset streaming**. Try `streaming: true` and `max_steps: 10000` arguments to load your dataset in streaming mode. Use `dataset_shards` to enable parallel preprocessing with streaming.
207+
[23/07/31] We supported **dataset streaming**. Try `streaming: true` and `max_steps: 10000` arguments to load your dataset in streaming mode.
208208

209209
[23/07/29] We released two instruction-tuned 13B models at Hugging Face. See these Hugging Face Repos ([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft)) for details.
210210

@@ -412,7 +412,7 @@ huggingface-cli login
412412
| CUDA | 11.6 | 12.2 |
413413
| deepspeed | 0.10.0 | 0.16.4 |
414414
| bitsandbytes | 0.39.0 | 0.43.1 |
415-
| vllm | 0.4.3 | 0.7.3 |
415+
| vllm | 0.4.3 | 0.8.2 |
416416
| flash-attn | 2.3.0 | 2.7.2 |
417417

418418
### Hardware Requirement

README_zh.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
206206

207207
[23/08/11] 我们支持了指令模型的 **[DPO 训练](https://arxiv.org/abs/2305.18290)**。详细用法请参照 [examples](examples/README_zh.md)
208208

209-
[23/07/31] 我们支持了**数据流式加载**。请使用 `streaming: true``max_steps: 10000` 参数来流式加载数据集。`dataset_shards` 来开启多进程加载。
209+
[23/07/31] 我们支持了**数据流式加载**。请使用 `streaming: true``max_steps: 10000` 参数来流式加载数据集。
210210

211211
[23/07/29] 我们在 Hugging Face 发布了两个 13B 指令微调模型。详细内容请查阅我们的 Hugging Face 项目([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft))。
212212

@@ -414,7 +414,7 @@ huggingface-cli login
414414
| CUDA | 11.6 | 12.2 |
415415
| deepspeed | 0.10.0 | 0.16.4 |
416416
| bitsandbytes | 0.39.0 | 0.43.1 |
417-
| vllm | 0.4.3 | 0.7.3 |
417+
| vllm | 0.4.3 | 0.8.2 |
418418
| flash-attn | 2.3.0 | 2.7.2 |
419419

420420
### 硬件依赖

examples/accelerate/fsdp_config.yaml

+4-4
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,16 @@ fsdp_config:
77
fsdp_backward_prefetch: BACKWARD_PRE
88
fsdp_forward_prefetch: false
99
fsdp_cpu_ram_efficient_loading: true
10-
fsdp_offload_params: true # offload may affect training speed
10+
fsdp_offload_params: false
1111
fsdp_sharding_strategy: FULL_SHARD
1212
fsdp_state_dict_type: FULL_STATE_DICT
1313
fsdp_sync_module_states: true
1414
fsdp_use_orig_params: true
1515
machine_rank: 0
1616
main_training_function: main
17-
mixed_precision: bf16 # or fp16
18-
num_machines: 1 # the number of nodes
19-
num_processes: 2 # the number of GPUs in all nodes
17+
mixed_precision: bf16 # or fp16
18+
num_machines: 1 # the number of nodes
19+
num_processes: 2 # the number of GPUs in all nodes
2020
rdzv_backend: static
2121
same_network: true
2222
tpu_env: []
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
compute_environment: LOCAL_MACHINE
2+
debug: false
3+
distributed_type: FSDP
4+
downcast_bf16: 'no'
5+
fsdp_config:
6+
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
7+
fsdp_backward_prefetch: BACKWARD_PRE
8+
fsdp_forward_prefetch: false
9+
fsdp_cpu_ram_efficient_loading: true
10+
fsdp_offload_params: true # offload may affect training speed
11+
fsdp_sharding_strategy: FULL_SHARD
12+
fsdp_state_dict_type: FULL_STATE_DICT
13+
fsdp_sync_module_states: true
14+
fsdp_use_orig_params: true
15+
machine_rank: 0
16+
main_training_function: main
17+
mixed_precision: bf16 # or fp16
18+
num_machines: 1 # the number of nodes
19+
num_processes: 2 # the number of GPUs in all nodes
20+
rdzv_backend: static
21+
same_network: true
22+
tpu_env: []
23+
tpu_use_cluster: false
24+
tpu_use_sudo: false
25+
use_cpu: false

scripts/vllm_infer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def vllm_infer(
5656
5757
Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo
5858
"""
59-
check_version("vllm>=0.4.3,<=0.7.3")
59+
check_version("vllm>=0.4.3,<=0.8.2")
6060
if pipeline_parallel_size > get_device_count():
6161
raise ValueError("Pipeline parallel size should be smaller than the number of gpus.")
6262

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def get_console_scripts() -> list[str]:
5353
"gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
5454
"awq": ["autoawq"],
5555
"aqlm": ["aqlm[gpu]>=1.1.0"],
56-
"vllm": ["vllm>=0.4.3,<=0.8.1"],
56+
"vllm": ["vllm>=0.4.3,<=0.8.2"],
5757
"sglang": ["sglang[srt]>=0.4.4", "transformers==4.48.3"],
5858
"galore": ["galore-torch"],
5959
"apollo": ["apollo-torch"],

src/llamafactory/data/loader.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,10 @@ def _load_single_dataset(
101101
split=dataset_attr.split,
102102
cache_dir=cache_dir,
103103
token=model_args.ms_hub_token,
104-
use_streaming=data_args.streaming and not data_args.dataset_shards, # only set to True when user specified streaming but do not want dataset to be sharded
104+
use_streaming=data_args.streaming,
105105
)
106106
if isinstance(dataset, MsDataset):
107107
dataset = dataset.to_hf_dataset()
108-
if data_args.streaming and data_args.dataset_shards:
109-
dataset = dataset.to_iterable_dataset(num_shards=data_args.dataset_shards)
110108

111109
elif dataset_attr.load_from == "om_hub":
112110
check_version("openmind>=0.8.0", mandatory=True)
@@ -135,10 +133,10 @@ def _load_single_dataset(
135133
token=model_args.hf_hub_token,
136134
num_proc=data_args.preprocessing_num_workers,
137135
trust_remote_code=model_args.trust_remote_code,
138-
streaming=data_args.streaming and not data_args.dataset_shards,
136+
streaming=data_args.streaming and dataset_attr.load_from != "file",
139137
)
140-
if data_args.streaming and data_args.dataset_shards:
141-
dataset = dataset.to_iterable_dataset(num_shards=data_args.dataset_shards)
138+
if data_args.streaming and dataset_attr.load_from == "file":
139+
dataset = dataset.to_iterable_dataset(num_shards=training_args.dataloader_num_workers)
142140

143141
if dataset_attr.num_samples is not None and not data_args.streaming:
144142
target_num = dataset_attr.num_samples

src/llamafactory/data/mm_plugin.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -1186,25 +1186,32 @@ def process_messages(
11861186
messages = deepcopy(messages)
11871187
if self.expand_mm_tokens:
11881188
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
1189+
else:
1190+
mm_inputs = {}
1191+
11891192
num_audio_tokens, num_image_tokens, num_video_tokens = 0, 0, 0
11901193
use_audio_in_video = getattr(processor, "use_audio_in_video", False)
11911194

11921195
# get length or size from mm_inputs
11931196
if "feature_attention_mask" in mm_inputs:
11941197
input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1
11951198
audio_lengths = (input_lengths - 2) // 2 + 1
1199+
11961200
if mm_inputs.get("image_grid_thw", None) is not None:
11971201
image_grid_thw = mm_inputs["image_grid_thw"]
11981202
merge_length = processor.omni_processor.merge_size**2
1203+
11991204
if mm_inputs.get("video_grid_thw", None) is not None:
12001205
video_grid_thw = mm_inputs["video_grid_thw"]
12011206
merge_length = processor.omni_processor.merge_size**2
12021207

12031208
if use_audio_in_video:
1204-
assert audio_lengths is not None, "audio_lengths should be exist when use_audio_in_video is `True`"
1205-
assert mm_inputs.get("video_grid_thw", None) is not None, (
1206-
"video_grid_thw should be exist when use_audio_in_video is `True`"
1207-
)
1209+
if audio_lengths is None:
1210+
raise ValueError("audio_lengths should exist when use_audio_in_video is `True`.")
1211+
1212+
if not mm_inputs.get("video_grid_thw", None):
1213+
raise ValueError("video_grid_thw should exist when use_audio_in_video is `True`.")
1214+
12081215
positions_list = []
12091216
for i, message in enumerate(messages): # get multimodal index when use_audio
12101217
positions = []
@@ -1216,6 +1223,7 @@ def process_messages(
12161223
break
12171224
positions.append((pos, special_token))
12181225
start = pos + len(special_token)
1226+
12191227
positions_list.append(positions.sort(key=lambda x: x[0]))
12201228

12211229
for message in messages:
@@ -1278,6 +1286,7 @@ def process_messages(
12781286
content = content.replace(AUDIO_PLACEHOLDER, "", 1)
12791287
num_audio_tokens += 1
12801288
num_video_tokens += 1
1289+
12811290
message["content"] = content
12821291

12831292
if len(audios) != num_audio_tokens:

src/llamafactory/extras/misc.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import gc
1919
import os
20+
import socket
2021
from typing import TYPE_CHECKING, Any, Literal, Union
2122

2223
import torch
@@ -278,10 +279,16 @@ def use_ray() -> bool:
278279

279280
def find_available_port() -> int:
280281
"""Find an available port on the local machine."""
281-
import socket
282-
283282
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
284283
sock.bind(("", 0))
285284
port = sock.getsockname()[1]
286285
sock.close()
287286
return port
287+
288+
289+
def fix_proxy(ipv6_enabled: bool) -> None:
290+
"""Fix proxy settings for gradio ui."""
291+
os.environ["no_proxy"] = "localhost,127.0.0.1,0.0.0.0"
292+
if ipv6_enabled:
293+
for name in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"):
294+
os.environ.pop(name, None)

src/llamafactory/hparams/data_args.py

-4
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,6 @@ class DataArguments:
8383
default=None,
8484
metadata={"help": "The number of processes to use for the pre-processing."},
8585
)
86-
dataset_shards: Optional[int] = field(
87-
default=None,
88-
metadata={"help": "The number of shards to split the dataset into. Only used in streaming mode. This should be set to the same as dataloader_num_workers. Not setting this while streaming data will cause the dataset to be non-sharded and thus only can be processed using one worker."},
89-
)
9086
max_samples: Optional[int] = field(
9187
default=None,
9288
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."},

src/llamafactory/hparams/parser.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def _check_extra_dependencies(
135135
check_version("mixture-of-depth>=1.1.6", mandatory=True)
136136

137137
if model_args.infer_backend == EngineName.VLLM:
138-
check_version("vllm>=0.4.3,<=0.8.1")
138+
check_version("vllm>=0.4.3,<=0.8.2")
139139
check_version("vllm", mandatory=True)
140140
elif model_args.infer_backend == EngineName.SGLANG:
141141
check_version("sglang>=0.4.4")

src/llamafactory/model/adapter.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import torch
1919
from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
2020
from transformers.integrations import is_deepspeed_zero3_enabled
21-
from transformers.modeling_utils import is_fsdp_enabled
2221

2322
from ..extras import logging
2423
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
@@ -277,14 +276,14 @@ def init_adapter(
277276

278277
# cast trainable parameters to float32 if:
279278
# 1. is_trainable and not pure_bf16 and not badam and quantization_bit is not None (qlora)
280-
# 2. is_trainable and not pure_bf16 and not badam and not zero3 and not fsdp (zero3 or fsdp already in fp32)
279+
# 2. is_trainable and not pure_bf16 and not badam and not zero3 (zero3 already in fp32)
281280
cast_trainable_params_to_fp32 = False
282281
if not is_trainable:
283282
pass
284283
elif finetuning_args.pure_bf16 or finetuning_args.use_badam:
285284
logger.info_rank0("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
286-
elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()):
287-
logger.info_rank0("ZeRO3 / FSDP detected, remaining trainable params in float32.")
285+
elif model_args.quantization_bit is None and is_deepspeed_zero3_enabled():
286+
logger.info_rank0("DeepSpeed ZeRO3 detected, remaining trainable params in float32.")
288287
else:
289288
logger.info_rank0("Upcasting trainable params to float32.")
290289
cast_trainable_params_to_fp32 = True
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright 2025 the LlamaFactory team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import TYPE_CHECKING
16+
17+
from ...extras import logging
18+
19+
20+
logger = logging.get_logger(__name__)
21+
22+
23+
if TYPE_CHECKING:
24+
from transformers import PretrainedConfig
25+
26+
from ...hparams import ModelArguments
27+
28+
29+
def configure_kv_cache(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
30+
if not is_trainable:
31+
setattr(config, "use_cache", model_args.use_cache)
32+
if hasattr(config, "text_config"):
33+
setattr(config.text_config, "use_cache", model_args.use_cache)
34+
35+
if model_args.use_cache:
36+
logger.info_rank0("KV cache is enabled for faster generation.")
37+
else:
38+
logger.info_rank0("KV cache is disabled.")
39+
else:
40+
setattr(config, "use_cache", False)
41+
if hasattr(config, "text_config"):
42+
setattr(config.text_config, "use_cache", False)
43+
44+
logger.info_rank0("KV cache is disabled during training.")

src/llamafactory/model/patcher.py

+2-11
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .model_utils.attention import configure_attn_implementation, print_attn_implementation
2828
from .model_utils.checkpointing import prepare_model_for_training
2929
from .model_utils.embedding import resize_embedding_layer
30+
from .model_utils.kv_cache import configure_kv_cache
3031
from .model_utils.longlora import configure_longlora
3132
from .model_utils.moe import add_z3_leaf_module, configure_moe
3233
from .model_utils.packing import configure_packing
@@ -102,23 +103,13 @@ def patch_config(
102103
configure_moe(config, model_args, is_trainable)
103104
configure_visual_model(config)
104105
configure_packing(model_args, is_trainable)
105-
106-
if model_args.use_cache and not is_trainable:
107-
setattr(config, "use_cache", True)
108-
logger.info_rank0("Using KV cache for faster generation.")
109-
110-
if config.architectures[0] == "Gemma3ForConditionalGeneration" and not model_args.use_cache:
111-
text_config = config.text_config
112-
setattr(text_config, "use_cache", False)
106+
configure_kv_cache(config, model_args, is_trainable)
113107

114108
if getattr(config, "model_type", None) == "qwen":
115109
setattr(config, "use_flash_attn", model_args.flash_attn == "fa2")
116110
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
117111
setattr(config, dtype_name, model_args.compute_dtype == dtype)
118112

119-
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2":
120-
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn
121-
122113
if getattr(config, "model_type", None) == "minicpmo":
123114
setattr(config, "init_audio", True)
124115
setattr(config, "init_tts", False)

src/llamafactory/webui/interface.py

+8-21
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,8 @@
1414

1515
import os
1616
import platform
17-
import httpx
1817

19-
20-
from ..extras.misc import is_env_enabled
18+
from ..extras.misc import fix_proxy, is_env_enabled
2119
from ..extras.packages import is_gradio_available
2220
from .common import save_config
2321
from .components import (
@@ -74,8 +72,9 @@ def create_ui(demo_mode: bool = False) -> "gr.Blocks":
7472

7573
def create_web_demo() -> "gr.Blocks":
7674
engine = Engine(pure_chat=True)
75+
hostname = os.getenv("HOSTNAME", os.getenv("COMPUTERNAME", platform.node())).split(".")[0]
7776

78-
with gr.Blocks(title="Web Demo", css=CSS) as demo:
77+
with gr.Blocks(title=f"LLaMA Factory Web Demo ({hostname})", css=CSS) as demo:
7978
lang = gr.Dropdown(choices=["en", "ru", "zh", "ko", "ja"], scale=1)
8079
engine.manager.add_elems("top", dict(lang=lang))
8180

@@ -90,30 +89,18 @@ def create_web_demo() -> "gr.Blocks":
9089

9190

9291
def run_web_ui() -> None:
93-
os.environ["no_proxy"] = "localhost,127.0.0.1,0.0.0.0"
9492
gradio_ipv6 = is_env_enabled("GRADIO_IPV6")
9593
gradio_share = is_env_enabled("GRADIO_SHARE")
9694
server_name = os.getenv("GRADIO_SERVER_NAME", "[::]" if gradio_ipv6 else "0.0.0.0")
97-
httpx.HTTPCORE_OPTS = {"trust_env": False}
98-
99-
try:
100-
demo = create_ui().queue()
101-
demo.launch(
102-
share=gradio_share,
103-
server_name=server_name,
104-
inbrowser=True,
105-
prevent_thread_lock=False,
106-
show_error=True,
107-
quiet=True,
108-
favicon_path=None
109-
)
110-
except Exception as e:
111-
print(f"Error launching web UI: {str(e)}")
112-
raise
95+
print("Visit http://ip:port for Web UI, e.g., http://127.0.0.1:7860")
96+
fix_proxy(ipv6_enabled=gradio_ipv6)
97+
create_ui().queue().launch(share=gradio_share, server_name=server_name, inbrowser=True)
11398

11499

115100
def run_web_demo() -> None:
116101
gradio_ipv6 = is_env_enabled("GRADIO_IPV6")
117102
gradio_share = is_env_enabled("GRADIO_SHARE")
118103
server_name = os.getenv("GRADIO_SERVER_NAME", "[::]" if gradio_ipv6 else "0.0.0.0")
104+
print("Visit http://ip:port for Web UI, e.g., http://127.0.0.1:7860")
105+
fix_proxy(ipv6_enabled=gradio_ipv6)
119106
create_web_demo().queue().launch(share=gradio_share, server_name=server_name, inbrowser=True)

0 commit comments

Comments
 (0)