Skip to content

Commit ea55dc1

Browse files
committed
[gloo] Enable using c10::Half for gloo cuda
[ghstack-poisoned]
1 parent cc44198 commit ea55dc1

8 files changed

+13
-0
lines changed

gloo/cuda.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,12 @@ DELEGATE_SIMPLE_CUDA_BINARY_OPERATOR(BFloat16, cudaSum, +);
398398
DELEGATE_SIMPLE_CUDA_BINARY_OPERATOR(BFloat16, cudaProduct, *);
399399
DELEGATE_SIMPLE_CUDA_BINARY_COMPARE(BFloat16, cudaMin, <);
400400
DELEGATE_SIMPLE_CUDA_BINARY_COMPARE(BFloat16, cudaMax, >);
401+
using Half = c10::Half;
402+
INSTANTIATE_COPY_ASYNC(Half);
403+
DELEGATE_SIMPLE_CUDA_BINARY_OPERATOR(Half, cudaSum, +);
404+
DELEGATE_SIMPLE_CUDA_BINARY_OPERATOR(Half, cudaProduct, *);
405+
DELEGATE_SIMPLE_CUDA_BINARY_COMPARE(Half, cudaMin, <);
406+
DELEGATE_SIMPLE_CUDA_BINARY_COMPARE(Half, cudaMax, >);
401407
#endif
402408
403409
} // namespace gloo

gloo/cuda_allreduce_bcube.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,7 @@ INSTANTIATE_TEMPLATE(float16);
516516

517517
#if GLOO_USE_TORCH_DTYPES
518518
INSTANTIATE_TEMPLATE(c10::BFloat16);
519+
INSTANTIATE_TEMPLATE(c10::Half);
519520
#endif
520521

521522
} // namespace gloo

gloo/cuda_allreduce_halving_doubling.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,7 @@ INSTANTIATE_TEMPLATE(float16);
659659

660660
#if GLOO_USE_TORCH_DTYPES
661661
INSTANTIATE_TEMPLATE(c10::BFloat16);
662+
INSTANTIATE_TEMPLATE(c10::Half);
662663
#endif
663664

664665
} // namespace gloo

gloo/cuda_allreduce_local.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ INSTANTIATE_TEMPLATE(float16);
7878

7979
#if GLOO_USE_TORCH_DTYPES
8080
INSTANTIATE_TEMPLATE(c10::BFloat16);
81+
INSTANTIATE_TEMPLATE(c10::Half);
8182
#endif
8283

8384
} // namespace gloo

gloo/cuda_allreduce_ring.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ INSTANTIATE_TEMPLATE(float16);
190190

191191
#if GLOO_USE_TORCH_DTYPES
192192
INSTANTIATE_TEMPLATE(c10::BFloat16);
193+
INSTANTIATE_TEMPLATE(c10::Half);
193194
#endif
194195

195196
} // namespace gloo

gloo/cuda_allreduce_ring_chunked.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ INSTANTIATE_TEMPLATE(float16);
367367

368368
#if GLOO_USE_TORCH_DTYPES
369369
INSTANTIATE_TEMPLATE(c10::BFloat16);
370+
INSTANTIATE_TEMPLATE(c10::Half);
370371
#endif
371372

372373
} // namespace gloo

gloo/cuda_broadcast_one_to_all.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ INSTANTIATE_TEMPLATE(float16);
199199

200200
#if GLOO_USE_TORCH_DTYPES
201201
INSTANTIATE_TEMPLATE(c10::BFloat16);
202+
INSTANTIATE_TEMPLATE(c10::Half);
202203
#endif
203204

204205
} // namespace gloo

gloo/cuda_private.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
#if GLOO_USE_TORCH_DTYPES
2424
#include <c10/util/BFloat16.h>
25+
#include <c10/util/Half.h>
2526
#endif
2627

2728
namespace gloo {

0 commit comments

Comments
 (0)