From bb9f4f8ee451f95775d48e7c61f49279173192d8 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 30 Nov 2025 20:38:44 -0500 Subject: [PATCH 1/8] Fix hessian sret type --- src/compiler.jl | 13 +++++++++++-- src/compiler/optimize.jl | 9 ++++++++- src/utils.jl | 8 ++++---- test/ext/staticarrays.jl | 19 ++++++++++++++++++- 4 files changed, 41 insertions(+), 8 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 4e2a1422db..40d41a4c4c 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -6795,19 +6795,28 @@ function _thunk(job, postopt::Bool = true)::Tuple{LLVM.Module, Vector{Any}, Stri mstr = if job.config.params.ABI <: InlineABI "" else + fixup_callconv!(mod, JIT.get_tm()) + for f in functions(mod) + for i in 1:length(parameters(f)) + for a in collect(parameter_attributes(f, i)) + @assert kind(a) != "enzyme_sret" + @assert kind(a) != "enzyme_sret_v" + end + end + end string(mod) end if job.config.params.ABI <: FFIABI || job.config.params.ABI <: NonGenABI if DumpPrePostOpt[] API.EnzymeDumpModuleRef(mod.ref) end - post_optimize!(mod, JIT.get_tm()) + post_optimize!(mod, JIT.get_tm(); callconv=false) if DumpPostOpt[] API.EnzymeDumpModuleRef(mod.ref) end else propagate_returned!(mod) - Compiler.JIT.prepare!(mod) + Compiler.JIT.prepare!(mod) end mstr else diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index d45761fc4f..acfe23c0b4 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -392,7 +392,7 @@ end const DumpPreCallConv = Ref(false) const DumpPostCallConv = Ref(false) -function post_optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool = true) +function fixup_callconv!(mod::LLVM.Module, tm::LLVM.TargetMachine) addr13NoAlias(mod) removeDeadArgs!(mod, tm, #=post_gc_fixup=#false) @@ -451,6 +451,13 @@ function post_optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool ), ) end + return +end + +function post_optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool = true; callconv::Bool = true) + if callconv + fixup_callconv!(mod, tm) + end @dispose pb = NewPMPassBuilder() begin registerEnzymeAndPassPipeline!(pb) register!(pb, ReinsertGCMarkerPass()) diff --git a/src/utils.jl b/src/utils.jl index d2c5544c2b..fe869d7623 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -568,11 +568,11 @@ function sret_ty(fn::LLVM.Function, idx::Int)::LLVM.LLVMType end if ekind == "enzyme_sret" - ety = parse(UInt, LLVM.value(attr)) - ety = Base.reinterpret(LLVM.API.LLVMTypeRef, ety) - ety = LLVM.LLVMType(ety) + ety = parse(UInt, LLVM.value(attr)) + ety = Base.reinterpret(LLVM.API.LLVMTypeRef, ety) + ety = LLVM.LLVMType(ety) if !LLVM.is_opaque(vt) - @assert ety == eltype(vt) + @assert ety == eltype(vt) "Mismatched sret type $(string(fn))\nidx=$idx\nety ($(string(ety))) != eltype(vt) (vt = $(string(vt)))" end return ety diff --git a/test/ext/staticarrays.jl b/test/ext/staticarrays.jl index c2c55a9aa4..8d8a501935 100644 --- a/test/ext/staticarrays.jl +++ b/test/ext/staticarrays.jl @@ -89,4 +89,21 @@ end @test res ≈ [1.0, 0.0] res = Enzyme.gradient(Enzyme.Forward, unstable_fun, inp)[1] @test res ≈ [1.0, 0.0] -end \ No newline at end of file +end + +function inner_forhess(x) + return tanh.(x) +end + +function for_hess(x) + return sum(inner_forhess(x)) +end + +grad_forhess(x) = autodiff(Reverse, for_hess, Active, Active(x))[1][1] +hess(x) = jacobian(Forward, grad_forhess, x)[1] + +@testset "StaticArrays hessian" begin + x = @SVector zeros(10) + res = [-2.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 -2.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 -2.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 -2.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 -2.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 -2.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 -2.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 -2.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 -2.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 -2.0] + @test jacobian(Forward, grad_forhess, x)[1] ≈ res +end From afdc8c5da63860f04afbf018a638c4bb674b4107 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 1 Dec 2025 10:56:18 -0500 Subject: [PATCH 2/8] fix --- src/compiler.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/compiler.jl b/src/compiler.jl index 40d41a4c4c..0bfdbc2d08 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -6799,6 +6799,9 @@ function _thunk(job, postopt::Bool = true)::Tuple{LLVM.Module, Vector{Any}, Stri for f in functions(mod) for i in 1:length(parameters(f)) for a in collect(parameter_attributes(f, i)) + if kind(a) == "enzyme_sret" + API.EnzymeDumpValueRef(f) + end @assert kind(a) != "enzyme_sret" @assert kind(a) != "enzyme_sret_v" end From 6d5c8cdad4b0b5680858be6aa389113f6517ab08 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 14 Dec 2025 12:34:29 -0600 Subject: [PATCH 3/8] fix --- src/compiler.jl | 5 ++++- src/compiler/optimize.jl | 13 ++++++++++++- test/advanced.jl | 10 ++++++++++ 3 files changed, 26 insertions(+), 2 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 0bfdbc2d08..dc10fd7b91 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1111,12 +1111,15 @@ end EnumAttribute("willreturn"), EnumAttribute("nosync"), EnumAttribute("nofree"), + StringAttribute("enzyme_preserve_primal", "*"), ] else LLVM.Attribute[EnumAttribute("memory", NoEffects.data), StringAttribute("enzyme_shouldrecompute"), EnumAttribute("willreturn"), EnumAttribute("nosync"), - EnumAttribute("nofree")] + EnumAttribute("nofree"), + StringAttribute("enzyme_preserve_primal", "*"), + ] end handleCustom(state, custom, k_name, llvmfn, name, attrs) return diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index acfe23c0b4..d4927a31a2 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -396,7 +396,6 @@ function fixup_callconv!(mod::LLVM.Module, tm::LLVM.TargetMachine) addr13NoAlias(mod) removeDeadArgs!(mod, tm, #=post_gc_fixup=#false) - memcpy_sret_split!(mod) # if we did the move_sret_tofrom_roots, we will have loaded out of the sret, then stored into the rooted. @@ -458,6 +457,18 @@ function post_optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool if callconv fixup_callconv!(mod, tm) end + + for f in functions(mod) + if isempty(blocks(f)) + continue + end + if has_fn_attr(f, StringAttribute("enzyme_preserve_primal")) + delete!(LLVM.function_attributes(f), StringAttribute("enzyme_preserve_primal")) + end + end + + removeDeadArgs!(mod, tm, #=post_gc_fixup=#true) + @dispose pb = NewPMPassBuilder() begin registerEnzymeAndPassPipeline!(pb) register!(pb, ReinsertGCMarkerPass()) diff --git a/test/advanced.jl b/test/advanced.jl index bf60925ad3..228edb5331 100644 --- a/test/advanced.jl +++ b/test/advanced.jl @@ -546,6 +546,16 @@ end sqr(x) = x * x power(x, n) = x^n + + function objective1(x) + objvar = power(x, 2) + return objvar + end + + x0 = 2.1 + res = Enzyme.jacobian(Forward, Const(Enzyme.gradient), Const(Reverse), Const(objective1), x0) + @test res[3][1] ≈ 2.0 + function objective(x) (x1, x2, x3, x4) = x objvar = -4 - -(((((((((((((sqr(x1) + sqr(x2)) + sqr(x3 + x4)) + x3) + sqr(sin(x3))) + sqr(x1) * sqr(x2)) + x4) + sqr(sin(x3))) + sqr(-1 + x4)) + sqr(sqr(x2))) + sqr(sqr(x3) + sqr(x1 + x4))) + sqr(((-4 + sqr(sin(x4))) + sqr(x2) * sqr(x3)) + x1)) + power(sin(x4), 4))) From 123421daefa6b1dcc7984dfa4f53b2b476e5ef1f Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 14 Dec 2025 16:42:21 -0600 Subject: [PATCH 4/8] unknown user --- src/llvm/transforms.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/llvm/transforms.jl b/src/llvm/transforms.jl index 167fc250b1..0bbc7b1434 100644 --- a/src/llvm/transforms.jl +++ b/src/llvm/transforms.jl @@ -2637,7 +2637,14 @@ function removeDeadArgs!(mod::LLVM.Module, tm::LLVM.TargetMachine, post_gc_fixup ) for u in LLVM.uses(fn) u = LLVM.user(u) - @assert isa(u, LLVM.CallInst) + if !isa(u, LLVM.CallInst) + msg = sprint() do io + println("Unknown user of fn: ", string(u)) + println("fn: ", string(fn)) + println("mod: ", string(parent(fn))) + end + throw(AssertionError(msg)) + end B = IRBuilder() nextInst = LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(u)) position!(B, nextInst) From 1627ab465dfbe48f97725f0ccf01e4a4916fe2e1 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 15 Dec 2025 09:35:24 -0600 Subject: [PATCH 5/8] Refactor error logging for unknown function users --- src/llvm/transforms.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/llvm/transforms.jl b/src/llvm/transforms.jl index 0bbc7b1434..f3d630df0f 100644 --- a/src/llvm/transforms.jl +++ b/src/llvm/transforms.jl @@ -2639,9 +2639,9 @@ function removeDeadArgs!(mod::LLVM.Module, tm::LLVM.TargetMachine, post_gc_fixup u = LLVM.user(u) if !isa(u, LLVM.CallInst) msg = sprint() do io - println("Unknown user of fn: ", string(u)) - println("fn: ", string(fn)) - println("mod: ", string(parent(fn))) + println(io, "Unknown user of fn: ", string(u)) + println(io, "fn: ", string(fn)) + println(io, "mod: ", string(parent(fn))) end throw(AssertionError(msg)) end From dfe37ff2a9fd406f3e4c1c002e75516a1863f608 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 25 Dec 2025 16:09:42 -0500 Subject: [PATCH 6/8] fix --- src/llvm/transforms.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llvm/transforms.jl b/src/llvm/transforms.jl index f3d630df0f..91fb69a2d5 100644 --- a/src/llvm/transforms.jl +++ b/src/llvm/transforms.jl @@ -2641,7 +2641,7 @@ function removeDeadArgs!(mod::LLVM.Module, tm::LLVM.TargetMachine, post_gc_fixup msg = sprint() do io println(io, "Unknown user of fn: ", string(u)) println(io, "fn: ", string(fn)) - println(io, "mod: ", string(parent(fn))) + println(io, "mod: ", string(LLVM.parent(fn))) end throw(AssertionError(msg)) end From 17728945db241aeca6fd107d44bcdcc687405a37 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 25 Dec 2025 16:58:24 -0600 Subject: [PATCH 7/8] fix --- src/llvm/transforms.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/llvm/transforms.jl b/src/llvm/transforms.jl index 91fb69a2d5..7807405802 100644 --- a/src/llvm/transforms.jl +++ b/src/llvm/transforms.jl @@ -2638,7 +2638,9 @@ function removeDeadArgs!(mod::LLVM.Module, tm::LLVM.TargetMachine, post_gc_fixup for u in LLVM.uses(fn) u = LLVM.user(u) if !isa(u, LLVM.CallInst) - msg = sprint() do io + # TODO investigate if the inttoptr store that comes from reference caller poses an issue. + continue + msg = sprint() do io println(io, "Unknown user of fn: ", string(u)) println(io, "fn: ", string(fn)) println(io, "mod: ", string(LLVM.parent(fn))) From a7abf046b676374d1daab23733f8fe5b2d139a17 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 25 Dec 2025 18:57:37 -0600 Subject: [PATCH 8/8] mf --- src/llvm/transforms.jl | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/src/llvm/transforms.jl b/src/llvm/transforms.jl index 7807405802..e82e905520 100644 --- a/src/llvm/transforms.jl +++ b/src/llvm/transforms.jl @@ -2683,7 +2683,26 @@ function removeDeadArgs!(mod::LLVM.Module, tm::LLVM.TargetMachine, post_gc_fixup for u in LLVM.uses(fn) u = LLVM.user(u) if isa(u, LLVM.ConstantExpr) - u = LLVM.user(only(LLVM.uses(u))) + for u in LLVM.uses(u) + u = LLVM.user(u) + if !isa(u, LLVM.CallInst) + continue + end + @assert isa(u, LLVM.CallInst) + B = IRBuilder() + nextInst = LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(u)) + position!(B, nextInst) + inp = operands(u)[idx] + cl = call!(B, funcT, sfunc, LLVM.Value[inp]) + if isa(value_type(inp), LLVM.PointerType) + LLVM.API.LLVMAddCallSiteAttribute( + cl, + LLVM.API.LLVMAttributeIndex(1), + EnumAttribute("nocapture"), + ) + end + end + continue end if !isa(u, LLVM.CallInst) continue