Description
Describe the bug
Hi team:
In a distributed scenario, we found that when a Triton kernel runs independently on two devices located on different PVC cards, a segmentation fault occurs. I have tested that even an empty kernel crashes in this setup, but if both devices are on the same card, it works fine. Additionally, I found that the issue does not occur with SYCL kernels. I have simplified the problem and provided a reproducer, which can be easily reproduced on a multi-card PVC setup.
Reproducer:
import os
import multiprocessing
import torch
import triton
import triton.language as tl
# The tile 0, 1 is on the first card, the case can pass
# os.environ["ZE_AFFINITY_MASK"] = "0,1"
# The tile 1 is on the first card, while tile 2 is on the second card.
# Then the case failed even with empty triton kernel
os.environ["ZE_AFFINITY_MASK"] = "1,2"
@triton.jit
def add_kernel(x_ptr, # *Pointer* to first input vector.
y_ptr, # *Pointer* to second input vector.
output_ptr, # *Pointer* to output vector.
n_elements, # Size of the vector.
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
# NOTE: `constexpr` so it can be used as a shape value.
):
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
# tl.device_print("pid=", pid)
# block_start = pid * BLOCK_SIZE
# offsets = block_start + tl.arange(0, BLOCK_SIZE)
# mask = offsets < n_elements
# x = tl.load(x_ptr + offsets, mask=mask)
# y = tl.load(y_ptr + offsets, mask=mask)
# output = x + y
# tl.store(output_ptr + offsets, output, mask=mask)
def test_triton(world_size, rank):
with torch.xpu._DeviceGuard(rank):
torch.xpu.set_device(rank)
# Set the device for the given rank.
device = torch.device(f"xpu:{rank}")
print(f"device is {device} while rank is {rank}", flush=True)
x = torch.rand(1024, device=device)
y = torch.rand(1024, device=device)
output = torch.rand(1024, device=device)
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
print(output)
print(f"Rank {rank} exiting.", flush=True)
# if repalce test_triton with test_sycl, the test passed.
def test_sycl(world_size, rank):
with torch.xpu._DeviceGuard(rank):
torch.xpu.set_device(rank)
device = torch.device(f"xpu:{rank}")
print(f"device is {device} while rank is {rank}", flush=True)
x = torch.rand(1024, device=device)
y = torch.rand(1024, device=device)
output = torch.rand(1024, device=device)
n_elements = output.numel()
torch.add(x, y, out=output)
if __name__ == "__main__":
world_size = 2
multiprocessing.set_start_method("spawn", force=True)
processes = []
for rank in range(world_size):
p = multiprocessing.Process(
target=test_triton,
args=(world_size, rank)
)
p.start()
processes.append(p)
for p in processes:
p.join()
if p.exitcode != 0:
raise RuntimeError(f"Process exited with code {p.exitcode}")
Error message:
device is xpu:1 while rank is 1
device is xpu:0 while rank is 0
Segmentation fault from GPU at 0x800000100000, ctx_id: 1 (CCS) type: 0 (NotPresent), level: 2 (PDP), access: 0 (Read), banned: 1, aborting.
Segmentation fault from GPU at 0x800000100000, ctx_id: 1 (CCS) type: 0 (NotPresent), level: 2 (PDP), access: 0 (Read), banned: 1, aborting.
Abort was called at 274 line in file:
/opt/src/compute-neo/shared/source/os_interface/linux/drm_neo.cpp
tensor([0.8241, 0.6043, 0.2742, ..., 0.9886, 0.8789, 0.8474], device='xpu:0')
Rank 0 exiting.
Traceback (most recent call last):
File "/home/xinanlin/xinanlin/cherry.py", line 78, in <module>
raise RuntimeError(f"Process exited with code {p.exitcode}")
RuntimeError: Process exited with code -6
Environment details
Triton: cfb7d53
Pytorch:latest main: ff29791ed8f815bdbca1a5606de046380baca69d