Skip to content

Commit a1b582e

Browse files
committed
Merge remote-tracking branch 'origin/main' into main_w8a8_fp8
2 parents 8c3dc13 + 27acf63 commit a1b582e

29 files changed

+543
-185
lines changed

.github/workflows/pr-test-sgl-kernel.yml

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,20 +30,55 @@ jobs:
3030
clangFormatVersion: 16
3131
style: file
3232

33+
build-wheels:
34+
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
35+
runs-on: ubuntu-latest
36+
strategy:
37+
matrix:
38+
python-version: ['3.10']
39+
cuda-version: ['12.4']
40+
41+
steps:
42+
- uses: actions/checkout@v4
43+
with:
44+
submodules: 'recursive'
45+
46+
- name: Set up Python ${{ matrix.python-version }}
47+
uses: actions/setup-python@v5
48+
with:
49+
python-version: ${{ matrix.python-version }}
50+
51+
- name: Build wheels for Python ${{ matrix.python-version }} and CUDA ${{ matrix.cuda-version }}
52+
run: |
53+
cd sgl-kernel
54+
chmod +x ./build.sh
55+
./build.sh "${{ matrix.python-version }}" "${{ matrix.cuda-version }}"
56+
57+
- name: Upload artifacts
58+
uses: actions/upload-artifact@v4
59+
with:
60+
name: wheel-python${{ matrix.python-version }}-cuda${{ matrix.cuda-version }}
61+
path: sgl-kernel/dist/*
62+
3363
unit-test:
3464
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
65+
needs: build-wheels
3566
runs-on: 1-gpu-runner
3667
steps:
3768
- uses: actions/checkout@v4
3869

70+
- name: Download artifacts
71+
uses: actions/download-artifact@v4
72+
with:
73+
path: sgl-kernel/dist/
74+
merge-multiple: true
75+
pattern: wheel-*
76+
3977
- name: Install
4078
run: |
4179
pip3 install torch==2.5.1 && pip3 install pytest && pip3 install vllm==0.6.4.post1
4280
pip3 uninstall sgl-kernel -y || true
43-
find . -name index.lock -delete
44-
cd sgl-kernel
45-
git submodule deinit --all --force && git submodule sync --recursive && git submodule update --init --force --recursive
46-
pip3 install .
81+
pip3 install sgl-kernel/dist/*whl --force-reinstall --no-deps
4782
pip3 list | grep sgl-kernel
4883
4984
- name: Run test

.github/workflows/pr-test.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ jobs:
4343
4444
- name: Run test
4545
timeout-minutes: 10
46+
env:
47+
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
4648
run: |
4749
cd test/lang
4850
python3 run_suite.py --suite per-commit
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
name: Release SGLang Kernel Wheel (cu118)
2+
3+
on:
4+
workflow_dispatch:
5+
inputs:
6+
tag_name:
7+
type: string
8+
push:
9+
branches:
10+
- main
11+
paths:
12+
- sgl-kernel/version.py
13+
14+
jobs:
15+
build-wheels:
16+
if: github.repository == 'sgl-project/sglang'
17+
runs-on: ubuntu-latest
18+
strategy:
19+
matrix:
20+
python-version: ['3.9', '3.10', '3.11', '3.12']
21+
cuda-version: ['11.8']
22+
23+
steps:
24+
- uses: actions/checkout@v4
25+
with:
26+
submodules: 'recursive'
27+
28+
- name: Set up Python ${{ matrix.python-version }}
29+
uses: actions/setup-python@v5
30+
with:
31+
python-version: ${{ matrix.python-version }}
32+
33+
- name: Build wheels for Python ${{ matrix.python-version }} and CUDA ${{ matrix.cuda-version }}
34+
run: |
35+
cd sgl-kernel
36+
chmod +x ./build.sh
37+
./build.sh "${{ matrix.python-version }}" "${{ matrix.cuda-version }}"
38+
39+
- name: Upload artifacts
40+
uses: actions/upload-artifact@v4
41+
with:
42+
name: wheel-python${{ matrix.python-version }}-cuda${{ matrix.cuda-version }}
43+
path: sgl-kernel/dist/*
44+
45+
release:
46+
needs: build-wheels
47+
runs-on: ubuntu-latest
48+
steps:
49+
- uses: actions/checkout@v4
50+
51+
- name: Download artifacts
52+
uses: actions/download-artifact@v4
53+
with:
54+
path: sgl-kernel/dist/
55+
merge-multiple: true
56+
pattern: wheel-*
57+
58+
- name: Set tag name
59+
id: set_tag_name
60+
run: |
61+
if [ -z "${{ inputs.tag_name }}" ]; then
62+
TAG_NAME="v$(cat sgl-kernel/version.py | cut -d'"' -f2)"
63+
echo "tag_name=$TAG_NAME" >> $GITHUB_OUTPUT
64+
else
65+
echo "tag_name=${{ inputs.tag_name }}" >> $GITHUB_OUTPUT
66+
fi
67+
68+
- name: Release
69+
uses: softprops/action-gh-release@v2
70+
with:
71+
tag_name: ${{ steps.set_tag_name.outputs.tag_name }}
72+
repository: sgl-project/whl
73+
token: ${{ secrets.WHL_TOKEN }}
74+
files: |
75+
sgl-kernel/dist/*
76+
77+
- name: Clone wheel index
78+
run: git clone https://oauth2:${WHL_TOKEN}@github.com/sgl-project/whl.git sgl-whl
79+
env:
80+
WHL_TOKEN: ${{ secrets.WHL_TOKEN }}
81+
82+
- name: Update wheel index
83+
run: python3 scripts/update_kernel_whl_index.py
84+
85+
- name: Push wheel index
86+
run: |
87+
cd sgl-whl
88+
git config --local user.name "github-actions[bot]"
89+
git config --local user.email "41898282+github-actions[bot]@users.noreply.github.com"
90+
git add -A
91+
git commit -m "update whl index"
92+
git push

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,6 @@
77
[submodule "sgl-kernel/3rdparty/flashinfer"]
88
path = sgl-kernel/3rdparty/flashinfer
99
url = https://github.com/flashinfer-ai/flashinfer.git
10+
[submodule "sgl-kernel/3rdparty/turbomind"]
11+
path = sgl-kernel/3rdparty/turbomind
12+
url = https://github.com/InternLM/turbomind

benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import argparse
22
import itertools
3-
import time
43

54
import torch
65
import triton

docs/references/supported_models.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
- XVERSE / XVERSE MoE
2929
- SmolLM
3030
- GLM-4
31+
- Phi-3 / Phi-4
3132
- Phi-3-Small
3233
- IBM Granite 3
3334

python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,16 @@
33
import torch
44

55
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
6-
from sglang.srt.utils import is_cuda_available
6+
from sglang.srt.utils import get_compiler_backend
77

8-
is_cuda = is_cuda_available()
9-
if is_cuda:
10-
from sgl_kernel import sampling_scaling_penalties
8+
9+
@torch.compile(dynamic=True, backend=get_compiler_backend())
10+
def apply_scaling_penalties(logits, scaling_penalties):
11+
logits[:] = torch.where(
12+
logits > 0,
13+
logits / scaling_penalties,
14+
logits * scaling_penalties,
15+
)
1116

1217

1318
class BatchedRepetitionPenalizer(_BatchedPenalizer):
@@ -61,16 +66,7 @@ def _cumulate_output_tokens(self, output_ids: _TokenIDs):
6166
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
6267

6368
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
64-
if is_cuda:
65-
return sampling_scaling_penalties(
66-
logits, self.cumulated_repetition_penalties
67-
)
68-
else:
69-
return torch.where(
70-
logits > 0,
71-
logits / self.cumulated_repetition_penalties,
72-
logits * self.cumulated_repetition_penalties,
73-
)
69+
apply_scaling_penalties(logits, self.cumulated_repetition_penalties)
7470

7571
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
7672
self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]

python/sglang/srt/sampling/sampling_batch_info.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,11 @@
77

88
import torch
99

10-
from sglang.srt.utils import is_cuda_available
11-
12-
is_cuda = is_cuda_available()
13-
if is_cuda:
14-
from sgl_kernel import sampling_scaling_penalties
15-
1610
import sglang.srt.sampling.penaltylib as penaltylib
1711
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
12+
from sglang.srt.sampling.penaltylib.penalizers.repetition_penalty import (
13+
apply_scaling_penalties,
14+
)
1815

1916
logger = logging.getLogger(__name__)
2017

@@ -386,14 +383,7 @@ def apply_logits_bias(self, logits: torch.Tensor):
386383

387384
# repetition
388385
if self.scaling_penalties is not None:
389-
if is_cuda:
390-
logits[:] = sampling_scaling_penalties(logits, self.scaling_penalties)
391-
else:
392-
logits[:] = torch.where(
393-
logits > 0,
394-
logits / self.scaling_penalties,
395-
logits * self.scaling_penalties,
396-
)
386+
apply_scaling_penalties(logits, self.scaling_penalties)
397387

398388
# Apply regex vocab_mask
399389
if self.vocab_mask is not None:

scripts/update_kernel_whl_index.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Reference: https://github.com/flashinfer-ai/flashinfer/blob/v0.2.0/scripts/update_whl_index.py
2+
3+
import hashlib
4+
import pathlib
5+
import re
6+
7+
for path in sorted(pathlib.Path("sgl-kernel/dist").glob("*.whl")):
8+
with open(path, "rb") as f:
9+
sha256 = hashlib.sha256(f.read()).hexdigest()
10+
ver = re.findall(r"sgl_kernel-([0-9.]+(?:\.post[0-9]+)?)-", path.name)[0]
11+
index_dir = pathlib.Path(f"sgl-whl/cu118/sgl-kernel")
12+
index_dir.mkdir(exist_ok=True)
13+
base_url = "https://github.com/sgl-project/whl/releases/download"
14+
full_url = f"{base_url}/v{ver}/{path.name}#sha256={sha256}"
15+
with (index_dir / "index.html").open("a") as f:
16+
f.write(f'<a href="{full_url}">{path.name}</a><br>\n')

sgl-kernel/3rdparty/turbomind

Submodule turbomind added at 0c9d0c7

sgl-kernel/Makefile

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
.PHONY: tree ln submodule install build clean test format
1+
.PHONY: tree ln submodule install build clean rebuild test format
22

33
tree:
44
@tree --prune -I "__pycache__|*.egg-info|*.so|build|3rdparty|dist"
@@ -13,11 +13,14 @@ install: submodule
1313
@pip install -e .
1414

1515
build: submodule
16-
@export MAX_JOBS=$(nproc) && python3 setup.py bdist_wheel
16+
@rm -rf dist/* || true && export MAX_JOBS=$(nproc) && python3 setup.py bdist_wheel && pip3 install dist/*whl --force-reinstall --no-deps
1717

1818
clean:
1919
@rm -rf build dist *.egg-info
2020

21+
rebuild: clean submodule build
22+
@echo "Succeed to rebuild"
23+
2124
test:
2225
@find tests -name "test_*.py" | xargs -n 1 python3
2326

sgl-kernel/README.md

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,19 @@
11
# SGL Kernel
22

3-
Kernel Library for SGLang
3+
[Kernel Library](https://github.com/sgl-project/sglang/tree/main/sgl-kernel) for SGLang
44

55
[![PyPI](https://img.shields.io/pypi/v/sgl-kernel)](https://pypi.org/project/sgl-kernel)
6+
7+
## Installation
8+
9+
For CUDA 11.8:
10+
11+
```bash
12+
pip3 install sgl-kernel -i https://docs.sglang.ai/whl/cu118
13+
```
14+
15+
For CUDA 12.1 or CUDA 12.4:
16+
17+
```bash
18+
pip3 install sgl-kernel
19+
```

sgl-kernel/build.sh

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@ PYTHON_VERSION=$1
44
CUDA_VERSION=$2
55
PYTHON_ROOT_PATH=/opt/python/cp${PYTHON_VERSION//.}-cp${PYTHON_VERSION//.}
66

7+
if (( ${CUDA_VERSION%.*} < 12 )); then
8+
ENABLE_SM90A=0
9+
else
10+
ENABLE_SM90A=1
11+
fi
12+
713
docker run --rm \
814
-v "$(pwd)":/sgl-kernel \
915
pytorch/manylinux-builder:cuda${CUDA_VERSION} \
@@ -13,7 +19,7 @@ docker run --rm \
1319
export CUDA_VERSION=${CUDA_VERSION} && \
1420
export SGL_KERNEL_ENABLE_BF16=1 && \
1521
export SGL_KERNEL_ENABLE_FP8=1 && \
16-
export SGL_KERNEL_ENABLE_SM90A=1 && \
22+
export SGL_KERNEL_ENABLE_SM90A=${ENABLE_SM90A} && \
1723
mkdir -p /usr/lib/x86_64-linux-gnu/ && \
1824
ln -s /usr/local/cuda-${CUDA_VERSION}/targets/x86_64-linux/lib/stubs/libcuda.so /usr/lib/x86_64-linux-gnu/libcuda.so && \
1925
cd /sgl-kernel && \

sgl-kernel/developer_guide.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,25 +19,25 @@ Third-party libraries:
1919
- [CCCL](https://github.com/NVIDIA/cccl)
2020
- [CUTLASS](https://github.com/NVIDIA/cutlass)
2121
- [FlashInfer](https://github.com/flashinfer-ai/flashinfer)
22+
- [TurboMind](https://github.com/InternLM/turbomind)
2223

2324
### Kernel Development
2425

2526
Steps to add a new kernel:
2627

2728
1. Implement in [src/sgl-kernel/csrc/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/src/sgl-kernel/csrc)
28-
2. Expose interface in [csrc/sgl_kernel_ops.cu](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu) with pybind11
29-
3. Create Python wrapper in [src/sgl-kernel/ops/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/ops/__init__.py)
30-
4. Expose Python interface in [src/sgl-kernel/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/__init__.py)
31-
5. Update [setup.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/setup.py) to include new CUDA source
29+
2. Expose interface in [src/sgl-kernel/include/sgl_kernel_ops.h](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/include/sgl_kernel_ops.h)
30+
3. Create torch extension in [src/sgl-kernel/torch_extension.cc](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/torch_extension.cc)
31+
4. Create Python wrapper in [src/sgl-kernel/ops/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/ops/__init__.py)
32+
5. Expose Python interface in [src/sgl-kernel/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/__init__.py)
33+
6. Update [setup.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/setup.py) to include new CUDA source
3234

3335
### Build & Install
3436

3537
Development build:
3638

3739
```bash
3840
make build
39-
pip3 install dist/*whl --force-reinstall --no-deps
40-
# Or use: make install (runs pip install -e .)
4141
```
4242

4343
### Testing & Benchmarking

sgl-kernel/pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "sgl-kernel"
7-
version = "0.0.2.post16"
7+
version = "0.0.2.post17"
88
description = "Kernel Library for SGLang"
99
readme = "README.md"
1010
requires-python = ">=3.9"
@@ -17,7 +17,7 @@ classifiers = [
1717
dependencies = []
1818

1919
[project.urls]
20-
"Homepage" = "https://github.com/sgl-project/sglang"
20+
"Homepage" = "https://github.com/sgl-project/sglang/tree/main/sgl-kernel"
2121
"Bug Tracker" = "https://github.com/sgl-project/sglang/issues"
2222

2323
[tool.setuptools]

0 commit comments

Comments
 (0)