Skip to content

[Multiple Card] Segmentation fault when running triton kernel on two device which are on the different cards. #3641

Closed
@etaf

Description

@etaf

Describe the bug

Hi team:
In a distributed scenario, we found that when a Triton kernel runs independently on two devices located on different PVC cards, a segmentation fault occurs. I have tested that even an empty kernel crashes in this setup, but if both devices are on the same card, it works fine. Additionally, I found that the issue does not occur with SYCL kernels. I have simplified the problem and provided a reproducer, which can be easily reproduced on a multi-card PVC setup.

Reproducer:

import os
import multiprocessing
import torch
import triton
import triton.language as tl
 
# The tile 0, 1 is on the first card, the case can pass
# os.environ["ZE_AFFINITY_MASK"] = "0,1"

# The tile 1 is on the first card, while tile 2 is on the second card.
# Then the case failed even with empty triton kernel
os.environ["ZE_AFFINITY_MASK"] = "1,2"

@triton.jit
def add_kernel(x_ptr,  # *Pointer* to first input vector.
               y_ptr,  # *Pointer* to second input vector.
               output_ptr,  # *Pointer* to output vector.
               n_elements,  # Size of the vector.
               BLOCK_SIZE: tl.constexpr,  # Number of elements each program should process.
               # NOTE: `constexpr` so it can be used as a shape value.
               ):
    pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0.
    # tl.device_print("pid=", pid)
    # block_start = pid * BLOCK_SIZE
    # offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # mask = offsets < n_elements
    # x = tl.load(x_ptr + offsets, mask=mask)
    # y = tl.load(y_ptr + offsets, mask=mask)
    # output = x + y
    # tl.store(output_ptr + offsets, output, mask=mask)

def test_triton(world_size, rank):
    with torch.xpu._DeviceGuard(rank):
        torch.xpu.set_device(rank)
        # Set the device for the given rank.
        device = torch.device(f"xpu:{rank}")
        print(f"device is {device} while rank is {rank}", flush=True)
 
        x = torch.rand(1024, device=device)
        y = torch.rand(1024, device=device)
        output = torch.rand(1024, device=device)
        n_elements = output.numel()
        grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
        add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
        print(output)

    print(f"Rank {rank} exiting.", flush=True)


# if repalce test_triton with test_sycl, the test passed.
def test_sycl(world_size, rank):
    with torch.xpu._DeviceGuard(rank):
        torch.xpu.set_device(rank)
        device = torch.device(f"xpu:{rank}")
        print(f"device is {device} while rank is {rank}", flush=True)
        x = torch.rand(1024, device=device)
        y = torch.rand(1024, device=device)
        output = torch.rand(1024, device=device)
        n_elements = output.numel()
        torch.add(x, y, out=output)
 
 
if __name__ == "__main__":
    world_size = 2
 
    multiprocessing.set_start_method("spawn", force=True)
    processes = []
    for rank in range(world_size):
        p = multiprocessing.Process(
            target=test_triton,
            args=(world_size, rank)
        )
        p.start()
        processes.append(p)
    for p in processes:
        p.join()
        if p.exitcode != 0:
            raise RuntimeError(f"Process exited with code {p.exitcode}")
Error message:
device is xpu:1 while rank is 1
device is xpu:0 while rank is 0
Segmentation fault from GPU at 0x800000100000, ctx_id: 1 (CCS) type: 0 (NotPresent), level: 2 (PDP), access: 0 (Read), banned: 1, aborting.
Segmentation fault from GPU at 0x800000100000, ctx_id: 1 (CCS) type: 0 (NotPresent), level: 2 (PDP), access: 0 (Read), banned: 1, aborting.
Abort was called at 274 line in file:
/opt/src/compute-neo/shared/source/os_interface/linux/drm_neo.cpp
tensor([0.8241, 0.6043, 0.2742,  ..., 0.9886, 0.8789, 0.8474], device='xpu:0')
Rank 0 exiting.
Traceback (most recent call last):
  File "/home/xinanlin/xinanlin/cherry.py", line 78, in <module>
    raise RuntimeError(f"Process exited with code {p.exitcode}")
RuntimeError: Process exited with code -6

Environment details

Triton: cfb7d53
Pytorch:latest main: ff29791ed8f815bdbca1a5606de046380baca69d

Metadata

Metadata

Assignees

Type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions