Skip to content

Commit fe15276

Browse files
committed
gloo/cuda: use torch dtype bf16
1 parent 7a4c857 commit fe15276

11 files changed

+46
-0
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ if(${USE_TCP_OPENSSL_LINK} AND ${USE_TCP_OPENSSL_LOAD})
4242
endif()
4343
option(USE_CUDA "Build with CUDA support" OFF)
4444
option(GLOO_USE_CUDA_TOOLKIT "Build CUDA with FindCUDATookit.cmake and enable_language(CUDA)" OFF)
45+
option(GLOO_USE_TORCH_DTYPES "Build CUDA kernels with pytorch dtypes" 0)
4546

4647
if(MSVC)
4748
message(STATUS "MSVC detected")

gloo/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,12 @@ target_link_libraries(gloo PRIVATE ${gloo_DEPENDENCY_LIBS})
171171
target_include_directories(gloo INTERFACE $<INSTALL_INTERFACE:include>)
172172
if(USE_CUDA)
173173
target_include_directories(gloo_cuda INTERFACE $<INSTALL_INTERFACE:include>)
174+
175+
message(STATUS "GLOO_USE_TORCH_DTYPES : ${GLOO_USE_TORCH_DTYPES} ${GLOO_TORCH_DIR}")
176+
177+
if(GLOO_USE_TORCH_DTYPES)
178+
target_include_directories(gloo_cuda PRIVATE ${GLOO_TORCH_DIR})
179+
endif()
174180
endif()
175181
if(USE_ROCM)
176182
target_include_directories(gloo_hip INTERFACE $<INSTALL_INTERFACE:include>)

gloo/config.h.in

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,5 @@ static_assert(
3737
#cmakedefine01 GLOO_HAVE_TRANSPORT_TCP_TLS
3838
#cmakedefine01 GLOO_HAVE_TRANSPORT_IBVERBS
3939
#cmakedefine01 GLOO_HAVE_TRANSPORT_UV
40+
41+
#cmakedefine01 GLOO_USE_TORCH_DTYPES

gloo/cuda.cu

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,4 +391,13 @@ DELEGATE_SIMPLE_CUDA_BINARY_COMPARE(double, cudaMax, >);
391391
DELEGATE_HALF_PRECISION_CUDA_BINARY_COMPARE(cudaMin, <);
392392
DELEGATE_HALF_PRECISION_CUDA_BINARY_COMPARE(cudaMax, >);
393393
394+
#if GLOO_USE_TORCH_DTYPES
395+
using BFloat16 = c10::BFloat16;
396+
INSTANTIATE_COPY_ASYNC(BFloat16);
397+
DELEGATE_SIMPLE_CUDA_BINARY_OPERATOR(BFloat16, cudaSum, +);
398+
DELEGATE_SIMPLE_CUDA_BINARY_OPERATOR(BFloat16, cudaProduct, *);
399+
DELEGATE_SIMPLE_CUDA_BINARY_COMPARE(BFloat16, cudaMin, <);
400+
DELEGATE_SIMPLE_CUDA_BINARY_COMPARE(BFloat16, cudaMax, >);
401+
#endif
402+
394403
} // namespace gloo

gloo/cuda_allreduce_bcube.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,4 +514,8 @@ INSTANTIATE_TEMPLATE(float);
514514
INSTANTIATE_TEMPLATE(double);
515515
INSTANTIATE_TEMPLATE(float16);
516516

517+
#if GLOO_USE_TORCH_DTYPES
518+
INSTANTIATE_TEMPLATE(c10::BFloat16);
519+
#endif
520+
517521
} // namespace gloo

gloo/cuda_allreduce_halving_doubling.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,4 +657,8 @@ INSTANTIATE_TEMPLATE(float);
657657
INSTANTIATE_TEMPLATE(double);
658658
INSTANTIATE_TEMPLATE(float16);
659659

660+
#if GLOO_USE_TORCH_DTYPES
661+
INSTANTIATE_TEMPLATE(c10::BFloat16);
662+
#endif
663+
660664
} // namespace gloo

gloo/cuda_allreduce_local.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,4 +76,8 @@ INSTANTIATE_TEMPLATE(float);
7676
INSTANTIATE_TEMPLATE(double);
7777
INSTANTIATE_TEMPLATE(float16);
7878

79+
#if GLOO_USE_TORCH_DTYPES
80+
INSTANTIATE_TEMPLATE(c10::BFloat16);
81+
#endif
82+
7983
} // namespace gloo

gloo/cuda_allreduce_ring.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,4 +188,8 @@ INSTANTIATE_TEMPLATE(float);
188188
INSTANTIATE_TEMPLATE(double);
189189
INSTANTIATE_TEMPLATE(float16);
190190

191+
#if GLOO_USE_TORCH_DTYPES
192+
INSTANTIATE_TEMPLATE(c10::BFloat16);
193+
#endif
194+
191195
} // namespace gloo

gloo/cuda_allreduce_ring_chunked.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,4 +365,8 @@ INSTANTIATE_TEMPLATE(float);
365365
INSTANTIATE_TEMPLATE(double);
366366
INSTANTIATE_TEMPLATE(float16);
367367

368+
#if GLOO_USE_TORCH_DTYPES
369+
INSTANTIATE_TEMPLATE(c10::BFloat16);
370+
#endif
371+
368372
} // namespace gloo

gloo/cuda_broadcast_one_to_all.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,4 +197,8 @@ INSTANTIATE_TEMPLATE(float);
197197
INSTANTIATE_TEMPLATE(double);
198198
INSTANTIATE_TEMPLATE(float16);
199199

200+
#if GLOO_USE_TORCH_DTYPES
201+
INSTANTIATE_TEMPLATE(c10::BFloat16);
202+
#endif
203+
200204
} // namespace gloo

gloo/cuda_private.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020
#include "gloo/cuda.h"
2121
#include "gloo/transport/device.h"
2222

23+
#if GLOO_USE_TORCH_DTYPES
24+
#include <c10/util/BFloat16.h>
25+
#endif
26+
2327
namespace gloo {
2428

2529
#define CUDA_CHECK(condition) \

0 commit comments

Comments
 (0)