Skip to content

Move arguments to compile/codegen into the CompilerConfig struct #668

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Feb 18, 2025
Merged
97 changes: 49 additions & 48 deletions src/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,64 +43,59 @@ end

export compile

# NOTE: the keyword arguments to compile/codegen control those aspects of compilation that
# might have to be changed (e.g. set libraries=false when recursing, or set
# strip=true for reflection). What remains defines the compilation job itself,
# and those values are contained in the CompilerJob struct.

# (::CompilerJob)
const compile_hook = Ref{Union{Nothing,Function}}(nothing)

"""
compile(target::Symbol, job::CompilerJob; kwargs...)

Compile a function `f` invoked with types `tt` for device capability `cap` to one of the
following formats as specified by the `target` argument: `:julia` for Julia IR, `:llvm` for
LLVM IR and `:asm` for machine code.

The following keyword arguments are supported:
- `toplevel`: indicates that this compilation is the outermost invocation of the compiler
(default: true)
- `libraries`: link the GPU runtime and `libdevice` libraries (default: true, if toplevel)
- `optimize`: optimize the code (default: true, if toplevel)
- `cleanup`: run cleanup passes on the code (default: true, if toplevel)
- `validate`: enable optional validation of input and outputs (default: true, if toplevel)
- `strip`: strip non-functional metadata and debug information (default: false)
- `only_entry`: only keep the entry function, remove all others (default: false).
This option is only for internal use, to implement reflection's `dump_module`.

Other keyword arguments can be found in the documentation of [`cufunction`](@ref).
compile(target::Symbol, job::CompilerJob)

Compile a `job` to one of the following formats as specified by the `target` argument:
`:julia` for Julia IR, `:llvm` for LLVM IR and `:asm` for machine code.
"""
function compile(target::Symbol, @nospecialize(job::CompilerJob); kwargs...)
# XXX: remove on next major version
if !isempty(kwargs)
Base.depwarn("The GPUCompiler `compile` API does not take keyword arguments anymore. Use CompilerConfig instead.", :compile)
config = CompilerConfig(job.config; kwargs...)
job = CompilerJob(job.source, config)
end

if compile_hook[] !== nothing
compile_hook[](job)
end

return codegen(target, job; kwargs...)
return compile_unhooked(target, job)
end

function codegen(output::Symbol, @nospecialize(job::CompilerJob); toplevel::Bool=true,
libraries::Bool=toplevel, optimize::Bool=toplevel, cleanup::Bool=toplevel,
validate::Bool=toplevel, strip::Bool=false, only_entry::Bool=false,
parent_job::Union{Nothing, CompilerJob}=nothing)
# XXX: remove on next major version
function codegen(output::Symbol, @nospecialize(job::CompilerJob); kwargs...)
if !isempty(kwargs)
Base.depwarn("The GPUCompiler `codegen` function is an internal API. Use `GPUCompiler.compile` (with any kwargs passed to `CompilerConfig`) instead.", :codegen)
config = CompilerConfig(job.config; kwargs...)
job = CompilerJob(job.source, config)
end
compile_unhooked(output, job)
end

function compile_unhooked(output::Symbol, @nospecialize(job::CompilerJob); kwargs...)
if context(; throw_error=false) === nothing
error("No active LLVM context. Use `JuliaContext()` do-block syntax to create one.")
end

@timeit_debug to "Validation" begin
check_method(job) # not optional
validate && check_invocation(job)
job.config.validate && check_invocation(job)
end

prepare_job!(job)


## LLVM IR

ir, ir_meta = emit_llvm(job; libraries, toplevel, optimize, cleanup, only_entry, validate)
ir, ir_meta = emit_llvm(job)

if output == :llvm
if strip
if job.config.strip
@timeit_debug to "strip debug info" strip_debuginfo!(ir)
end

