There are some important missing rules which make differentiating standard linear algebra operations difficult. A summary:
Here is code which can replicate these problems.
# Missing LAPACK rules: getrf! (dgetrf_64_), getrs! (dgetrs_64_)
# Fallback BLAS: trsv! (dtrsv_64_), trsm! (dtrsm_64_)
#
# LU solve: pre-factor outside AD, differentiate solve via trsv! inside,
# Note: cholesky!/potrf! and ldiv!(::Cholesky)/potrs! have native rules.
using LinearAlgebra, LinearAlgebra.BLAS, LinearAlgebra.LAPACK
using Random, Test, Enzyme, EnzymeTestUtils
Random.seed!(42)
const N = 6
A = randn(N, N)
b = randn(N)
F = lu(A)
U = collect(UpperTriangular(F.factors))
# ─── Passing tests ───────────────────────────────────────────────────────
@testset "trsv! (fallback warning)" begin
f!(x, U, b) = (copyto!(x, b); BLAS.trsv!('U', 'N', 'N', U, x); sum(x))
test_forward(f!, Duplicated, (zeros(N), Duplicated), (copy(U), Duplicated), (copy(b), Duplicated))
test_reverse(f!, Active, (zeros(N), Duplicated), (copy(U), Duplicated), (copy(b), Duplicated))
end
@testset "trsm! (fallback warning)" begin
B = randn(N, 3)
f!(X, U, B) = (copyto!(X, B); BLAS.trsm!('L', 'U', 'N', 'N', 1.0, U, X); sum(X))
test_forward(f!, Duplicated, (zeros(N, 3), Duplicated), (copy(U), Duplicated), (copy(B), Duplicated))
test_reverse(f!, Active, (zeros(N, 3), Duplicated), (copy(U), Duplicated), (copy(B), Duplicated))
end
function _apply_ipiv!(b, ipiv)
@inbounds for i in eachindex(ipiv)
if ipiv[i] != i
b[i], b[ipiv[i]] = b[ipiv[i]], b[i]
end
end
return b
end
function lu_trsv!(x, factors, ipiv, b)
copyto!(x, b)
_apply_ipiv!(x, ipiv)
BLAS.trsv!('L', 'N', 'U', factors, x)
BLAS.trsv!('U', 'N', 'N', factors, x)
return sum(x)
end
@testset "LU solve: pre-factor + trsv!" begin
x = zeros(N); lu_trsv!(x, F.factors, F.ipiv, b)
@test x ≈ A \ b
test_forward(lu_trsv!, Duplicated,
(zeros(N), Duplicated), (copy(F.factors), Duplicated), (F.ipiv, Const), (copy(b), Duplicated))
test_reverse(lu_trsv!, Active,
(zeros(N), Duplicated), (copy(F.factors), Duplicated), (F.ipiv, Const), (copy(b), Duplicated))
end
# ─── Expected failures ───────────────────────────────────────────────────
println("\n--- Expected failures ---")
f_getrf(buf, A) = (copyto!(buf, A); LAPACK.getrf!(buf, zeros(LinearAlgebra.BlasInt, N)); sum(buf))
f_getrs(x, fac, ipiv, b) = (copyto!(x, b); LAPACK.getrs!('N', fac, ipiv, x); sum(x))
println("A\\b forward:")
try
autodiff(Forward, (A, b) -> sum(A \ b), Duplicated,
Duplicated(copy(A), ones(N, N)), Duplicated(copy(b), ones(N)))
println(" PASS (unexpected)")
catch e
println(" ", nameof(typeof(e)))
end
println("getrf! forward:")
try
autodiff(Forward, f_getrf, Duplicated,
Duplicated(zeros(N, N), ones(N, N)), Duplicated(copy(A), ones(N, N)))
println(" PASS (unexpected)")
catch e
println(" ", nameof(typeof(e)))
end
println("getrs! forward:")
try
autodiff(Forward, f_getrs, Duplicated,
Duplicated(zeros(N), ones(N)), Duplicated(copy(F.factors), ones(N, N)),
Const(F.ipiv), Duplicated(copy(b), ones(N)))
println(" PASS (unexpected)")
catch e
println(" ", nameof(typeof(e)))
end
println("getrs! reverse:")
try
autodiff(Reverse, f_getrs, Active,
Duplicated(zeros(N), zeros(N)), Duplicated(copy(F.factors), zeros(N, N)),
Const(F.ipiv), Duplicated(copy(b), zeros(N)))
println(" PASS (unexpected)")
catch e
println(" ", nameof(typeof(e)))
end
A partial duplicate of #1820 .
There are some important missing rules which make differentiating standard linear algebra operations difficult. A summary:
A \ bLAPACK.getrf!LAPACK.getrs!BLAS.trsv!BLAS.trsm!On Julia 1.12.5, Enzyme.jl version = "0.13.138"
Here is code which can replicate these problems.