Skip to content

Commit 703246c

Browse files
gdalleamontoison
authored andcommitted
feat: allow full direct decompression into larger SparseMatrixCSC
1 parent 0553284 commit 703246c

4 files changed

Lines changed: 88 additions & 27 deletions

File tree

src/decompression.jl

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,8 @@ The out-of-place alternative is [`decompress`](@ref).
189189
190190
!!! note
191191
In-place decompression is faster when `A isa SparseMatrixCSC`.
192+
- In general, this case requires the sparsity pattern of `A` to match the sparsity pattern `S` from which the coloring result was computed.
193+
- For a coloring result with `decompression=:direct`, we also allow _full_ decompression into an `A` whose sparsity pattern is a strict superset of `S`.
192194
193195
Compression means summing either the columns or the rows of `A` which share the same color.
194196
It is done by calling [`compress`](@ref).
@@ -356,10 +358,25 @@ end
356358
function decompress!(A::SparseMatrixCSC, B::AbstractMatrix, result::ColumnColoringResult)
357359
(; compressed_indices) = result
358360
S = result.bg.S2
359-
check_same_pattern(A, S)
361+
check_same_pattern(A, S; allow_superset=true)
360362
nzA = nonzeros(A)
361-
for k in eachindex(nzA, compressed_indices)
362-
nzA[k] = B[compressed_indices[k]]
363+
if nnz(A) == nnz(S)
364+
for k in eachindex(compressed_indices)
365+
nzA[k] = B[compressed_indices[k]]
366+
end
367+
else # nnz(A) > nnz(Z)
368+
fill!(nonzeros(A), zero(eltype(A)))
369+
rvA, rvS = rowvals(A), rowvals(S)
370+
shift = 0
371+
for j in axes(S, 2)
372+
for k in nzrange(S, j)
373+
i = rvS[k]
374+
while (k + shift) < A.colptr[j] || rvA[k + shift] < i
375+
shift += 1
376+
end
377+
nzA[k + shift] = B[compressed_indices[k]]
378+
end
379+
end
363380
end
364381
return A
365382
end
@@ -418,10 +435,25 @@ end
418435
function decompress!(A::SparseMatrixCSC, B::AbstractMatrix, result::RowColoringResult)
419436
(; compressed_indices) = result
420437
S = result.bg.S2
421-
check_same_pattern(A, S)
438+
check_same_pattern(A, S; allow_superset=true)
422439
nzA = nonzeros(A)
423-
for k in eachindex(nzA, compressed_indices)
424-
nzA[k] = B[compressed_indices[k]]
440+
if nnz(A) == nnz(S)
441+
for k in eachindex(nzA, compressed_indices)
442+
nzA[k] = B[compressed_indices[k]]
443+
end
444+
else # nnz(A) > nnz(S)
445+
fill!(nonzeros(A), zero(eltype(A)))
446+
rvA, rvS = rowvals(A), rowvals(S)
447+
shift = 0
448+
for j in axes(S, 2)
449+
for k in nzrange(S, j)
450+
i = rvS[k]
451+
while (k + shift) < A.colptr[j] || rvA[k + shift] < i
452+
shift += 1
453+
end
454+
nzA[k + shift] = B[compressed_indices[k]]
455+
end
456+
end
425457
end
426458
return A
427459
end
@@ -492,9 +524,24 @@ function decompress!(
492524
(; S) = ag
493525
nzA = nonzeros(A)
494526
if uplo == :F
495-
check_same_pattern(A, S)
496-
for k in eachindex(nzA, compressed_indices)
497-
nzA[k] = B[compressed_indices[k]]
527+
check_same_pattern(A, S; allow_superset=true)
528+
if nnz(A) == nnz(S)
529+
for k in eachindex(nzA, compressed_indices)
530+
nzA[k] = B[compressed_indices[k]]
531+
end
532+
else # nnz(A) > nnz(S)
533+
fill!(nonzeros(A), zero(eltype(A)))
534+
rvA, rvS = rowvals(A), rowvals(S)
535+
shift = 0
536+
for j in axes(S, 2)
537+
for k in nzrange(S, j)
538+
i = rvS[k]
539+
while (k + shift) < A.colptr[j] || rvA[k + shift] < i
540+
shift += 1
541+
end
542+
nzA[k + shift] = B[compressed_indices[k]]
543+
end
544+
end
498545
end
499546
else
500547
rvS = rowvals(S)

src/matrices.jl

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -61,25 +61,24 @@ function respectful_similar(A::Union{Symmetric,Hermitian}, ::Type{T}) where {T}
6161
return respectful_similar(sparse(A), T)
6262
end
6363

64-
"""
65-
same_pattern(A, B)
66-
67-
Perform a partial equality check on the sparsity patterns of `A` and `B`:
68-
69-
- if the return is `true`, they might have the same sparsity pattern but we're not sure
70-
- if the return is `false`, they definitely don't have the same sparsity pattern
71-
"""
72-
same_pattern(A, B) = size(A) == size(B)
73-
74-
function same_pattern(
75-
A::Union{SparseMatrixCSC,SparsityPatternCSC},
76-
B::Union{SparseMatrixCSC,SparsityPatternCSC},
77-
)
78-
return size(A) == size(B) && nnz(A) == nnz(B)
64+
same_pattern(A::AbstractMatrix, S; allow_superset::Bool=false) = true
65+
function same_pattern(A::SparseMatrixCSC, S; allow_superset::Bool=false)
66+
return allow_superset ? nnz(A) >= nnz(S) : nnz(A) == nnz(S)
7967
end
8068

81-
function check_same_pattern(A, S)
82-
if !same_pattern(A, S)
83-
throw(DimensionMismatch("`A` and `S` must have the same sparsity pattern."))
69+
function check_same_pattern(A, S; allow_superset::Bool=false)
70+
if size(A) != size(S)
71+
throw(
72+
DimensionMismatch(
73+
"Decompression target must have the same size as sparsity pattern"
74+
),
75+
)
76+
elseif !same_pattern(A, S; allow_superset)
77+
throw(
78+
DimensionMismatch(
79+
"""Decompression target must $(allow_superset ? "contain the nonzeros of the sparsity pattern" : "be equal to the sparsity pattern") used for coloring""",
80+
),
81+
)
8482
end
83+
return true
8584
end

test/matrices.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,5 +48,8 @@ end
4848
@test !same_pattern(A2, S)
4949
@test same_pattern(Matrix(A2), S)
5050

51+
@test_throws DimensionMismatch check_same_pattern(vcat(A1, A1), S)
5152
@test_throws DimensionMismatch check_same_pattern(A2, S)
53+
@test check_same_pattern(A1, S; allow_superset=true)
54+
@test check_same_pattern(A2, S; allow_superset=true)
5255
end

test/utils.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,18 @@ function test_coloring_decompression(
8686
@test decompress(B, result) A0 # check result wasn't modified
8787
@test decompress!(respectful_similar(A, eltype(B)), B, result) A0
8888
@test decompress!(respectful_similar(A, eltype(B)), B, result) A0
89+
if decompression == :direct && A isa SparseMatrixCSC
90+
A_bigger = respectful_similar(A, eltype(B))
91+
for _ in 1:10
92+
nb_coeffs_added = rand(1:minimum(size(A)))
93+
for _ in nb_coeffs_added
94+
i = rand(axes(A, 1))
95+
j = rand(axes(A, 2))
96+
A_bigger[i, j] = one(eltype(B))
97+
end
98+
@test decompress!(A_bigger, B, result) A0
99+
end
100+
end
89101
end
90102

91103
@testset "Single-color decompression" begin

0 commit comments

Comments
 (0)