Skip to content

Support Fuse Operation on MSCCPP DSL #547

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jun 17, 2025
Merged
15 changes: 9 additions & 6 deletions python/mscclpp/language/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,48 +25,51 @@ def __init__(self, dst_rank: int, src_rank: int, channel_type: ChannelType):
get_program().add_channel(self)

def signal(self, tb: int, sync: SyncType = SyncType.none, relaxed=False):
if sync == SyncType.before or sync == SyncType.both:
if sync == SyncType.before:
sync_op = SyncOperation()
get_program().add_operation(self.src_rank, tb, sync_op)

tb_channel_ids = get_program().setup_channel(tb, self)
op = SignalOperation(tb_channel_ids, self.channel_type, relaxed)
get_program().add_operation(self.src_rank, tb, op)

if sync == SyncType.after or sync == SyncType.both:
if sync == SyncType.after:
sync_op = SyncOperation()
get_program().add_operation(self.src_rank, tb, sync_op)

def wait(self, tb: int, sync: SyncType = SyncType.none, relaxed=False):
if sync == SyncType.before or sync == SyncType.both:
if sync == SyncType.before:
sync_op = SyncOperation()
get_program().add_operation(self.src_rank, tb, sync_op)

tb_channel_ids = get_program().setup_channel(tb, self)
op = WaitOperation(tb_channel_ids, self.channel_type, relaxed)
get_program().add_operation(self.src_rank, tb, op)

if sync == SyncType.after or sync == SyncType.both:
if sync == SyncType.after:
sync_op = SyncOperation()
get_program().add_operation(self.src_rank, tb, sync_op)

def flush(self, tb: int, sync: SyncType = SyncType.none):
if self.channel_type != ChannelType.port:
raise RuntimeError(f"Flush operation is only supported for ChannelType.port.")

if sync == SyncType.before or sync == SyncType.both:
if sync == SyncType.before:
sync_op = SyncOperation()
get_program().add_operation(self.src_rank, tb, sync_op)

tb_channel_ids = get_program().setup_channel(tb, self)
op = FlushOperation(tb_channel_ids, self.channel_type)
get_program().add_operation(self.src_rank, tb, op)

if sync == SyncType.after or sync == SyncType.both:
if sync == SyncType.after:
sync_op = SyncOperation()
get_program().add_operation(self.src_rank, tb, sync_op)

def get(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int):
if self.channel_type != ChannelType.memory:
raise RuntimeError(f"Get operation is only supported for ChannelType.memory.")

if dst_chunk.rank != self.src_rank:
raise RuntimeError(
f"Source chunk rank {dst_chunk.rank} does not match current channel source rank {self.src_rank}."
Expand Down
1 change: 1 addition & 0 deletions python/mscclpp/language/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@


def JSON():
get_program().optimize_operations()
return get_program().to_json()
4 changes: 4 additions & 0 deletions python/mscclpp/language/internal/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ def add_operation(self, tb: int, operation: BaseOperation):

self.threadblocks[tb].add_operation(operation)

def optimize_operations(self):
for tb in self.threadblocks:
tb.optimize_operations()

def to_json(self) -> dict:
return {
"id": self.id,
Expand Down
143 changes: 137 additions & 6 deletions python/mscclpp/language/internal/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ class SyncOperation(BaseOperation):
def __init__(self):
self.name = Instruction.nop

def __add__(self, other):
fused_operation = None
if isinstance(other, SyncOperation):
fused_operation = SyncOperation()

return fused_operation

def to_json(self):
result = {"name": self.name.value}
return result
Expand All @@ -59,6 +66,9 @@ def __init__(
self.src_buff = src_buff
self.dst_buff = dst_buff

def __add__(self, other):
return None

def to_json(self):
result = {"name": self.name.value}
result["src_buff"] = []
Expand All @@ -77,12 +87,30 @@ def __init__(self, channels_ids: List[int], channel_type: ChannelType, relaxed:
self.name = Instruction.relaxed_signal
else:
self.name = Instruction.signal
self.channel_ids = channels_ids
self.channel_ids = set(channels_ids)
self.channel_type = channel_type

def __add__(self, other):
fused_operation = None
if (
isinstance(other, SignalOperation)
and self.channel_type == other.channel_type
and self.name == other.name
and not self.channel_ids & other.channel_ids
):
fused_operation = SignalOperation(
channels_ids=self.channel_ids | other.channel_ids,
channel_type=self.channel_type,
relaxed=(self.name == Instruction.relaxed_signal),
)
if isinstance(other, SyncOperation):
fused_operation = self

return fused_operation

def to_json(self):
result = {"name": self.name.value}
result["channel_ids"] = self.channel_ids
result["channel_ids"] = list(self.channel_ids)
result["channel_type"] = self.channel_type.value
return result

Expand All @@ -94,12 +122,30 @@ def __init__(self, channels_ids: List[int], channel_type: ChannelType, relaxed:
self.name = Instruction.relaxed_wait
else:
self.name = Instruction.wait
self.channel_ids = channels_ids
self.channel_ids = set(channels_ids)
self.channel_type = channel_type

def __add__(self, other):
fused_operation = None
if (
isinstance(other, WaitOperation)
and self.name == other.name
and not self.channel_ids & other.channel_ids
and self.channel_type == other.channel_type
):
fused_operation = WaitOperation(
channels_ids=self.channel_ids | other.channel_ids,
channel_type=self.channel_type,
relaxed=(self.name == Instruction.relaxed_wait),
)
if isinstance(other, SyncOperation):
fused_operation = self

return fused_operation

def to_json(self):
result = {"name": self.name.value}
result["channel_ids"] = self.channel_ids
result["channel_ids"] = list(self.channel_ids)
result["channel_type"] = self.channel_type.value
return result

Expand All @@ -122,6 +168,9 @@ def __init__(self, rank: int, tb_list: List[int]):
self.name = Instruction.barrier
self.barrier_info = barrier_info

def __add__(self, other):
return None

def to_json(self):
result = {"name": self.name.value}
result["barrier_id"] = self.barrier_id
Expand All @@ -144,12 +193,28 @@ def __hash__(self):
class FlushOperation(BaseOperation):
def __init__(self, channels_ids: List[int], channel_type: ChannelType):
self.name = Instruction.flush
self.channel_ids = channels_ids
self.channel_ids = set(channels_ids)
self.channel_type = channel_type

def __add__(self, other):
fused_operation = None
if (
isinstance(other, FlushOperation)
and self.channel_type == other.channel_type
and not self.channel_ids & other.channel_ids
):
fused_operation = FlushOperation(
channels_ids=self.channel_ids | other.channel_ids,
channel_type=self.channel_type
)
if isinstance(other, SyncOperation):
fused_operation = self

return fused_operation

def to_json(self):
result = {"name": self.name.value}
result["channel_ids"] = self.channel_ids
result["channel_ids"] = list(self.channel_ids)
result["channel_type"] = self.channel_type.value
return result

Expand All @@ -169,6 +234,22 @@ def __init__(
self.channel_ids = channel_ids
self.channel_type = channel_type

def __add__(self, other):
fused_operation = None
if (
isinstance(other, GetOperation)
and self.src_buff[0].size == other.src_buff[0].size
and self.channel_type == other.channel_type
):
fused_operation = GetOperation(
src_buff=self.src_buff + other.src_buff,
dst_buff=self.dst_buff + other.dst_buff,
channel_ids=self.channel_ids + other.channel_ids,
channel_type=self.channel_type,
)

return fused_operation

def to_json(self):
result = {"name": self.name.value}
result["src_buff"] = []
Expand Down Expand Up @@ -217,6 +298,26 @@ def __init__(
self.channel_ids = channel_ids
self.channel_type = channel_type

def __add__(self, other):
fused_operation = None
if (
isinstance(other, PutOperation)
and self.name == Instruction.put or self.name == Instruction.put_with_signal or self.name == Instruction.put_with_signal_and_flush
and self.name == other.name
and self.src_buff[0].size == other.src_buff[0].size
and self.channel_type == other.channel_type
):
fused_operation = PutOperation(
src_buff=self.src_buff + other.src_buff,
dst_buff=self.dst_buff + other.dst_buff,
channel_ids=self.channel_ids + other.channel_ids,
channel_type=self.channel_type,
with_signal=self.with_signal,
with_signal_and_flush=self.with_signal_and_flush,
)

return fused_operation

def to_json(self):
result = {"name": self.name.value}
result["src_buff"] = []
Expand Down Expand Up @@ -267,6 +368,30 @@ def __init__(
self.channel_ids = channel_ids
self.channel_type = channel_type
self.reduce_operation = reduce_operation
self.packet = packet

def __add__(self, other):
fused_operation = None
if (
isinstance(other, ReduceOperation)
and (self.name == Instruction.reduce_copy or self.name == Instruction.reduce_copy_packet or self.name == Instruction.read_reduce_copy)
and self.name == other.name
and self.local_src_buff[0] == other.local_src_buff[0]
and self.local_dst_buff == other.local_dst_buff
and self.channel_type == other.channel_type
and self.reduce_operation == other.reduce_operation
):
fused_operation = ReduceOperation(
self.local_src_buff + other.local_src_buff[1:],
self.local_dst_buff,
remote_src_buff=self.remote_src_buff + other.remote_src_buff,
channel_ids=self.channel_ids + other.channel_ids,
channel_type=self.channel_type,
reduce_operation=self.reduce_operation,
packet=self.packet,
)

return fused_operation

def to_json(self):
result = {"name": self.name.value}
Expand Down Expand Up @@ -313,6 +438,9 @@ def __init__(
self.channel_type = channel_type
self.reduce_operation = reduce_operation

def __add__(self, other):
return None

def to_json(self):
result = {"name": self.name.value}
result["buffer_type"] = self.buffer_type.value
Expand Down Expand Up @@ -346,6 +474,9 @@ def __init__(
self.channel_type = channel_type
self.reduce_operation = reduce_operation

def __add__(self, other):
return None

def to_json(self):
result = {"name": self.name.value}
result["src_chunk"] = self.src_chunk.to_json()
Expand Down
26 changes: 26 additions & 0 deletions python/mscclpp/language/internal/optmizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from mscclpp.language.internal.operations import *

def fuse_instructions(operations):
operation_index = 0
fused_operations = []
while operation_index < len(operations):
next_operation_index = operation_index + 1
current_operation = operations[operation_index]
previous_operation = None
fused_operation = current_operation

while next_operation_index < len(operations):
next_operation = operations[next_operation_index]
fused_operation = current_operation + next_operation
if fused_operation is None:
if previous_operation is not None and previous_operation.name == Instruction.nop:
next_operation_index -=1
break
current_operation = fused_operation
previous_operation = next_operation
next_operation_index += 1

fused_operations.append(current_operation)
operation_index = next_operation_index

return fused_operations
4 changes: 4 additions & 0 deletions python/mscclpp/language/internal/threadblock.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from mscclpp.language.internal.types import ChannelType, RemoteBuffer, BufferType
from mscclpp.language.internal.optmizer import *
from dataclasses import dataclass, field


Expand Down Expand Up @@ -58,6 +59,9 @@ def add_remote_buffer(self, remote_buffer: RemoteBuffer):
def add_operation(self, op):
self.ops.append(op)

def optimize_operations(self):
self.ops = fuse_instructions(self.ops)

def to_json(self) -> dict:
return {
"id": self.id,
Expand Down
4 changes: 4 additions & 0 deletions python/mscclpp/language/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ def setup_remote_chunk(self, rank, tb, remote_chunk: RemoteBuffer):
def add_operation(self, rank, tb, operation):
self.gpus[rank].add_operation(tb, operation)

def optimize_operations(self):
for gpu in self.gpus:
gpu.optimize_operations()

def to_json(self):
json_obj = {
"name": self.name,
Expand Down
Loading
Loading