Skip to content

Support Fuse Operation on MSCCPP DSL #547

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jun 17, 2025
265 changes: 211 additions & 54 deletions python/mscclpp/language/channel.py

Large diffs are not rendered by default.

19 changes: 19 additions & 0 deletions python/mscclpp/language/collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,22 @@ def init_buffers(self):
}
rank_buffers.append(buffers)
return rank_buffers


class AllReduce(Collective):
def __init__(self, num_ranks, chunk_factor, inplace):
Collective.__init__(self, num_ranks, chunk_factor, inplace)
self.name = "allreduce"

# Initializes input buffer for an allgather
def init_buffers(self):
rank_buffers = []
for rank in range(self.num_ranks):
input_buffer_size = self.num_ranks * self.chunk_factor
output_buffer_size = self.num_ranks * self.chunk_factor
buffers = {
BufferType.input: BaseBuffer(rank, BufferType.input, 0, input_buffer_size),
BufferType.output: BaseBuffer(rank, BufferType.output, 0, output_buffer_size),
}
rank_buffers.append(buffers)
return rank_buffers
1 change: 1 addition & 0 deletions python/mscclpp/language/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@


def JSON():
get_program().optimize_operations()
return get_program().to_json()
10 changes: 7 additions & 3 deletions python/mscclpp/language/internal/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ def setup_channel(self, tb: int, channel) -> int:
return self.threadblocks[tb].add_channel(channel)

def add_remote_buffer(self, tb: int, remote_buffer: RemoteBuffer) -> int:
if (remote_buffer.rank, remote_buffer.type) not in self.remote_buffers:
if (remote_buffer.remote_rank, remote_buffer.type) not in self.remote_buffers:
remote_buffer.set_id()
self.remote_buffers[(remote_buffer.rank, remote_buffer.type)] = remote_buffer
self.remote_buffers[(remote_buffer.remote_rank, remote_buffer.type)] = remote_buffer
else:
gpu_remote_buffer = self.remote_buffers[(remote_buffer.rank, remote_buffer.type)]
gpu_remote_buffer = self.remote_buffers[(remote_buffer.remote_rank, remote_buffer.type)]
gpu_remote_buffer.channel_access.update(remote_buffer.channel_access)
remote_buffer = gpu_remote_buffer

Expand All @@ -53,6 +53,10 @@ def add_operation(self, tb: int, operation: BaseOperation):

self.threadblocks[tb].add_operation(operation)

def optimize_operations(self):
for tb in self.threadblocks:
tb.optimize_operations()

def to_json(self) -> dict:
return {
"id": self.id,
Expand Down
Loading