Skip to content

Add support for F16 in linalg::transpose #2672

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: branch-25.08
Choose a base branch
from
Open
2 changes: 1 addition & 1 deletion cpp/include/raft/linalg/transpose.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ template <typename T, typename IndexType, typename LayoutPolicy, typename Access
auto transpose(raft::resources const& handle,
raft::mdspan<T, raft::matrix_extent<IndexType>, LayoutPolicy, AccessorPolicy> in,
raft::mdspan<T, raft::matrix_extent<IndexType>, LayoutPolicy, AccessorPolicy> out)
-> std::enable_if_t<std::is_floating_point_v<T>, void>
-> std::enable_if_t<std::is_floating_point_v<T> || std::is_same_v<T, half>, void>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of the primitives in raft only support 32-bit types, and I think it's really great we're starting to add support for half precision. Though I think it'll be a good idea to let the user know (in the doxygen docs) which primitives support half AND full-precision. Can you add a quick note to the doc for this transpose() function ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @cjnolet for reviewing the code. I added a quick note.

{
RAFT_EXPECTS(out.extent(0) == in.extent(1), "Invalid shape for transpose.");
RAFT_EXPECTS(out.extent(1) == in.extent(0), "Invalid shape for transpose.");
Expand Down
10 changes: 3 additions & 7 deletions cpp/tests/linalg/transpose.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,6 @@

#include <type_traits>

namespace std {
template <>
struct is_floating_point<half> : std::true_type {};
} // namespace std

namespace raft {
namespace linalg {

Expand Down Expand Up @@ -241,7 +236,7 @@ namespace transpose_extra_test {
template <typename T, typename IndexType, typename LayoutPolicy>
[[nodiscard]] auto transpose(raft::resources const& handle,
device_matrix_view<T, IndexType, LayoutPolicy> in)
-> std::enable_if_t<std::is_floating_point_v<T> &&
-> std::enable_if_t<(std::is_floating_point_v<T> || std::is_same_v<T, half>) &&
(std::is_same_v<LayoutPolicy, layout_c_contiguous> ||
std::is_same_v<LayoutPolicy, layout_f_contiguous>),
device_matrix<T, IndexType, LayoutPolicy>>
Expand All @@ -266,7 +261,8 @@ template <typename T, typename IndexType, typename LayoutPolicy>
template <typename T, typename IndexType>
[[nodiscard]] auto transpose(raft::resources const& handle,
device_matrix_view<T, IndexType, layout_stride> in)
-> std::enable_if_t<std::is_floating_point_v<T>, device_matrix<T, IndexType, layout_stride>>
-> std::enable_if_t<std::is_floating_point_v<T> || std::is_same_v<T, half>,
device_matrix<T, IndexType, layout_stride>>
{
matrix_extent<size_t> exts{in.extent(1), in.extent(0)};
using policy_type =
Expand Down