Skip to content

LU and triangular solve rules #3039

@jlperla

Description

@jlperla

A partial duplicate of #1820 .

There are some important missing rules which make differentiating standard linear algebra operations difficult. A summary:

Call Forward Reverse Notes
A \ b fwd: TypeAnalysis error
LAPACK.getrf! NoDerivativeError
LAPACK.getrs! NoDerivativeError
BLAS.trsv! ✓† ✓† † fallback BLAS warning
BLAS.trsm! ✓† ✓† † fallback BLAS warning

On Julia 1.12.5, Enzyme.jl version = "0.13.138"

Here is code which can replicate these problems.

# Missing LAPACK rules:  getrf! (dgetrf_64_), getrs! (dgetrs_64_)
# Fallback BLAS:         trsv! (dtrsv_64_), trsm! (dtrsm_64_)
#
# LU solve: pre-factor outside AD, differentiate solve via trsv! inside, 
# Note: cholesky!/potrf! and ldiv!(::Cholesky)/potrs! have native rules.

using LinearAlgebra, LinearAlgebra.BLAS, LinearAlgebra.LAPACK
using Random, Test, Enzyme, EnzymeTestUtils

Random.seed!(42)
const N = 6
A = randn(N, N)
b = randn(N)
F = lu(A)
U = collect(UpperTriangular(F.factors))

# ─── Passing tests ───────────────────────────────────────────────────────

@testset "trsv! (fallback warning)" begin
    f!(x, U, b) = (copyto!(x, b); BLAS.trsv!('U', 'N', 'N', U, x); sum(x))
    test_forward(f!, Duplicated, (zeros(N), Duplicated), (copy(U), Duplicated), (copy(b), Duplicated))
    test_reverse(f!, Active, (zeros(N), Duplicated), (copy(U), Duplicated), (copy(b), Duplicated))
end

@testset "trsm! (fallback warning)" begin
    B = randn(N, 3)
    f!(X, U, B) = (copyto!(X, B); BLAS.trsm!('L', 'U', 'N', 'N', 1.0, U, X); sum(X))
    test_forward(f!, Duplicated, (zeros(N, 3), Duplicated), (copy(U), Duplicated), (copy(B), Duplicated))
    test_reverse(f!, Active, (zeros(N, 3), Duplicated), (copy(U), Duplicated), (copy(B), Duplicated))
end

function _apply_ipiv!(b, ipiv)
    @inbounds for i in eachindex(ipiv)
        if ipiv[i] != i
            b[i], b[ipiv[i]] = b[ipiv[i]], b[i]
        end
    end
    return b
end

function lu_trsv!(x, factors, ipiv, b)
    copyto!(x, b)
    _apply_ipiv!(x, ipiv)
    BLAS.trsv!('L', 'N', 'U', factors, x)
    BLAS.trsv!('U', 'N', 'N', factors, x)
    return sum(x)
end

@testset "LU solve: pre-factor + trsv!" begin
    x = zeros(N); lu_trsv!(x, F.factors, F.ipiv, b)
    @test x  A \ b
    test_forward(lu_trsv!, Duplicated,
        (zeros(N), Duplicated), (copy(F.factors), Duplicated), (F.ipiv, Const), (copy(b), Duplicated))
    test_reverse(lu_trsv!, Active,
        (zeros(N), Duplicated), (copy(F.factors), Duplicated), (F.ipiv, Const), (copy(b), Duplicated))
end

# ─── Expected failures ───────────────────────────────────────────────────

println("\n--- Expected failures ---")

f_getrf(buf, A) = (copyto!(buf, A); LAPACK.getrf!(buf, zeros(LinearAlgebra.BlasInt, N)); sum(buf))
f_getrs(x, fac, ipiv, b) = (copyto!(x, b); LAPACK.getrs!('N', fac, ipiv, x); sum(x))

println("A\\b forward:")
try
    autodiff(Forward, (A, b) -> sum(A \ b), Duplicated,
        Duplicated(copy(A), ones(N, N)), Duplicated(copy(b), ones(N)))
    println("  PASS (unexpected)")
catch e
    println("  ", nameof(typeof(e)))
end

println("getrf! forward:")
try
    autodiff(Forward, f_getrf, Duplicated,
        Duplicated(zeros(N, N), ones(N, N)), Duplicated(copy(A), ones(N, N)))
    println("  PASS (unexpected)")
catch e
    println("  ", nameof(typeof(e)))
end

println("getrs! forward:")
try
    autodiff(Forward, f_getrs, Duplicated,
        Duplicated(zeros(N), ones(N)), Duplicated(copy(F.factors), ones(N, N)),
        Const(F.ipiv), Duplicated(copy(b), ones(N)))
    println("  PASS (unexpected)")
catch e
    println("  ", nameof(typeof(e)))
end

println("getrs! reverse:")
try
    autodiff(Reverse, f_getrs, Active,
        Duplicated(zeros(N), zeros(N)), Duplicated(copy(F.factors), zeros(N, N)),
        Const(F.ipiv), Duplicated(copy(b), zeros(N)))
    println("  PASS (unexpected)")
catch e
    println("  ", nameof(typeof(e)))
end

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions