Skip to content

Commit ffa13fc

Browse files
committed
wip
1 parent 3223b51 commit ffa13fc

16 files changed

+23
-25
lines changed

python/mscclpp/language/channel.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77

88
@dataclass
9-
class Channel:
9+
class MemoryChannel:
1010
__channel_counts = defaultdict(int)
1111

1212
def __init__(self, dst_rank: int, src_rank: int):
@@ -16,8 +16,8 @@ def __init__(self, dst_rank: int, src_rank: int):
1616
if dst_rank >= num_ranks:
1717
raise RuntimeError(f"Destination rank {dst_rank} is out of bounds. Number of ranks: {num_ranks}")
1818

19-
self.channel_id = Channel.__channel_counts[src_rank]
20-
Channel.__channel_counts[src_rank] += 1
19+
self.channel_id = MemoryChannel.__channel_counts[src_rank]
20+
MemoryChannel.__channel_counts[src_rank] += 1
2121

2222
self.dst_rank = dst_rank
2323
self.src_rank = src_rank

python/mscclpp/language/tests/unit_tests/get/get_fuse_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def get_test(num_threads_per_block, min_message_size, max_message_size):
2929
if src_rank != dst_rank:
3030
rank = Rank(dst_rank)
3131
dst_buff = rank.get_input_buffer()
32-
ch = Channel(dst_rank, src_rank)
32+
ch = MemoryChannel(dst_rank, src_rank)
3333
ch.signal(tb=0, relaxed=True)
3434
ch.wait(tb=0, data_sync=SyncType.after, relaxed=True)
3535
ch.get(src_buff[0:1], dst_buff[1:2], tb=0)

python/mscclpp/language/tests/unit_tests/get/get_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def get_test(num_threads_per_block, min_message_size, max_message_size):
2929
if src_rank != dst_rank:
3030
rank = Rank(dst_rank)
3131
dst_buff = rank.get_input_buffer()
32-
ch = Channel(dst_rank, src_rank)
32+
ch = MemoryChannel(dst_rank, src_rank)
3333
ch.signal(tb=0, relaxed=True)
3434
ch.wait(tb=0, data_sync=SyncType.after, relaxed=True)
3535
ch.get(src_buff[0:1], dst_buff[1:2], tb=0)

python/mscclpp/language/tests/unit_tests/put/put_fuse_test.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,11 @@ def put_test(num_threads_per_block, min_message_size, max_message_size):
2929
if src_rank != dst_rank:
3030
rank = Rank(dst_rank)
3131
dst_buff = rank.get_input_buffer()
32-
output_buff = rank.get_output_buffer()
33-
ch = Channel(dst_rank, src_rank)
32+
ch = MemoryChannel(dst_rank, src_rank)
3433
ch.signal(tb=0, relaxed=True)
3534
ch.wait(tb=0, data_sync=SyncType.after, relaxed=True)
3635
ch.put(dst_buff[1:2], src_buff[0:1], tb=0)
37-
ch = Channel(dst_rank, src_rank)
36+
ch = MemoryChannel(dst_rank, src_rank)
3837
ch.put(dst_buff[0:1], src_buff[1:2], tb=0)
3938
ch.signal(tb=0, data_sync=SyncType.before)
4039
ch.wait(tb=0, data_sync=SyncType.after)

python/mscclpp/language/tests/unit_tests/put/put_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@ def put_test(num_threads_per_block, min_message_size, max_message_size):
2929
if src_rank != dst_rank:
3030
rank = Rank(dst_rank)
3131
dst_buff = rank.get_input_buffer()
32-
output_buff = rank.get_output_buffer()
33-
ch = Channel(dst_rank, src_rank)
32+
ch = MemoryChannel(dst_rank, src_rank)
3433
ch.signal(tb=0, relaxed=True)
3534
ch.wait(tb=0, data_sync=SyncType.after, relaxed=True)
3635
ch.put(dst_buff[1:2], src_buff[0:1], tb=0)

