Skip to content

Commit 7b208f5

Browse files
authored
Add support for the LLVM SPIR-V back-end. (#665)
1 parent 870fa83 commit 7b208f5

File tree

7 files changed

+97
-53
lines changed

7 files changed

+97
-53
lines changed

src/spirv.jl

Lines changed: 70 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,45 @@
44
# https://github.com/KhronosGroup/LLVM-SPIRV-Backend/blob/master/llvm/docs/SPIR-V-Backend.rst
55
# https://github.com/KhronosGroup/SPIRV-LLVM-Translator/blob/master/docs/SPIRVRepresentationInLLVM.rst
66

7-
const SPIRV_LLVM_Translator_unified_jll = LazyModule("SPIRV_LLVM_Translator_unified_jll", UUID("85f0d8ed-5b39-5caa-b1ae-7472de402361"))
8-
const SPIRV_Tools_jll = LazyModule("SPIRV_Tools_jll", UUID("6ac6d60f-d740-5983-97d7-a4482c0689f4"))
7+
const SPIRV_LLVM_Backend_jll =
8+
LazyModule("SPIRV_LLVM_Backend_jll",
9+
UUID("4376b9bf-cff8-51b6-bb48-39421dff0d0c"))
10+
const SPIRV_LLVM_Translator_unified_jll =
11+
LazyModule("SPIRV_LLVM_Translator_unified_jll",
12+
UUID("85f0d8ed-5b39-5caa-b1ae-7472de402361"))
13+
const SPIRV_Tools_jll =
14+
LazyModule("SPIRV_Tools_jll",
15+
UUID("6ac6d60f-d740-5983-97d7-a4482c0689f4"))
916

1017

1118
## target
1219

1320
export SPIRVCompilerTarget
1421

1522
Base.@kwdef struct SPIRVCompilerTarget <: AbstractCompilerTarget
23+
version::Union{Nothing,VersionNumber} = nothing
1624
extensions::Vector{String} = []
1725
supports_fp16::Bool = true
1826
supports_fp64::Bool = true
27+
28+
backend::Symbol = isavailable(SPIRV_LLVM_Backend_jll) ? :llvm : :khronos
29+
# XXX: these don't really belong in the _target_ struct
30+
validate::Bool = false
31+
optimize::Bool = false
1932
end
2033

21-
llvm_triple(::SPIRVCompilerTarget) = Int===Int64 ? "spir64-unknown-unknown" : "spirv-unknown-unknown"
34+
function llvm_triple(target::SPIRVCompilerTarget)
35+
if target.backend == :llvm
36+
architecture = Int===Int64 ? "spirv64" : "spirv32" # could also be "spirv" for logical addressing
37+
subarchitecture = target.version === nothing ? "" : "v$(target.version.major).$(target.version.minor)"
38+
vendor = "unknown" # could also be AMD
39+
os = "unknown"
40+
environment = "unknown"
41+
return "$architecture$subarchitecture-$vendor-$os-$environment"
42+
elseif target.backend == :khronos
43+
return Int===Int64 ? "spir64-unknown-unknown" : "spirv-unknown-unknown"
44+
end
45+
end
2246

2347
# SPIRV is not supported by our LLVM builds, so we can't get a target machine
2448
llvm_machine(::SPIRVCompilerTarget) = nothing
@@ -32,7 +56,8 @@ llvm_datalayout(::SPIRVCompilerTarget) = Int===Int64 ?
3256

3357
# TODO: encode debug build or not in the compiler job
3458
# https://github.com/JuliaGPU/CUDAnative.jl/issues/368
35-
runtime_slug(job::CompilerJob{SPIRVCompilerTarget}) = "spirv"
59+
runtime_slug(job::CompilerJob{SPIRVCompilerTarget}) =
60+
"spirv-" * String(job.config.target.backend)
3661

3762
function finish_module!(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module, entry::LLVM.Function)
3863
# update calling convention
@@ -90,47 +115,57 @@ end
90115
# (SPIRV-LLVM-Translator#1140)
91116
rm_freeze!(mod)
92117

93-
94118
# translate to SPIR-V
95119
input = tempname(cleanup=false) * ".bc"
96120
translated = tempname(cleanup=false) * ".spv"
97-
options = `--spirv-debug-info-version=ocl-100`
98-
if !isempty(job.config.target.extensions)
99-
str = join(map(ext->"+$ext", job.config.target.extensions), ",")
100-
options = `$options --spirv-ext=$str`
101-
end
102121
write(input, mod)
103-
let cmd = `$(SPIRV_LLVM_Translator_unified_jll.llvm_spirv()) $options -o $translated $input`
104-
proc = run(ignorestatus(cmd))
105-
if !success(proc)
106-
error("""Failed to translate LLVM code to SPIR-V.
107-
If you think this is a bug, please file an issue and attach $(input).""")
122+
if job.config.target.backend === :llvm
123+
cmd = `$(SPIRV_LLVM_Backend_jll.llc()) $input -filetype=obj -o $translated`
124+
125+
if !isempty(job.config.target.extensions)
126+
str = join(map(ext->"+$ext", job.config.target.extensions), ",")
127+
cmd = `$(cmd) -spirv-ext=$str`
128+
end
129+
elseif job.config.target.backend === :khronos
130+
cmd = `$(SPIRV_LLVM_Translator_unified_jll.llvm_spirv()) -o $translated $input --spirv-debug-info-version=ocl-100`
131+
132+
if !isempty(job.config.target.extensions)
133+
str = join(map(ext->"+$ext", job.config.target.extensions), ",")
134+
cmd = `$(cmd) --spirv-ext=$str`
135+
end
136+
137+
if job.config.target.version !== nothing
138+
cmd = `$(cmd) --spirv-max-version=$(job.config.target.version.major).$(job.config.target.version.minor)`
108139
end
109140
end
141+
proc = run(ignorestatus(cmd))
142+
if !success(proc)
143+
error("""Failed to translate LLVM code to SPIR-V.
144+
If you think this is a bug, please file an issue and attach $(input).""")
145+
end
110146

111147
# validate
112-
# XXX: parameterize this on the `validate` driver argument
113-
# XXX: our code currently doesn't pass the validator
114-
#if Base.JLOptions().debug_level >= 2
115-
# cmd = `$(SPIRV_Tools_jll.spirv_val()) $translated`
116-
# proc = run(ignorestatus(cmd))
117-
# if !success(proc)
118-
# error("""Failed to validate generated SPIR-V.
119-
# If you think this is a bug, please file an issue and attach $(input) and $(translated).""")
120-
# end
121-
#end
148+
if job.config.target.validate
149+
cmd = `$(SPIRV_Tools_jll.spirv_val()) $translated`
150+
proc = run(ignorestatus(cmd))
151+
if !success(proc)
152+
error("""Failed to validate generated SPIR-V.
153+
If you think this is a bug, please file an issue and attach $(input) and $(translated).""")
154+
end
155+
end
122156

123157
# optimize
124-
# XXX: parameterize this on the `optimize` driver argument
125-
# XXX: the optimizer segfaults on some of our code
126158
optimized = tempname(cleanup=false) * ".spv"
127-
#let cmd = `$(SPIRV_Tools_jll.spirv_opt()) -O --skip-validation $translated -o $optimized`
128-
# proc = run(ignorestatus(cmd))
129-
# if !success(proc)
130-
# error("""Failed to optimize generated SPIR-V.
131-
# If you think this is a bug, please file an issue and attach $(input) and $(translated).""")
132-
# end
133-
#end
159+
if job.config.target.optimize
160+
cmd = `$(SPIRV_Tools_jll.spirv_opt()) -O --skip-validation $translated -o $optimized`
161+
proc = run(ignorestatus(cmd))
162+
if !success(proc)
163+
error("""Failed to optimize generated SPIR-V.
164+
If you think this is a bug, please file an issue and attach $(input) and $(translated).""")
165+
end
166+
else
167+
cp(translated, optimized)
168+
end
134169

135170
output = if format == LLVM.API.LLVMObjectFile
136171
read(translated)
@@ -141,7 +176,7 @@ end
141176

142177
rm(input)
143178
rm(translated)
144-
#rm(optimized)
179+
rm(optimized)
145180

146181
return output
147182
end

src/utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ struct LazyModule
3030
LazyModule(name, uuid) = new(Base.PkgId(uuid, name))
3131
end
3232

33+
isavailable(lazy_mod::LazyModule) = haskey(Base.loaded_modules, getfield(lazy_mod, :pkg))
34+
3335
function Base.getproperty(lazy_mod::LazyModule, sym::Symbol)
3436
pkg = getfield(lazy_mod, :pkg)
3537
mod = get(Base.loaded_modules, pkg, nothing)

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
99
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1010
REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
1111
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
12+
SPIRV_LLVM_Backend_jll = "4376b9bf-cff8-51b6-bb48-39421dff0d0c"
1213
SPIRV_LLVM_Translator_unified_jll = "85f0d8ed-5b39-5caa-b1ae-7472de402361"
1314
SPIRV_Tools_jll = "6ac6d60f-d740-5983-97d7-a4482c0689f4"
1415
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

test/helpers/spirv.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@ GPUCompiler.runtime_module(::CompilerJob{<:Any,CompilerParams}) = TestRuntime
88

99
function create_job(@nospecialize(func), @nospecialize(types);
1010
kernel::Bool=false, always_inline=false,
11-
supports_fp16=true, supports_fp64=true, kwargs...)
11+
supports_fp16=true, supports_fp64=true,
12+
backend::Symbol, kwargs...)
1213
source = methodinstance(typeof(func), Base.to_tuple_type(types), Base.get_world_counter())
13-
target = SPIRVCompilerTarget(; supports_fp16, supports_fp64)
14+
target = SPIRVCompilerTarget(; backend, validate=true, optimize=true,
15+
supports_fp16, supports_fp64)
1416
params = CompilerParams()
1517
config = CompilerConfig(target, params; kernel, always_inline)
1618
CompilerJob(source, config), kwargs

