1
+ module Enzyme
2
+
3
+ using .. GPUCompiler
4
+
5
+ struct EnzymeTarget{Target<: AbstractCompilerTarget } <: AbstractCompilerTarget
6
+ target:: Target
7
+ end
8
+
9
+ function EnzymeTarget (;kwargs... )
10
+ EnzymeTarget (GPUCompiler. NativeCompilerTarget (; jlruntime = true , kwargs... ))
11
+ end
12
+
13
+ GPUCompiler. llvm_triple (target:: EnzymeTarget ) = GPUCompiler. llvm_triple (target. target)
14
+ GPUCompiler. llvm_datalayout (target:: EnzymeTarget ) = GPUCompiler. llvm_datalayout (target. target)
15
+ GPUCompiler. llvm_machine (target:: EnzymeTarget ) = GPUCompiler. llvm_machine (target. target)
16
+ GPUCompiler. nest_target (:: EnzymeTarget , other:: AbstractCompilerTarget ) = EnzymeTarget (other)
17
+ GPUCompiler. have_fma (target:: EnzymeTarget , T:: Type ) = GPUCompiler. have_fma (target. target, T)
18
+ GPUCompiler. dwarf_version (target:: EnzymeTarget ) = GPUCompiler. dwarf_version (target. target)
19
+
20
+ abstract type AbstractEnzymeCompilerParams <: AbstractCompilerParams end
21
+ struct EnzymeCompilerParams{Params<: AbstractCompilerParams } <: AbstractEnzymeCompilerParams
22
+ params:: Params
23
+ end
24
+ struct PrimalCompilerParams <: AbstractEnzymeCompilerParams
25
+ end
26
+
27
+ EnzymeCompilerParams () = EnzymeCompilerParams (PrimalCompilerParams ())
28
+
29
+ GPUCompiler. nest_params (:: EnzymeCompilerParams , other:: AbstractCompilerParams ) = EnzymeCompilerParams (other)
30
+
31
+ function GPUCompiler. compile_unhooked (output:: Symbol , job:: CompilerJob{<:EnzymeTarget} )
32
+ config = job. config
33
+ primal_target = (job. config. target:: EnzymeTarget ). target
34
+ primal_params = (job. config. params:: EnzymeCompilerParams ). params
35
+
36
+ primal_config = CompilerConfig (
37
+ primal_target,
38
+ primal_params;
39
+ toplevel = config. toplevel,
40
+ always_inline = config. always_inline,
41
+ kernel = false ,
42
+ libraries = true ,
43
+ optimize = false ,
44
+ cleanup = false ,
45
+ only_entry = false ,
46
+ validate = false ,
47
+ # ??? entry_abi
48
+ )
49
+ primal_job = CompilerJob (job. source, primal_config, job. world)
50
+ return GPUCompiler. compile_unhooked (output, primal_job)
51
+
52
+ # Normally, Enzyme would run here and transform the output of the primal job.
53
+ end
54
+
55
+ import GPUCompiler: deferred_codegen_jobs
56
+ import Core. Compiler as CC
57
+
58
+ function deferred_codegen_id_generator (world:: UInt , source, self, ft:: Type , tt:: Type )
59
+ @nospecialize
60
+ @assert CC. isType (ft) && CC. isType (tt)
61
+ ft = ft. parameters[1 ]
62
+ tt = tt. parameters[1 ]
63
+
64
+ stub = Core. GeneratedFunctionStub (identity, Core. svec (:deferred_codegen_id , :ft , :tt ), Core. svec ())
65
+
66
+ # look up the method match
67
+ method_error = :(throw (MethodError (ft, tt, $ world)))
68
+ sig = Tuple{ft, tt. parameters... }
69
+ min_world = Ref {UInt} (typemin (UInt))
70
+ max_world = Ref {UInt} (typemax (UInt))
71
+ match = ccall (:jl_gf_invoke_lookup_worlds , Any,
72
+ (Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}),
73
+ sig, #= mt=# nothing , world, min_world, max_world)
74
+ match === nothing && return stub (world, source, method_error)
75
+
76
+ # look up the method and code instance
77
+ mi = ccall (:jl_specializations_get_linfo , Ref{Core. MethodInstance},
78
+ (Any, Any, Any), match. method, match. spec_types, match. sparams)
79
+ ci = CC. retrieve_code_info (mi, world)
80
+
81
+ # prepare a new code info
82
+ # TODO : Can we create a new CI instead of copying a "wrong" one?
83
+ new_ci = copy (ci)
84
+ empty! (new_ci. code)
85
+ @static if isdefined (Core, :DebugInfo )
86
+ new_ci. debuginfo = Core. DebugInfo (:none )
87
+ else
88
+ empty! (new_ci. codelocs)
89
+ resize! (new_ci. linetable, 1 ) # see note below
90
+ end
91
+ empty! (new_ci. ssaflags)
92
+ new_ci. ssavaluetypes = 0
93
+
94
+ # propagate edge metadata
95
+ # new_ci.min_world = min_world[]
96
+ new_ci. min_world = world
97
+ new_ci. max_world = max_world[]
98
+ new_ci. edges = Core. MethodInstance[mi]
99
+
100
+ # prepare the slots
101
+ new_ci. slotnames = Symbol[Symbol (" #self#" ), :ft , :tt ]
102
+ new_ci. slotflags = UInt8[0x00 for i = 1 : 3 ]
103
+ @static if isdefined (Core, :DebugInfo )
104
+ new_ci. nargs = 3
105
+ end
106
+
107
+ # We don't know the caller's target so EnzymeTarget uses the default NativeCompilerTarget.
108
+ target = EnzymeTarget ()
109
+ params = EnzymeCompilerParams ()
110
+ config = CompilerConfig (target, params; kernel= false )
111
+ job = CompilerJob (mi, config, world)
112
+
113
+ id = length (deferred_codegen_jobs) + 1
114
+ deferred_codegen_jobs[id] = job
115
+
116
+ # return the deferred_codegen_id
117
+ push! (new_ci. code, CC. ReturnNode (id))
118
+ push! (new_ci. ssaflags, 0x00 )
119
+ @static if isdefined (Core, :DebugInfo )
120
+ else
121
+ push! (new_ci. codelocs, 1 ) # see note below
122
+ end
123
+ new_ci. ssavaluetypes += 1
124
+
125
+ # NOTE: we keep the first entry of the original linetable, and use it for location info
126
+ # on the call to check_cache. we can't not have a codeloc (using 0 causes
127
+ # corruption of the back trace), and reusing the target function's info
128
+ # has as advantage that we see the name of the kernel in the backtraces.
129
+
130
+ return new_ci
131
+ end
132
+
133
+ @eval function deferred_codegen_id (ft, tt)
134
+ $ (Expr (:meta , :generated_only ))
135
+ $ (Expr (:meta , :generated , deferred_codegen_id_generator))
136
+ end
137
+
138
+ @inline function deferred_codegen (f:: Type , tt:: Type )
139
+ id = deferred_codegen_id (f, tt)
140
+ ccall (" extern deferred_codegen" , llvmcall, Ptr{Cvoid}, (Int,), id)
141
+ end
142
+
143
+ end
0 commit comments