Skip to content

Commit 72b20b1

Browse files
committed
tag doubles
1 parent 6ba1672 commit 72b20b1

File tree

1 file changed

+31
-3
lines changed

1 file changed

+31
-3
lines changed

src/device/intrinsics/output.jl

+31-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ const __METAL_OS_LOG_TYPE_DEFAULT__ = Int32(0)
77
const __METAL_OS_LOG_TYPE_ERROR__ = Int32(16)
88
const __METAL_OS_LOG_TYPE_FAULT__ = Int32(17)
99

10+
const ALLOW_DOUBLE_META = "allowdouble"
11+
1012
export @mtlprintf
1113

1214
@generated function promote_c_argument(arg)
@@ -18,13 +20,39 @@ export @mtlprintf
1820

1921
if arg == Cchar || arg == Cshort
2022
return :(Cint(arg))
21-
elseif arg == Cfloat
22-
return :(Cdouble(arg))
2323
else
2424
return :(arg)
2525
end
2626
end
2727

28+
@generated function tag_doubles(arg)
29+
@dispose ctx=Context() begin
30+
ret = arg == Cfloat ? Cdouble : arg
31+
T_arg = convert(LLVMType, arg)
32+
T_ret = convert(LLVMType, ret)
33+
34+
f, ft = create_function(T_ret, [T_arg])
35+
36+
@dispose builder=IRBuilder() begin
37+
entry = BasicBlock(f, "entry")
38+
position!(builder, entry)
39+
40+
p1 = parameters(f)[1]
41+
42+
if arg == Cfloat
43+
res = fpext!(builder, p1, LLVM.DoubleType())
44+
metadata(res)["ir_check_ignore"] = MDNode([])
45+
ret!(builder, res)
46+
else
47+
ret!(builder, p1)
48+
end
49+
end
50+
51+
call_function(f, ret, Tuple{arg}, :arg)
52+
end
53+
end
54+
55+
2856
"""
2957
@mtlprintf("%Fmt", args...)
3058
@@ -33,7 +61,7 @@ Print a formatted string in device context on the host standard output.
3361
macro mtlprintf(fmt::String, args...)
3462
fmt_val = Val(Symbol(fmt))
3563

36-
return :(_mtlprintf($fmt_val, $(map(arg -> :(promote_c_argument($arg)), esc.(args))...)))
64+
return :(_mtlprintf($fmt_val, $(map(arg -> :(tag_doubles(promote_c_argument($arg))), esc.(args))...)))
3765
end
3866

3967
@generated function _mtlprintf(::Val{fmt}, argspec...) where {fmt}

0 commit comments

Comments
 (0)