test/runtests.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,6 @@ end
116116
push!(skip_tests, "metal")
117117
end
118118
end
119-
if !(SPIRV_LLVM_Translator_unified_jll.is_available() && SPIRV_Tools_jll.is_available())
120-
# SPIRV needs it's tools to be available
121-
push!(skip_tests, "spirv")
122-
end
123119
if VERSION < v"1.11"
124120
append!(skip_tests, ["ptx/precompile", "native/precompile"])
125121
end

test/setup.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using Distributed, Test, GPUCompiler, LLVM
22

3-
using SPIRV_LLVM_Translator_unified_jll, SPIRV_Tools_jll
3+
using SPIRV_LLVM_Backend_jll, SPIRV_LLVM_Translator_unified_jll, SPIRV_Tools_jll
44

55
# include all helpers
66
include(joinpath(@__DIR__, "helpers", "runtime.jl"))

test/spirv.jl

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
1+
for backend in (:khronos, :llvm)
2+
13
@testset "IR" begin
24

35
@testset "kernel functions" begin
46
@testset "calling convention" begin
57
kernel() = return
68

7-
ir = sprint(io->SPIRV.code_llvm(io, kernel, Tuple{}; dump_module=true))
9+
ir = sprint(io->SPIRV.code_llvm(io, kernel, Tuple{}; backend, dump_module=true))
810
@test !occursin("spir_kernel", ir)
911

