Skip to content

Commit afa9599

Browse files
authored
Allow for nested targets (#696)
1 parent 8b8c73f commit afa9599

File tree

6 files changed

+189
-1
lines changed

6 files changed

+189
-1
lines changed

src/driver.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,9 @@ const __llvm_initialized = Ref(false)
222222
for dyn_job in keys(worklist)
223223
# cached compilation
224224
dyn_entry_fn = get!(jobs, dyn_job) do
225-
config = CompilerConfig(dyn_job.config; toplevel=false)
225+
target = nest_target(dyn_job.config.target, job.config.target)
226+
params = nest_params(dyn_job.config.params, job.config.params)
227+
config = CompilerConfig(dyn_job.config; toplevel=false, target, params)
226228
dyn_ir, dyn_meta = codegen(:llvm, CompilerJob(dyn_job; config))
227229
dyn_entry_fn = LLVM.name(dyn_meta.entry)
228230
merge!(compiled, dyn_meta.compiled)

src/interface.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ have_fma(@nospecialize(target::AbstractCompilerTarget), T::Type) = false
4848

4949
dwarf_version(target::AbstractCompilerTarget) = Int32(4) # It seems every target supports v4 bar cuda
5050

51+
# If your target performs nested compilation, this function should reconstruct your target with a new inner target
52+
nest_target(target::AbstractCompilerTarget, parent::AbstractCompilerTarget) = target
53+
5154
## params
5255

5356
export AbstractCompilerParams
@@ -56,6 +59,8 @@ export AbstractCompilerParams
5659

5760
abstract type AbstractCompilerParams end
5861

62+
nest_params(params::AbstractCompilerParams, parent::AbstractCompilerParams) = params
63+
5964

6065
## config
6166

test/helpers/enzyme.jl

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
module Enzyme
2+
3+
using ..GPUCompiler
4+
5+
struct EnzymeTarget{Target<:AbstractCompilerTarget} <: AbstractCompilerTarget
6+
target::Target
7+
end
8+
9+
function EnzymeTarget(;kwargs...)
10+
EnzymeTarget(GPUCompiler.NativeCompilerTarget(; jlruntime = true, kwargs...))
11+
end
12+
13+
GPUCompiler.llvm_triple(target::EnzymeTarget) = GPUCompiler.llvm_triple(target.target)
14+
GPUCompiler.llvm_datalayout(target::EnzymeTarget) = GPUCompiler.llvm_datalayout(target.target)
15+
GPUCompiler.llvm_machine(target::EnzymeTarget) = GPUCompiler.llvm_machine(target.target)
16+
GPUCompiler.nest_target(::EnzymeTarget, other::AbstractCompilerTarget) = EnzymeTarget(other)
17+
GPUCompiler.have_fma(target::EnzymeTarget, T::Type) = GPUCompiler.have_fma(target.target, T)
18+
GPUCompiler.dwarf_version(target::EnzymeTarget) = GPUCompiler.dwarf_version(target.target)
19+
20+
abstract type AbstractEnzymeCompilerParams <: AbstractCompilerParams end
21+
struct EnzymeCompilerParams{Params<:AbstractCompilerParams} <: AbstractEnzymeCompilerParams
22+
params::Params
23+
end
24+
struct PrimalCompilerParams <: AbstractEnzymeCompilerParams
25+
end
26+
27+
EnzymeCompilerParams() = EnzymeCompilerParams(PrimalCompilerParams())
28+
29+
GPUCompiler.nest_params(::EnzymeCompilerParams, other::AbstractCompilerParams) = EnzymeCompilerParams(other)
30+
31+
function GPUCompiler.compile_unhooked(output::Symbol, job::CompilerJob{<:EnzymeTarget})
32+
config = job.config
33+
primal_target = (job.config.target::EnzymeTarget).target
34+
primal_params = (job.config.params::EnzymeCompilerParams).params
35+
36+
primal_config = CompilerConfig(
37+
primal_target,
38+
primal_params;
39+
toplevel = config.toplevel,
40+
always_inline = config.always_inline,
41+
kernel = false,
42+
libraries = true,
43+
optimize = false,
44+
cleanup = false,
45+
only_entry = false,
46+
validate = false,
47+
# ??? entry_abi
48+
)
49+
primal_job = CompilerJob(job.source, primal_config, job.world)
50+
return GPUCompiler.compile_unhooked(output, primal_job)
51+
52+
# Normally, Enzyme would run here and transform the output of the primal job.
53+
end
54+
55+
import GPUCompiler: deferred_codegen_jobs
56+
import Core.Compiler as CC
57+
58+
function deferred_codegen_id_generator(world::UInt, source, self, ft::Type, tt::Type)
59+
@nospecialize
60+
@assert CC.isType(ft) && CC.isType(tt)
61+
ft = ft.parameters[1]
62+
tt = tt.parameters[1]
63+
64+
stub = Core.GeneratedFunctionStub(identity, Core.svec(:deferred_codegen_id, :ft, :tt), Core.svec())
65+
66+
# look up the method match
67+
method_error = :(throw(MethodError(ft, tt, $world)))
68+
sig = Tuple{ft, tt.parameters...}
69+
min_world = Ref{UInt}(typemin(UInt))
70+
max_world = Ref{UInt}(typemax(UInt))
71+
match = ccall(:jl_gf_invoke_lookup_worlds, Any,
72+
(Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}),
73+
sig, #=mt=# nothing, world, min_world, max_world)
74+
match === nothing && return stub(world, source, method_error)
75+
76+
# look up the method and code instance
77+
mi = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance},
78+
(Any, Any, Any), match.method, match.spec_types, match.sparams)
79+
ci = CC.retrieve_code_info(mi, world)
80+
81+
# prepare a new code info
82+
# TODO: Can we create a new CI instead of copying a "wrong" one?
83+
new_ci = copy(ci)
84+
empty!(new_ci.code)
85+
@static if isdefined(Core, :DebugInfo)
86+
new_ci.debuginfo = Core.DebugInfo(:none)
87+
else
88+
empty!(new_ci.codelocs)
89+
resize!(new_ci.linetable, 1) # see note below
90+
end
91+
empty!(new_ci.ssaflags)
92+
new_ci.ssavaluetypes = 0
93+
94+
# propagate edge metadata
95+
# new_ci.min_world = min_world[]
96+
new_ci.min_world = world
97+
new_ci.max_world = max_world[]
98+
new_ci.edges = Core.MethodInstance[mi]
99+
100+
# prepare the slots
101+
new_ci.slotnames = Symbol[Symbol("#self#"), :ft, :tt]
102+
new_ci.slotflags = UInt8[0x00 for i = 1:3]
103+
@static if isdefined(Core, :DebugInfo)
104+
new_ci.nargs = 3
105+
end
106+
107+
# We don't know the caller's target so EnzymeTarget uses the default NativeCompilerTarget.
108+
target = EnzymeTarget()
109+
params = EnzymeCompilerParams()
110+
config = CompilerConfig(target, params; kernel=false)
111+
job = CompilerJob(mi, config, world)
112+
113+
id = length(deferred_codegen_jobs) + 1
114+
deferred_codegen_jobs[id] = job
115+
116+
# return the deferred_codegen_id
117+
push!(new_ci.code, CC.ReturnNode(id))
118+
push!(new_ci.ssaflags, 0x00)
119+
@static if isdefined(Core, :DebugInfo)
120+
else
121+
push!(new_ci.codelocs, 1) # see note below
122+
end
123+
new_ci.ssavaluetypes += 1
124+
125+
# NOTE: we keep the first entry of the original linetable, and use it for location info
126+
# on the call to check_cache. we can't not have a codeloc (using 0 causes
127+
# corruption of the back trace), and reusing the target function's info
128+
# has as advantage that we see the name of the kernel in the backtraces.
129+
130+
return new_ci
131+
end
132+
133+
@eval function deferred_codegen_id(ft, tt)
134+
$(Expr(:meta, :generated_only))
135+
$(Expr(:meta, :generated, deferred_codegen_id_generator))
136+
end
137+
138+
@inline function deferred_codegen(f::Type, tt::Type)
139+
id = deferred_codegen_id(f, tt)
140+
ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Int,), id)
141+
end
142+
143+
end

