diff --git a/src/compiler.jl b/src/compiler.jl index bf7e62cc4a..5bec0dd255 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -6740,19 +6740,31 @@ function _thunk(job, postopt::Bool = true)::Tuple{LLVM.Module, Vector{Any}, Stri mstr = if job.config.params.ABI <: InlineABI "" else + fixup_callconv!(mod, JIT.get_tm()) + for f in functions(mod) + for i in 1:length(parameters(f)) + for a in collect(parameter_attributes(f, i)) + if kind(a) == "enzyme_sret" + API.EnzymeDumpValueRef(f) + end + @assert kind(a) != "enzyme_sret" + @assert kind(a) != "enzyme_sret_v" + end + end + end string(mod) end if job.config.params.ABI <: FFIABI || job.config.params.ABI <: NonGenABI if DumpPrePostOpt[] API.EnzymeDumpModuleRef(mod.ref) end - post_optimize!(mod, JIT.get_tm()) + post_optimize!(mod, JIT.get_tm(); callconv=false) if DumpPostOpt[] API.EnzymeDumpModuleRef(mod.ref) end else propagate_returned!(mod) - Compiler.JIT.prepare!(mod) + Compiler.JIT.prepare!(mod) end mstr else diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index 713e8440ad..cc2903f807 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -377,7 +377,7 @@ end const DumpPreCallConv = Ref(false) const DumpPostCallConv = Ref(false) -function post_optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool = true) +function fixup_callconv!(mod::LLVM.Module, tm::LLVM.TargetMachine) addr13NoAlias(mod) removeDeadArgs!(mod, tm, #=post_gc_fixup=#false) @@ -430,6 +430,13 @@ function post_optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool ), ) end + return +end + +function post_optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool = true; callconv::Bool = true) + if callconv + fixup_callconv!(mod, tm) + end @dispose pb = NewPMPassBuilder() begin registerEnzymeAndPassPipeline!(pb) register!(pb, ReinsertGCMarkerPass()) diff --git a/src/utils.jl b/src/utils.jl index c0ca94de7e..be75f0ef12 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -562,11 +562,11 @@ function sret_ty(fn::LLVM.Function, idx::Int)::LLVM.LLVMType end if ekind == "enzyme_sret" - ety = parse(UInt, LLVM.value(attr)) - ety = Base.reinterpret(LLVM.API.LLVMTypeRef, ety) - ety = LLVM.LLVMType(ety) + ety = parse(UInt, LLVM.value(attr)) + ety = Base.reinterpret(LLVM.API.LLVMTypeRef, ety) + ety = LLVM.LLVMType(ety) if !LLVM.is_opaque(vt) - @assert ety == eltype(vt) + @assert ety == eltype(vt) "Mismatched sret type $(string(fn))\nidx=$idx\nety ($(string(ety))) != eltype(vt) (vt = $(string(vt)))" end return ety diff --git a/test/ext/staticarrays.jl b/test/ext/staticarrays.jl index c2c55a9aa4..8d8a501935 100644 --- a/test/ext/staticarrays.jl +++ b/test/ext/staticarrays.jl @@ -89,4 +89,21 @@ end @test res ≈ [1.0, 0.0] res = Enzyme.gradient(Enzyme.Forward, unstable_fun, inp)[1] @test res ≈ [1.0, 0.0] -end \ No newline at end of file +end + +function inner_forhess(x) + return tanh.(x) +end + +function for_hess(x) + return sum(inner_forhess(x)) +end + +grad_forhess(x) = autodiff(Reverse, for_hess, Active, Active(x))[1][1] +hess(x) = jacobian(Forward, grad_forhess, x)[1] + +@testset "StaticArrays hessian" begin + x = @SVector zeros(10) + res = [-2.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 -2.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 -2.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 -2.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 -2.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 -2.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 -2.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 -2.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 -2.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 -2.0] + @test jacobian(Forward, grad_forhess, x)[1] ≈ res +end diff --git a/test/hessian_mwe.jl b/test/hessian_mwe.jl new file mode 100644 index 0000000000..8d3016711b --- /dev/null +++ b/test/hessian_mwe.jl @@ -0,0 +1,56 @@ +#!/usr/bin/env julia +# Minimal Working Example (MWE) for testing enzyme_sret type handling +# This file isolates the core functionality being tested to help diagnose timeouts + +using Enzyme +using Test +using LinearAlgebra + +println("Testing Forward-over-Reverse (Hessian) computation...") +println("This test exercises sret type handling in nested AD") + +@testset "Hessian MWE" begin + # Function: f(x) = x[1]^2 + x[1]*x[2] + # Gradient: ∇f = [2*x[1] + x[2], x[1]] + # Hessian: H = [[2, 1], [1, 0]] + + function origf(x::Array{Float64}, y::Array{Float64}) + y[1] = x[1] * x[1] + x[2] * x[1] + return nothing + end + + function grad(x, dx, y, dy) + Enzyme.autodiff(Reverse, Const(origf), Duplicated(x, dx), DuplicatedNoNeed(y, dy)) + nothing + end + + x = [2.0, 2.0] + y = Vector{Float64}(undef, 1) + dx = [0.0, 0.0] + dy = [1.0] + + grad(x, dx, y, dy) + + vx = ([1.0, 0.0], [0.0, 1.0]) + hess = ([0.0, 0.0], [0.0, 0.0]) + dx2 = [0.0, 0.0] + dy = [1.0] + + Enzyme.autodiff( + Enzyme.Forward, grad, + Enzyme.BatchDuplicated(x, vx), + Enzyme.BatchDuplicated(dx2, hess), + Const(y), + Const(dy) + ) + + @test dx ≈ dx2 + @test hess[1][1] ≈ 2.0 + @test hess[1][2] ≈ 1.0 + @test hess[2][1] ≈ 1.0 + @test hess[2][2] ≈ 0.0 +end + +println("\nTest completed successfully!") +println("If this test hangs or times out, the issue is with Forward-over-Reverse AD") +println("with BatchDuplicated arguments, which requires proper enzyme_sret handling.")