1012
ir = sprint(io->SPIRV.code_llvm(io, kernel, Tuple{};
11-
dump_module=true, kernel=true))
13+
backend, dump_module=true, kernel=true))
1214
@test occursin("spir_kernel", ir)
1315
end
1416

@@ -18,19 +20,20 @@ end
1820
kernel(x) = return
1921
end
2022

21-
ir = sprint(io->SPIRV.code_llvm(io, mod.kernel, Tuple{Tuple{Int}}))
23+
ir = sprint(io->SPIRV.code_llvm(io, mod.kernel, Tuple{Tuple{Int}}; backend))
2224
@test occursin(r"@\w*kernel\w*\(({ i64 }|\[1 x i64\])\*", ir) ||
2325
occursin(r"@\w*kernel\w*\(ptr", ir)
2426

25-
ir = sprint(io->SPIRV.code_llvm(io, mod.kernel, Tuple{Tuple{Int}}; kernel=true))
27+
ir = sprint(io->SPIRV.code_llvm(io, mod.kernel, Tuple{Tuple{Int}};
28+
backend, kernel=true))
2629
@test occursin(r"@\w*kernel\w*\(.*{ ({ i64 }|\[1 x i64\]) }\*.+byval", ir) ||
2730
occursin(r"@\w*kernel\w*\(ptr byval", ir)
2831
end
2932

