Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6740,19 +6740,31 @@ function _thunk(job, postopt::Bool = true)::Tuple{LLVM.Module, Vector{Any}, Stri
mstr = if job.config.params.ABI <: InlineABI
""
else
fixup_callconv!(mod, JIT.get_tm())
for f in functions(mod)
for i in 1:length(parameters(f))
for a in collect(parameter_attributes(f, i))
if kind(a) == "enzyme_sret"
API.EnzymeDumpValueRef(f)
end
@assert kind(a) != "enzyme_sret"
@assert kind(a) != "enzyme_sret_v"
end
end
end
string(mod)
end
if job.config.params.ABI <: FFIABI || job.config.params.ABI <: NonGenABI
if DumpPrePostOpt[]
API.EnzymeDumpModuleRef(mod.ref)
end
post_optimize!(mod, JIT.get_tm())
post_optimize!(mod, JIT.get_tm(); callconv=false)
if DumpPostOpt[]
API.EnzymeDumpModuleRef(mod.ref)
end
else
propagate_returned!(mod)
Compiler.JIT.prepare!(mod)
Compiler.JIT.prepare!(mod)
end
mstr
else
Expand Down
9 changes: 8 additions & 1 deletion src/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ end
const DumpPreCallConv = Ref(false)
const DumpPostCallConv = Ref(false)

function post_optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool = true)
function fixup_callconv!(mod::LLVM.Module, tm::LLVM.TargetMachine)
addr13NoAlias(mod)

removeDeadArgs!(mod, tm, #=post_gc_fixup=#false)
Expand Down Expand Up @@ -430,6 +430,13 @@ function post_optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool
),
)
end
return
end

function post_optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool = true; callconv::Bool = true)
if callconv
fixup_callconv!(mod, tm)
end
@dispose pb = NewPMPassBuilder() begin
registerEnzymeAndPassPipeline!(pb)
register!(pb, ReinsertGCMarkerPass())
Expand Down
8 changes: 4 additions & 4 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -562,11 +562,11 @@ function sret_ty(fn::LLVM.Function, idx::Int)::LLVM.LLVMType
end

if ekind == "enzyme_sret"
ety = parse(UInt, LLVM.value(attr))
ety = Base.reinterpret(LLVM.API.LLVMTypeRef, ety)
ety = LLVM.LLVMType(ety)
ety = parse(UInt, LLVM.value(attr))
ety = Base.reinterpret(LLVM.API.LLVMTypeRef, ety)
ety = LLVM.LLVMType(ety)
if !LLVM.is_opaque(vt)
@assert ety == eltype(vt)
@assert ety == eltype(vt) "Mismatched sret type $(string(fn))\nidx=$idx\nety ($(string(ety))) != eltype(vt) (vt = $(string(vt)))"
end

return ety
Expand Down
19 changes: 18 additions & 1 deletion test/ext/staticarrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,21 @@ end
@test res ≈ [1.0, 0.0]
res = Enzyme.gradient(Enzyme.Forward, unstable_fun, inp)[1]
@test res ≈ [1.0, 0.0]
end
end

function inner_forhess(x)
return tanh.(x)
end

function for_hess(x)
return sum(inner_forhess(x))
end

grad_forhess(x) = autodiff(Reverse, for_hess, Active, Active(x))[1][1]
hess(x) = jacobian(Forward, grad_forhess, x)[1]

@testset "StaticArrays hessian" begin
x = @SVector zeros(10)
res = [-2.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 -2.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 -2.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 -2.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 -2.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 -2.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 -2.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 -2.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 -2.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 -2.0]
@test jacobian(Forward, grad_forhess, x)[1] ≈ res
end
56 changes: 56 additions & 0 deletions test/hessian_mwe.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#!/usr/bin/env julia
# Minimal Working Example (MWE) for testing enzyme_sret type handling
# This file isolates the core functionality being tested to help diagnose timeouts

using Enzyme
using Test
using LinearAlgebra

println("Testing Forward-over-Reverse (Hessian) computation...")
println("This test exercises sret type handling in nested AD")

@testset "Hessian MWE" begin
# Function: f(x) = x[1]^2 + x[1]*x[2]
# Gradient: ∇f = [2*x[1] + x[2], x[1]]
# Hessian: H = [[2, 1], [1, 0]]

function origf(x::Array{Float64}, y::Array{Float64})
y[1] = x[1] * x[1] + x[2] * x[1]
return nothing
end

function grad(x, dx, y, dy)
Enzyme.autodiff(Reverse, Const(origf), Duplicated(x, dx), DuplicatedNoNeed(y, dy))
nothing
end

x = [2.0, 2.0]
y = Vector{Float64}(undef, 1)
dx = [0.0, 0.0]
dy = [1.0]

grad(x, dx, y, dy)

vx = ([1.0, 0.0], [0.0, 1.0])
hess = ([0.0, 0.0], [0.0, 0.0])
dx2 = [0.0, 0.0]
dy = [1.0]

Enzyme.autodiff(
Enzyme.Forward, grad,
Enzyme.BatchDuplicated(x, vx),
Enzyme.BatchDuplicated(dx2, hess),
Const(y),
Const(dy)
)

@test dx ≈ dx2
@test hess[1][1] ≈ 2.0
@test hess[1][2] ≈ 1.0
@test hess[2][1] ≈ 1.0
@test hess[2][2] ≈ 0.0
end

println("\nTest completed successfully!")
println("If this test hangs or times out, the issue is with Forward-over-Reverse AD")
println("with BatchDuplicated arguments, which requires proper enzyme_sret handling.")