Expand All @@ -117,7 +112,7 @@ function codegen(output::Symbol, @nospecialize(job::CompilerJob); toplevel::Bool
else
error("Unknown assembly format $output")
end
asm, asm_meta = emit_asm(job, ir; strip, validate, format)
asm, asm_meta = emit_asm(job, ir, format)

if output == :asm || output == :obj
return asm, (; asm_meta..., ir_meta..., ir)
Expand Down Expand Up @@ -156,9 +151,14 @@ end

const __llvm_initialized = Ref(false)

@locked function emit_llvm(@nospecialize(job::CompilerJob); toplevel::Bool,
libraries::Bool, optimize::Bool, cleanup::Bool,
validate::Bool, only_entry::Bool)
@locked function emit_llvm(@nospecialize(job::CompilerJob); kwargs...)
# XXX: remove on next major version
if !isempty(kwargs)
Base.depwarn("The GPUCompiler `emit_llvm` function is an internal API. Use `GPUCompiler.compile` (with any kwargs passed to `CompilerConfig`) instead.", :emit_llvm)
config = CompilerConfig(job.config; kwargs...)
job = CompilerJob(job.source, config)
end

if !__llvm_initialized[]
InitializeAllTargets()
InitializeAllTargetInfos()
Expand All @@ -183,7 +183,8 @@ const __llvm_initialized = Ref(false)
entry = finish_module!(job, ir, entry)

# deferred code generation
has_deferred_jobs = toplevel && !only_entry && haskey(functions(ir), "deferred_codegen")
has_deferred_jobs = job.config.toplevel && !job.config.only_entry &&
haskey(functions(ir), "deferred_codegen")
jobs = Dict{CompilerJob, String}(job => entry_fn)
if has_deferred_jobs
dyn_marker = functions(ir)["deferred_codegen"]
Expand Down Expand Up @@ -221,8 +222,8 @@ const __llvm_initialized = Ref(false)
for dyn_job in keys(worklist)
# cached compilation
dyn_entry_fn = get!(jobs, dyn_job) do
dyn_ir, dyn_meta = codegen(:llvm, dyn_job; toplevel=false,
parent_job=job)
config = CompilerConfig(dyn_job.config; toplevel=false)
dyn_ir, dyn_meta = codegen(:llvm, CompilerJob(dyn_job; config))
Comment on lines +225 to +226
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know quite a delayed review, but the removal of parent_job here is tricky.

Enzyme places a fully featured CompilerJob into the jobs list (https://github.com/EnzymeAD/Enzyme.jl/blob/b9b12e5d6f5597d70bacbe7a83361244253b84bd/src/compiler.jl#L6057) with the target being EnzymeTarget().

Arguably, EnzymeTarget() should probably capture the target of the primal, but at that moment we are executing a generator function (I know I know) and don't have information in which context this job will be used. We somehow need to know the config (or at least target) for the parent such that we could generate the correct primal code, otherwise we default to the CPU target.

We were using parent_job to obtain this information just before we went and generated the primal code.

dyn_entry_fn = LLVM.name(dyn_meta.entry)
merge!(compiled, dyn_meta.compiled)
@assert context(dyn_ir) == context(ir)
Expand Down Expand Up @@ -258,7 +259,7 @@ const __llvm_initialized = Ref(false)
erase!(dyn_marker)
end

if libraries
if job.config.toplevel && job.config.libraries
# load the runtime outside of a timing block (because it recurses into the compiler)
if !uses_julia_runtime(job)
runtime = load_runtime(job)
Expand All @@ -284,7 +285,7 @@ const __llvm_initialized = Ref(false)
# mark everything internal except for entrypoints and any exported
# global variables. this makes sure that the optimizer can, e.g.,
# rewrite function signatures.
if toplevel
if job.config.toplevel
preserved_gvs = collect(values(jobs))
for gvar in globals(ir)
if linkage(gvar) == LLVM.API.LLVMExternalLinkage
Expand All @@ -310,7 +311,7 @@ const __llvm_initialized = Ref(false)
# so that we can reconstruct the CompileJob instead of setting it globally
end

if optimize
if job.config.toplevel && job.config.optimize
@timeit_debug to "optimization" begin
optimize!(job, ir; job.config.opt_level)

Expand All @@ -337,7 +338,7 @@ const __llvm_initialized = Ref(false)
entry = functions(ir)[entry_fn]
end

if cleanup
if job.config.toplevel && job.config.cleanup
@timeit_debug to "clean-up" begin
@dispose pb=NewPMPassBuilder() begin
add!(pb, RecomputeGlobalsAAPass())
Expand All @@ -355,7 +356,7 @@ const __llvm_initialized = Ref(false)
# we want to finish the module after optimization, so we cannot do so
# during deferred code generation. instead, process the deferred jobs
# here.
if toplevel
if job.config.toplevel
entry = finish_ir!(job, ir, entry)

for (job′, fn′) in jobs
Expand All @@ -367,7 +368,7 @@ const __llvm_initialized = Ref(false)
# replace non-entry function definitions with a declaration
# NOTE: we can't do this before optimization, because the definitions of called
# functions may affect optimization.
if only_entry
if job.config.only_entry
for f in functions(ir)
f == entry && continue
isdeclaration(f) && continue
Expand All @@ -377,7 +378,7 @@ const __llvm_initialized = Ref(false)
end
end

if validate
if job.config.toplevel && job.config.validate
@timeit_debug to "Validation" begin
check_ir(job, ir)
end
Expand All @@ -390,10 +391,10 @@ const __llvm_initialized = Ref(false)
return ir, (; entry, compiled)
end

@locked function emit_asm(@nospecialize(job::CompilerJob), ir::LLVM.Module;
strip::Bool, validate::Bool, format::LLVM.API.LLVMCodeGenFileType)
@locked function emit_asm(@nospecialize(job::CompilerJob), ir::LLVM.Module,
format::LLVM.API.LLVMCodeGenFileType)
# NOTE: strip after validation to get better errors
if strip
if job.config.strip
@timeit_debug to "Debug info removal" strip_debuginfo!(ir)
end

Expand Down
25 changes: 21 additions & 4 deletions src/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,20 @@ export split_kwargs, assign_args!
# split keyword arguments expressions into groups. returns vectors of keyword argument
# values, one more than the number of groups (unmatched keywords in the last vector).
# intended for use in macros; the resulting groups can be used in expressions.
# can be used at run time, but not in performance critical code.
function split_kwargs(kwargs, kw_groups...)
kwarg_groups = ntuple(_->[], length(kw_groups) + 1)
for kwarg in kwargs
# decode
Meta.isexpr(kwarg, :(=)) || throw(ArgumentError("non-keyword argument like option '$kwarg'"))
key, val = kwarg.args
if Meta.isexpr(kwarg, :(=))
# use in macros
key, val = kwarg.args
elseif kwarg isa Pair{Symbol,<:Any}
# use in functions
key, val = kwarg
else
throw(ArgumentError("non-keyword argument like option '$kwarg'"))
end
isa(key, Symbol) || throw(ArgumentError("non-symbolic keyword '$key'"))

# find a matching group
Expand Down Expand Up @@ -182,7 +190,7 @@ end
end

struct DiskCacheEntry
src::Type # Originally MethodInstance, but upon deserialize they were not uniqued...
src::Type # Originally MethodInstance, but upon deserialize they were not uniqued...
cfg::CompilerConfig
asm
end
Expand Down Expand Up @@ -262,7 +270,16 @@ end
obj = linker(job, asm)

if ci === nothing
ci = ci_cache_lookup(ci_cache(job), src, world, world)::CodeInstance
ci = ci_cache_lookup(ci_cache(job), src, world, world)
if ci === nothing
error("""Did not find CodeInstance for $job.

Pleaase make sure that the `compiler` function passed to `cached_compilation`
invokes GPUCompiler with exactly the same configuration as passed to the API.

Note that you should do this by calling `GPUCompiler.compile`, and not by
using reflection functions (which alter the compiler configuration).""")
end
key = (ci, cfg)
end
cache[key] = obj
Expand Down
Loading
Loading