Skip to content

Commit

Permalink
Save julia types on sret (#2127)
Browse files Browse the repository at this point in the history
* Save julia types on sret

* fix

* lig
  • Loading branch information
wsmoses authored Nov 28, 2024
1 parent e9d303b commit 8b65381
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 3 deletions.
10 changes: 10 additions & 0 deletions src/absint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,16 @@ function abs_typeof(
end
end

if isa(arg, LLVM.AllocaInst) || isa(arg, LLVM.CallInst)
if haskey(metadata(arg), "enzymejl_allocart")
mds = operands(metadata(arg)["enzymejl_allocart"])[1]::MDString
mds = Base.convert(String, mds)
ptr = reinterpret(Ptr{Cvoid}, parse(UInt, mds))
RT = Base.unsafe_pointer_to_objref(ptr)
return (true, RT, GPUCompiler.MUT_REF)
end
end

if isa(arg, LLVM.CallInst)
fn = LLVM.called_operand(arg)
nm = ""
Expand Down
48 changes: 45 additions & 3 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1727,6 +1727,7 @@ end
else
"Unknown object of type" * " " * string(TT)
end
@assert !illegal
illegalVal = cur
illegal = true
return make_batched(ncur, prevbb)
Expand Down Expand Up @@ -1770,6 +1771,7 @@ end
end

cur2 = if changed
@assert !illegal
illegalVal = cur
illegal = true
# TODO replace with correct insertions/splats
Expand Down Expand Up @@ -1942,8 +1944,10 @@ end
return make_batched(ncur, prevbb)
end

illegal = true
illegalVal = cur
if !illegal
illegal = true
illegalVal = cur
end
return ncur
end

Expand Down Expand Up @@ -7070,10 +7074,48 @@ end
ctx = LLVM.context(mod)
for f in functions(mod), bb in blocks(f), inst in instructions(bb)
fn = isa(inst, LLVM.CallInst) ? LLVM.called_operand(inst) : nothing

if !API.HasFromStack(inst) && isa(inst, LLVM.AllocaInst)

calluse = nothing
for u in LLVM.uses(inst)
u = LLVM.user(u)
if isa(u, LLVM.CallInst) && operands(u)[1] == inst

sretkind = kind(if LLVM.version().major >= 12
TypeAttribute("sret", LLVM.Int32Type())
else
EnumAttribute("sret")
end)
hassret = false
llvmfn = LLVM.called_operand(u)
if llvmfn isa LLVM.Function
for attr in collect(parameter_attributes(llvmfn, 1))
if kind(attr) == sretkind
hassret = true
break
end
end
end
if hassret
calluse = u
end
end
end
if calluse isa LLVM.CallInst
_, RT = enzyme_custom_extract_mi(calluse, false)
if RT !== nothing
llrt, sret, returnRoots = get_return_info(RT)
if !(sret isa Nothing) && !is_sret_union(RT)
metadata(inst)["enzymejl_allocart"] = MDNode(LLVM.Metadata[MDString(string(convert(UInt, unsafe_to_pointer(RT))))])
end
end
end
end

if !API.HasFromStack(inst) &&
((isa(inst, LLVM.CallInst) &&
(!isa(fn, LLVM.Function) || isempty(blocks(fn))) ) || isa(inst, LLVM.LoadInst))
(!isa(fn, LLVM.Function) || isempty(blocks(fn))) ) || isa(inst, LLVM.LoadInst) || isa(inst, LLVM.AllocaInst))
legal, source_typ, byref = abs_typeof(inst)
codegen_typ = value_type(inst)
if legal
Expand Down

0 comments on commit 8b65381

Please sign in to comment.