Skip to content

Commit

Permalink
add correct addrspace to global constants for Metal
Browse files Browse the repository at this point in the history
  • Loading branch information
tgymnich committed Nov 17, 2024
1 parent 09b4708 commit aa03c15
Showing 1 changed file with 76 additions and 2 deletions.
78 changes: 76 additions & 2 deletions src/metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,9 @@ function finish_ir!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::L

# add kernel metadata
if job.config.kernel
entry = add_address_spaces!(job, mod, entry)
entry = add_parameter_address_spaces!(job, mod, entry)
add_global_address_spaces!(job, mod)
entry = LLVM.functions(mod)[entry_fn]

add_argument_metadata!(job, mod, entry)

Expand Down Expand Up @@ -226,7 +228,7 @@ end
# NOTE: this pass also only rewrites pointers _without_ address spaces, which requires it to
# be executed after optimization (where Julia's address spaces are stripped). If we ever
# want to execute it earlier, adapt remapType to rewrite all pointer types.
function add_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.Function)
function add_parameter_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.Function)
ft = function_type(f)

# find the byref parameters
Expand Down Expand Up @@ -332,6 +334,78 @@ function add_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
return new_f
end

# add addrspace 2 to global constants
function add_global_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module)
for gv in globals(mod)
if isconstant(gv) && addrspace(value_type(gv)) == 0
gv_ty = global_value_type(gv)
gv_name = LLVM.name(gv)

new_gv = GlobalVariable(mod, gv_ty, "", 2)

alignment!(new_gv, alignment(gv))
unnamed_addr!(new_gv, unnamed_addr(gv))
initializer!(new_gv, initializer(gv))
constant!(new_gv, true)
linkage!(new_gv, linkage(gv))
visibility!(new_gv, visibility(gv))

funcs = Set{LLVM.Function}()
for use in uses(gv)
inst = user(use)
bb = LLVM.parent(inst)
f = LLVM.parent(bb)

push!(funcs, f)
end

for f in funcs
ft = function_type(f)
new_f = LLVM.Function(mod, "h", ft)
linkage!(new_f, linkage(f))

for (param, new_param) in zip(parameters(f), parameters(new_f))
LLVM.name!(new_param, LLVM.name(param))
end

@dispose builder=IRBuilder() begin
entry = BasicBlock(new_f, "gv_conversion")
position!(builder, entry)

ptr = alloca!(builder, gv_ty)
val = load!(builder, gv_ty, new_gv)
store!(builder, val, ptr)

# map the arguments
value_map = Dict{LLVM.Value, LLVM.Value}(
param => new_param for (param, new_param) in zip(parameters(f), parameters(new_f))
)

value_map[gv] = ptr
value_map[f] = new_f
clone_into!(new_f, f; value_map,
changes=LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges)

br!(builder, blocks(new_f)[2])
end

f_name = LLVM.name(f)
replace_uses!(f, new_f)
replace_metadata_uses!(f, new_f)
erase!(f)
LLVM.name!(new_f, f_name)
end

@assert isempty(uses(gv))
replace_metadata_uses!(gv, new_gv)
erase!(gv)
LLVM.name!(new_gv, gv_name)
end
end

return
end


# value-to-reference conversion
#
Expand Down

0 comments on commit aa03c15

Please sign in to comment.