Skip to content

Commit

Permalink
Consider jl_new_array (#1074)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Sep 24, 2023
1 parent 8d8963e commit d5361e1
Showing 1 changed file with 27 additions and 4 deletions.
31 changes: 27 additions & 4 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1009,10 +1009,7 @@ function array_shadow_handler(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMV

anti = call_samefunc_with_inverted_bundles!(b, gutils, orig, vals, valTys, #=lookup=#false)

prod = LLVM.Value(unsafe_load(Args, 2))
for i = 3:numArgs
prod = LLVM.mul!(b, prod, LLVM.Value(unsafe_load(Args, i)))
end
prod = get_array_len(b, anti)

isunboxed = allocatedinline(typ)

Expand Down Expand Up @@ -1099,6 +1096,27 @@ function get_array_elsz(B, array)
end

function get_array_len(B, array)
if isa(array, LLVM.CallInst)
fn = LLVM.called_operand(array)
nm = ""
if isa(fn, LLVM.Function)
nm = LLVM.name(fn)
end

for (fname, num) in (
("jl_alloc_array_1d", 1), ("ijl_alloc_array_1d", 1),
("jl_alloc_array_2d", 2), ("jl_alloc_array_2d", 2),
("jl_alloc_array_2d", 3), ("jl_alloc_array_2d", 3),
)
if nm == fname
res = operands(array)[2]
for i in 2:num
res = mul!(B, res, operands(array)[1+i])
end
return res
end
end
end
ST = get_array_struct()
array = LLVM.pointercast!(B, array, LLVM.PointerType(ST, LLVM.addrspace(LLVM.value_type(array))))
v = inbounds_gep!(B, ST, array, LLVM.Value[LLVM.ConstantInt(Int32(0)), LLVM.ConstantInt(Int32(1))])
Expand Down Expand Up @@ -6763,6 +6781,11 @@ function __init__()
@cfunction(array_shadow_handler, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, Csize_t, Ptr{LLVM.API.LLVMValueRef}, API.EnzymeGradientUtilsRef)),
@cfunction(null_free_handler, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef))
)
register_alloc_handler!(
("jl_new_array", "ijl_new_array"),
@cfunction(array_shadow_handler, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, Csize_t, Ptr{LLVM.API.LLVMValueRef}, API.EnzymeGradientUtilsRef)),
@cfunction(null_free_handler, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef))
)
register_handler!(
("julia.call",),
@augfunc(jlcall_augfwd),
Expand Down

0 comments on commit d5361e1

Please sign in to comment.