python/mscclpp/language/tests/unit_tests/put_packet_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def put_packet_test(num_threads_per_block, min_message_size, max_message_size):
2828
for dst_rank in range(gpus):
2929
if src_rank != dst_rank:
3030
dst_buff = Buffer(dst_rank, 1)
31-
ch = Channel(dst_rank, src_rank)
31+
ch = MemoryChannel(dst_rank, src_rank)
3232
ch.put_packet(dst_buff[0:1], src_buff[0:1], tb=0)
3333

3434
print(JSON())

python/mscclpp/language/tests/unit_tests/read_put_packet_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def read_put_packet_test(num_threads_per_block, min_message_size, max_message_si
2929
for src_rank in range(gpus):
3030
for dst_rank in range(gpus):
3131
if src_rank != dst_rank:
32-
ch = Channel(dst_rank, src_rank)
32+
ch = MemoryChannel(dst_rank, src_rank)
3333
ch.put_packet(
3434
scratch_buffers[dst_rank][1:2], scratch_buffers[src_rank][0:1], tb=0, from_packet=True
3535
)

python/mscclpp/language/tests/unit_tests/read_reduce/read_reduce_fuse_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def read_reduce_test(num_threads_per_block, min_message_size, max_message_size):
3030
if src_rank != dst_rank:
3131
peer_rank = Rank(dst_rank)
3232
peer_input_buff = peer_rank.get_input_buffer()
33-
ch = Channel(dst_rank, src_rank)
33+
ch = MemoryChannel(dst_rank, src_rank)
3434
ch.signal(tb=0, relaxed=True)
3535
ch.wait(tb=0, data_sync=SyncType.after, relaxed=True)
3636
ch.reduce(input_buff[0:1], [peer_input_buff[0:1]], tb=0, local_dst_chunk=output_buff[0:1])

python/mscclpp/language/tests/unit_tests/read_reduce/read_reduce_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def read_reduce_test(num_threads_per_block, min_message_size, max_message_size):
3030
if src_rank != dst_rank:
3131
peer_rank = Rank(dst_rank)
3232
peer_input_buff = peer_rank.get_input_buffer()
33-
ch = Channel(dst_rank, src_rank)
33+
ch = MemoryChannel(dst_rank, src_rank)
3434
ch.signal(tb=0, relaxed=True)
3535
ch.wait(tb=0, data_sync=SyncType.after, relaxed=True)
3636
ch.reduce(input_buff[0:1], [peer_input_buff[1:2]], tb=0, local_dst_chunk=output_buff[0:1])

python/mscclpp/language/tests/unit_tests/read_reduce_send_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def put_test(num_threads_per_block, min_message_size, max_message_size):
3131
peer_rank = Rank(dst_rank)
3232
peer_input_buff = peer_rank.get_input_buffer()
3333
peer_output_buff = peer_rank.get_output_buffer()
34-
ch = Channel(dst_rank, src_rank)
34+
ch = MemoryChannel(dst_rank, src_rank)
3535
ch.signal(tb=0, relaxed=True)
3636
ch.wait(tb=0, data_sync=SyncType.after, relaxed=True)
3737
ch.reduce(input_buff[0:1], [peer_input_buff[1:2]], tb=0, local_dst_chunk=output_buff[0:1])

python/mscclpp/language/tests/unit_tests/reduce_send_packet_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def reduce_send_packet_test(num_threads_per_block, min_message_size, max_message
3030
rank = Rank(src_rank)
3131
for dst_rank in range(gpus):
3232
if src_rank != dst_rank:
33-
ch = Channel(dst_rank, src_rank)
33+
ch = MemoryChannel(dst_rank, src_rank)
3434
rank.reduce(
3535
scratch_buffers[src_rank][0:1],
3636
[scratch_buffers[src_rank][1:2]],

python/mscclpp/language/tests/unit_tests/reduce_send_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def reduce_send_test(num_threads_per_block, min_message_size, max_message_size):
3030
if src_rank != dst_rank:
3131
peer_rank = Rank(dst_rank)
3232
peer_output_buff = peer_rank.get_output_buffer()
33-
ch = Channel(dst_rank, src_rank)
33+
ch = MemoryChannel(dst_rank, src_rank)
3434
ch.signal(tb=0, relaxed=True)
3535
ch.wait(tb=0, data_sync=SyncType.after, relaxed=True)
3636
rank.reduce(input_buff[0:1], [input_buff[1:2]], tb=0, dst_chunk=output_buff[0:1])

python/mscclpp/language/tests/unit_tests/signal_wait/relaxed_signal_wait_fuse_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,16 @@ def signal_wait_test(num_threads_per_block, min_message_size, max_message_size):
2525
for src_rank in range(gpus):
2626
for dst_rank in range(gpus):
2727
if src_rank != dst_rank:
28-
ch = Channel(dst_rank, src_rank)
28+
ch = MemoryChannel(dst_rank, src_rank)
2929
ch.signal(tb=0, data_sync=SyncType.before, relaxed=True)
3030
ch.signal(tb=0, data_sync=SyncType.before, relaxed=True)
31-
ch = Channel(dst_rank, src_rank)
31+
ch = MemoryChannel(dst_rank, src_rank)
3232
ch.signal(tb=0, data_sync=SyncType.before, relaxed=True)
3333
ch.signal(tb=0, data_sync=SyncType.before, relaxed=True)
3434

3535
ch.wait(tb=0, data_sync=SyncType.after, relaxed=True)
3636
ch.wait(tb=0, data_sync=SyncType.after, relaxed=True)
37-
ch = Channel(dst_rank, src_rank)
37+
ch = MemoryChannel(dst_rank, src_rank)
3838
ch.wait(tb=0, data_sync=SyncType.after, relaxed=True)
3939
ch.wait(tb=0, data_sync=SyncType.before, relaxed=True)
4040

python/mscclpp/language/tests/unit_tests/signal_wait/relaxed_signal_wait_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def relaxed_signal_wait_test(num_threads_per_block, min_message_size, max_messag
2525
for src_rank in range(gpus):
2626
for dst_rank in range(gpus):
2727
if src_rank != dst_rank:
28-
ch = Channel(dst_rank, src_rank)
28+
ch = MemoryChannel(dst_rank, src_rank)
2929
ch.signal(tb=0, relaxed=True)
3030
ch.wait(tb=0, relaxed=True)
3131

python/mscclpp/language/tests/unit_tests/signal_wait/signal_wait_fuse_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,16 @@ def signal_wait_test(num_threads_per_block, min_message_size, max_message_size):
2525
for src_rank in range(gpus):
2626
for dst_rank in range(gpus):
2727
if src_rank != dst_rank:
28-
ch = Channel(dst_rank, src_rank)
28+
ch = MemoryChannel(dst_rank, src_rank)
2929
ch.signal(tb=0, data_sync=SyncType.before)
3030
ch.signal(tb=0, data_sync=SyncType.before)
31-
ch = Channel(dst_rank, src_rank)
31+
ch = MemoryChannel(dst_rank, src_rank)
3232
ch.signal(tb=0, data_sync=SyncType.before)
3333
ch.signal(tb=0, data_sync=SyncType.before)
3434

3535
ch.wait(tb=0, data_sync=SyncType.after)
3636
ch.wait(tb=0, data_sync=SyncType.after)
37-
ch = Channel(dst_rank, src_rank)
37+
ch = MemoryChannel(dst_rank, src_rank)
3838
ch.wait(tb=0, data_sync=SyncType.after)
3939
ch.wait(tb=0, data_sync=SyncType.before)
4040

python/mscclpp/language/tests/unit_tests/signal_wait/signal_wait_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def signal_wait_test(num_threads_per_block, min_message_size, max_message_size):
2525
for src_rank in range(gpus):
2626
for dst_rank in range(gpus):
2727
if src_rank != dst_rank:
28-
ch = Channel(dst_rank, src_rank)
28+
ch = MemoryChannel(dst_rank, src_rank)
2929
ch.signal(tb=0, data_sync=SyncType.before)
3030
ch.wait(tb=0, data_sync=SyncType.after)
3131

0 commit comments

Comments
 (0)