test/native.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,3 +653,20 @@ end
653653
Native.code_llvm(mod.parent, Tuple{}; debuginfo=:none, mod.method_table)
654654
end
655655
end
656+
657+
@testset "Mock Enzyme" begin
658+
function kernel(a)
659+
a[1] = a[1]^2
660+
return
661+
end
662+
663+
function dkernel(a)
664+
ptr = Enzyme.deferred_codegen(typeof(kernel), Tuple{Vector{Float64}})
665+
ccall(ptr, Cvoid, (Vector{Float64},), a)
666+
return
667+
end
668+
669+
ir = sprint(io->Native.code_llvm(io, dkernel, Tuple{Vector{Float64}}; debuginfo=:none))
670+
@test !occursin("deferred_codegen", ir)
671+
@test occursin("call void @julia_kernel", ir)
672+
end

test/ptx.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,23 @@ end
152152
end
153153
end
154154

155+
@testset "Mock Enzyme" begin
156+
function kernel(a)
157+
unsafe_store!(a, unsafe_load(a)^2)
158+
return
159+
end
160+
161+
function dkernel(a)
162+
ptr = Enzyme.deferred_codegen(typeof(kernel), Tuple{Ptr{Float64}})
163+
ccall(ptr, Cvoid, (Ptr{Float64},), a)
164+
return
165+
end
166+
167+
ir = sprint(io->Native.code_llvm(io, dkernel, Tuple{Ptr{Float64}}; debuginfo=:none))
168+
@test !occursin("deferred_codegen", ir)
169+
@test occursin("call void @julia_", ir)
170+
end
171+
155172
end
156173

157174
############################################################################################

test/utils.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,3 +187,7 @@ end
187187
error("errors")
188188
end
189189
end
190+
191+
@testset "Mock Enzyme" begin
192+
Enzyme.deferred_codegen_id(typeof(identity), Tuple{Vector{Float64}})
193+
end

0 commit comments

Comments
 (0)