DeepSeek is a novel family of open-weights sparse MoE models by DeepSeek AI. DeepSeek-V3 features advanced techniques, including Multi-Head Latent Attention (MLA), finer-grained and shared experts, Multi-Token Prediction (MTP), and FP8 mixed precision designed for enhanced efficiency and performance. The currently supported models are DeepSeek V3 (671B) and DeepSeek V2-Lite (16B).
Please note:
- MTP and FP8 mixed precision is not supported yet.
- To leverage MLA with Flash Attention, ensure you have the latest JAX version.
- The provided TPU configurations are examples and not mandatory.
You can training from scratch to generate a new checkpoint. One example command to run pretraining with V3 on v5p-256.
python3 -m MaxText.train MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} run_name=matmul_pre_training per_device_batch_size=4 enable_checkpointing=false model_name=deepseek3-671b ici_fsdp_parallelism=128 steps=5 max_target_length=1024 async_checkpointing=false tokenizer_type=huggingface tokenizer_path=deepseek-ai/DeepSeek-V3 attention=flash dtype=bfloat16 weight_dtype=bfloat16 megablox=False sparse_matmul=False dataset_type=synthetic
To get started, follow the instructions at HuggingFace (V3, V2-Lite) to download the model. Currently, for V3, please convert it from FP8 to BF16 using script here. Once downloaded and converted to BF16:
- run convert_deepseek_ckpt.py to convert the checkpoint for MaxText compatibility in Orbax for training and fine-tuning.
- run convert_deepseek_unscanned_ckpt.py to convert the checkpoint to unscanned version in Orbax for decoding.
After you have a MaxText compatible ckeckpoint, you could fine-tune it with different datasets.
One example command to run general finetuning with V3 on v5p-256.
python3 -m MaxText.train MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=matmul_fine_tuning per_device_batch_size=4 model_name=deepseek3-671b steps=5 max_target_length=1024 async_checkpointing=false tokenizer_type=huggingface tokenizer_path=deepseek-ai/DeepSeek-V3 attention=flash dtype=bfloat16 weight_dtype=bfloat16 megablox=False sparse_matmul=False enable_checkpointing=true ici_expert_parallelism=128 ici_fsdp_parallelism=1
One example command to run supervised finetuning with V3 on v5p-256. Supervised fine-tuning is only working with HuggingFace conversational datasets. And, you can customize the dataset path using the hf_path
config and provide your access token with hf_access_token
config.
python3 -m MaxText.sft_trainer MaxText/configs/sft.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=matmul_supervised_fine_tuning per_device_batch_size=4 model_name=deepseek3-671b steps=5 max_target_length=1024 async_checkpointing=false tokenizer_type=huggingface tokenizer_path=deepseek-ai/DeepSeek-V3 attention=flash dtype=bfloat16 weight_dtype=bfloat16 megablox=False sparse_matmul=False enable_checkpointing=true ici_expert_parallelism=128 ici_fsdp_parallelism=1 dataset_type=hf
One example command to run supervised finetuning with V3 on v5p-256 with unscanned ckeckpoint for fast decoding. When decoding with a supervised fine-tuned checkpoint, format your prompt as prompt='<user>your text</user> <assistant>'
.
python3 -m MaxText.decode MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=decode per_device_batch_size=1 enable_checkpointing=false model_name=deepseek3-671b max_prefill_predict_length=100 max_target_length=1024 tokenizer_type=huggingface tokenizer_path=deepseek-ai/DeepSeek-V3 attention=flash dtype=bfloat16 weight_dtype=bfloat16 megablox=False sparse_matmul=False ici_tensor_parallelism=128 ici_fsdp_parallelism=1 prompt="An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and outputs are all vectors. The output is " scan_layers=False
- Dropless
- MegaBlocks implementation with flag
sparse_matmul=True megablox=True
. - JAX ragged_dot implementation with flag
sparse_matmul=True megablox=False
. - Generagel dense matmul implementation with flag
sparse_matmul=False capacity_factor=-1
.
- MegaBlocks implementation with flag
- Dropping implementation with flag
sparse_matmul=False
and reasonablecapacity_factor
, commonly used from 1 to 1.25.