3033
@testset "byval bug" begin
3134
# byval added alwaysinline, which could conflict with noinline and fail verification
3235
@noinline kernel() = return
33-
SPIRV.code_llvm(devnull, kernel, Tuple{}; kernel=true)
36+
SPIRV.code_llvm(devnull, kernel, Tuple{}; backend, kernel=true)
3437
@test "We did not crash!" != ""
3538
end
3639
end
@@ -44,26 +47,29 @@ end
4447
end
4548
end
4649

47-
ir = sprint(io->SPIRV.code_llvm(io, mod.kernel, Tuple{Ptr{Float16}, Float16}; validate=true))
50+
ir = sprint(io->SPIRV.code_llvm(io, mod.kernel, Tuple{Ptr{Float16}, Float16};
51+
backend, validate=true))
4852
@test occursin("store half", ir)
4953

50-
ir = sprint(io->SPIRV.code_llvm(io, mod.kernel, Tuple{Ptr{Float32}, Float32}; validate=true))
54+
ir = sprint(io->SPIRV.code_llvm(io, mod.kernel, Tuple{Ptr{Float32}, Float32};
55+
backend, validate=true))
5156
@test occursin("store float", ir)
5257

53-
ir = sprint(io->SPIRV.code_llvm(io, mod.kernel, Tuple{Ptr{Float64}, Float64}; validate=true))
58+
ir = sprint(io->SPIRV.code_llvm(io, mod.kernel, Tuple{Ptr{Float64}, Float64};
59+
backend, validate=true))
5460
@test occursin("store double", ir)
5561

5662
@test_throws_message(InvalidIRError,
5763
SPIRV.code_llvm(devnull, mod.kernel, Tuple{Ptr{Float16}, Float16};
58-
supports_fp16=false, validate=true)) do msg
64+
backend, supports_fp16=false, validate=true)) do msg
5965
occursin("unsupported use of half value", msg) &&
6066
occursin("[1] unsafe_store!", msg) &&
6167
occursin("[2] kernel", msg)
6268
end
6369

6470
@test_throws_message(InvalidIRError,
6571
SPIRV.code_llvm(devnull, mod.kernel, Tuple{Ptr{Float64}, Float64};
66-
supports_fp64=false, validate=true)) do msg
72+
backend, supports_fp64=false, validate=true)) do msg
6773
occursin("unsupported use of double value", msg) &&
6874
occursin("[1] unsafe_store!", msg) &&
6975
occursin("[2] kernel", msg)
@@ -82,8 +88,10 @@ end
8288
return
8389
end
8490

85-
asm = sprint(io->SPIRV.code_native(io, kernel, Tuple{Bool}; kernel=true))
91+
asm = sprint(io->SPIRV.code_native(io, kernel, Tuple{Bool}; backend, kernel=true))
8692
@test occursin(r"OpFunctionCall %void %(julia|j)_error", asm)
8793
end
8894

8995
end
96+
97+
end

0 commit comments

Comments
 (0)