Skip to content

Commit cc29f75

Browse files
committed
Support decompression in one triangle
1 parent 87c8016 commit cc29f75

File tree

6 files changed

+170
-145
lines changed

6 files changed

+170
-145
lines changed

ext/SparseMatrixColoringsCUDAExt.jl

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,13 @@ function SMC.StarSetColoringResult(
4848
ag::SMC.AdjacencyGraph{T},
4949
color::Vector{<:Integer},
5050
star_set::SMC.StarSet{<:Integer},
51+
decompression_uplo::Symbol,
5152
) where {T<:Integer}
5253
group = SMC.group_by_color(T, color)
53-
compressed_indices = SMC.star_csc_indices(ag, color, star_set)
54+
compressed_indices = SMC.star_csc_indices(ag, color, star_set, decompression_uplo)
5455
additional_info = (; compressed_indices_gpu_csc=CuVector(compressed_indices))
5556
return SMC.StarSetColoringResult(
56-
A, ag, color, group, compressed_indices, additional_info
57+
A, ag, color, group, compressed_indices, decompression_uplo, additional_info
5758
)
5859
end
5960

@@ -86,12 +87,18 @@ function SMC.StarSetColoringResult(
8687
ag::SMC.AdjacencyGraph{T},
8788
color::Vector{<:Integer},
8889
star_set::SMC.StarSet{<:Integer},
90+
decompression_uplo::Symbol,
8991
) where {T<:Integer}
9092
group = SMC.group_by_color(T, color)
91-
compressed_indices = SMC.star_csc_indices(ag, color, star_set)
93+
reversed_uplo = if (decompression_uplo == :L)
94+
:U
95+
else
96+
(decompression_uplo == :U ? :L : decompression_uplo)
97+
end
98+
compressed_indices = SMC.star_csc_indices(ag, color, star_set, reversed_uplo)
9299
additional_info = (; compressed_indices_gpu_csr=CuVector(compressed_indices))
93100
return SMC.StarSetColoringResult(
94-
A, ag, color, group, compressed_indices, additional_info
101+
A, ag, color, group, compressed_indices, decompression_uplo, additional_info
95102
)
96103
end
97104

@@ -120,15 +127,7 @@ function SMC.decompress!(
120127
A::CuSparseMatrixCSC,
121128
B::CuMatrix,
122129
result::SMC.StarSetColoringResult{<:CuSparseMatrixCSC},
123-
uplo::Symbol=:F,
124130
)
125-
if uplo != :F
126-
throw(
127-
SMC.UnsupportedDecompressionError(
128-
"Single-triangle decompression is not supported on GPU matrices"
129-
),
130-
)
131-
end
132131
compressed_indices = result.additional_info.compressed_indices_gpu_csc
133132
copyto!(A.nzVal, view(B, compressed_indices))
134133
return A
@@ -138,15 +137,7 @@ function SMC.decompress!(
138137
A::CuSparseMatrixCSR,
139138
B::CuMatrix,
140139
result::SMC.StarSetColoringResult{<:CuSparseMatrixCSR},
141-
uplo::Symbol=:F,
142140
)
143-
if uplo != :F
144-
throw(
145-
SMC.UnsupportedDecompressionError(
146-
"Single-triangle decompression is not supported on GPU matrices"
147-
),
148-
)
149-
end
150141
compressed_indices = result.additional_info.compressed_indices_gpu_csr
151142
copyto!(A.nzVal, view(B, compressed_indices))
152143
return A

src/SparseMatrixColorings.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ using LinearAlgebra:
2626
issymmetric,
2727
ldiv!,
2828
parent,
29-
transpose
29+
transpose,
30+
tril,
31+
triu
3032
using PrecompileTools: @compile_workload
3133
using Random: Random, AbstractRNG, default_rng, randperm
3234
using SparseArrays:

src/decompression.jl

Lines changed: 45 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -175,19 +175,32 @@ function decompress(B::AbstractMatrix, result::AbstractColoringResult)
175175
return decompress!(A, B, result)
176176
end
177177

178+
function decompress(B::AbstractMatrix, result::AbstractColoringResult{:symmetric,:column})
179+
A = respectful_similar(result.A, eltype(B))
180+
if A isa SparseMatrixCSC && result.uplo != :F
181+
(result.uplo == :L) && (A = tril(A))
182+
(result.uplo == :U) && (A = triu(A))
183+
end
184+
return decompress!(A, B, result)
185+
end
186+
178187
function decompress(
179188
Br::AbstractMatrix,
180189
Bc::AbstractMatrix,
181190
result::AbstractColoringResult{structure,:bidirectional},
182191
) where {structure}
183192
A = respectful_similar(result.A, Base.promote_eltype(Br, Bc))
193+
if A isa SparseMatrixCSC && result.symmetric_result.uplo != :F
194+
(result.symmetric_result.uplo == :L) && (A = tril(A))
195+
(result.symmetric_result.uplo == :U) && (A = triu(A))
196+
end
184197
return decompress!(A, Br, Bc, result)
185198
end
186199

187200
"""
188201
decompress!(
189202
A::AbstractMatrix, B::AbstractMatrix,
190-
result::AbstractColoringResult{_,:column/:row}, [uplo=:F]
203+
result::AbstractColoringResult{_,:column/:row},
191204
)
192205
193206
decompress!(
@@ -204,9 +217,6 @@ The out-of-place alternative is [`decompress`](@ref).
204217
Compression means summing either the columns or the rows of `A` which share the same color.
205218
It is done by calling [`compress`](@ref).
206219
207-
For `:symmetric` coloring results (and for those only), an optional positional argument `uplo in (:U, :L, :F)` can be passed to specify which part of the matrix `A` should be updated: the Upper triangle, the Lower triangle, or the Full matrix.
208-
When `A isa SparseMatrixCSC`, using the `uplo` argument requires a target matrix which only stores the relevant triangle(s).
209-
210220
!!! warning
211221
For some coloring variants, the `result` object is mutated during decompression.
212222
@@ -260,7 +270,7 @@ function decompress! end
260270
"""
261271
decompress_single_color!(
262272
A::AbstractMatrix, b::AbstractVector, c::Integer,
263-
result::AbstractColoringResult, [uplo=:F]
273+
result::AbstractColoringResult,
264274
)
265275
266276
Decompress the vector `b` corresponding to color `c` in-place into `A`, given a `:direct` coloring `result` of the sparsity pattern of `A` (it will not work with a `:substitution` coloring).
@@ -272,9 +282,6 @@ Decompress the vector `b` corresponding to color `c` in-place into `A`, given a
272282
!!! warning
273283
This function will only update some coefficients of `A`, without resetting the rest to zero.
274284
275-
For `:symmetric` coloring results (and for those only), an optional positional argument `uplo in (:U, :L, :F)` can be passed to specify which part of the matrix `A` should be updated: the Upper triangle, the Lower triangle, or the Full matrix.
276-
When `A isa SparseMatrixCSC`, using the `uplo` argument requires a target matrix which only stores the relevant triangle(s).
277-
278285
!!! warning
279286
For some coloring variants, the `result` object is mutated during decompression.
280287
@@ -445,97 +452,66 @@ end
445452

446453
## StarSetColoringResult
447454

448-
function decompress!(
449-
A::AbstractMatrix, B::AbstractMatrix, result::StarSetColoringResult, uplo::Symbol=:F
450-
)
451-
(; ag, compressed_indices) = result
455+
function decompress!(A::AbstractMatrix, B::AbstractMatrix, result::StarSetColoringResult)
456+
(; ag, compressed_indices, uplo) = result
452457
(; S) = ag
453458
uplo == :F && check_same_pattern(A, S)
454459
fill!(A, zero(eltype(A)))
455460

456-
rvS = rowvals(S)
461+
l = 0
462+
rvS = rowvals(A)
457463
for j in axes(S, 2)
458464
for k in nzrange(S, j)
459465
i = rvS[k]
460466
if in_triangle(i, j, uplo)
461-
A[i, j] = B[compressed_indices[k]]
467+
l += 1
468+
A[i, j] = B[compressed_indices[l]]
462469
end
463470
end
464471
end
465472
return A
466473
end
467474

468475
function decompress_single_color!(
469-
A::AbstractMatrix,
470-
b::AbstractVector,
471-
c::Integer,
472-
result::StarSetColoringResult,
473-
uplo::Symbol=:F,
476+
A::SparseMatrixCSC, b::AbstractVector, c::Integer, result::StarSetColoringResult
474477
)
475-
(; ag, compressed_indices, group) = result
478+
(; ag, compressed_indices, group, uplo) = result
476479
(; S) = ag
477480
uplo == :F && check_same_pattern(A, S)
478481

479482
lower_index = (c - 1) * S.n + 1
480483
upper_index = c * S.n
481-
rvS = rowvals(S)
484+
rvA = rowvals(A)
485+
nzA = nonzeros(A)
482486
for j in group[c]
483-
for k in nzrange(S, j)
484-
# Check if the color c is used to recover A[i,j] / A[j,i]
487+
for k in nzrange(A, j)
488+
# Check if the color c is used to recover A[i,j]
485489
if lower_index <= compressed_indices[k] <= upper_index
486-
i = rvS[k]
487-
if i == j
488-
# Recover the diagonal coefficients of A
489-
A[i, i] = b[i]
490-
else
491-
# Recover the off-diagonal coefficients of A
492-
if in_triangle(i, j, uplo)
493-
A[i, j] = b[i]
494-
end
495-
if in_triangle(j, i, uplo)
496-
A[j, i] = b[i]
497-
end
498-
end
490+
i = rvA[k]
491+
nzA[k] = b[i]
499492
end
500493
end
501494
end
502495
return A
503496
end
504497

505-
function decompress!(
506-
A::SparseMatrixCSC, B::AbstractMatrix, result::StarSetColoringResult, uplo::Symbol=:F
507-
)
508-
(; ag, compressed_indices) = result
498+
function decompress!(A::SparseMatrixCSC, B::AbstractMatrix, result::StarSetColoringResult)
499+
(; ag, compressed_indices, uplo) = result
509500
(; S) = ag
510501
nzA = nonzeros(A)
511-
if uplo == :F
512-
check_same_pattern(A, S)
513-
for k in eachindex(nzA, compressed_indices)
514-
nzA[k] = B[compressed_indices[k]]
515-
end
516-
else
517-
rvS = rowvals(S)
518-
l = 0 # assume A has the same pattern as the triangle
519-
for j in axes(S, 2)
520-
for k in nzrange(S, j)
521-
i = rvS[k]
522-
if in_triangle(i, j, uplo)
523-
l += 1
524-
nzA[l] = B[compressed_indices[k]]
525-
end
526-
end
527-
end
502+
uplo == :F && check_same_pattern(A, S)
503+
for k in eachindex(nzA, compressed_indices)
504+
nzA[k] = B[compressed_indices[k]]
528505
end
529506
return A
530507
end
531508

532509
## TreeSetColoringResult
533510

534-
function decompress!(
535-
A::AbstractMatrix, B::AbstractMatrix, result::TreeSetColoringResult, uplo::Symbol=:F
536-
)
537-
(; ag, color, reverse_bfs_orders, tree_edge_indices, nt, diagonal_indices, buffer) =
538-
result
511+
function decompress!(A::AbstractMatrix, B::AbstractMatrix, result::TreeSetColoringResult)
512+
(;
513+
ag, color, reverse_bfs_orders, tree_edge_indices, nt, diagonal_indices, buffer, uplo
514+
) = result
539515
(; S) = ag
540516
uplo == :F && check_same_pattern(A, S)
541517
R = eltype(A)
@@ -586,10 +562,7 @@ function decompress!(
586562
end
587563

588564
function decompress!(
589-
A::SparseMatrixCSC{R},
590-
B::AbstractMatrix{R},
591-
result::TreeSetColoringResult,
592-
uplo::Symbol=:F,
565+
A::SparseMatrixCSC{R}, B::AbstractMatrix{R}, result::TreeSetColoringResult
593566
) where {R<:Real}
594567
(;
595568
ag,
@@ -602,6 +575,7 @@ function decompress!(
602575
lower_triangle_offsets,
603576
upper_triangle_offsets,
604577
buffer,
578+
uplo,
605579
) = result
606580
(; S) = ag
607581
A_colptr = A.colptr
@@ -702,12 +676,10 @@ end
702676
## MatrixInverseColoringResult
703677

704678
function decompress!(
705-
A::AbstractMatrix,
706-
B::AbstractMatrix,
707-
result::LinearSystemColoringResult,
708-
uplo::Symbol=:F,
679+
A::AbstractMatrix, B::AbstractMatrix, result::LinearSystemColoringResult
709680
)
710-
(; color, strict_upper_nonzero_inds, M_factorization, strict_upper_nonzeros_A) = result
681+
(; color, strict_upper_nonzero_inds, M_factorization, strict_upper_nonzeros_A, uplo) =
682+
result
711683
S = result.ag.S
712684
uplo == :F && check_same_pattern(A, S)
713685

@@ -776,7 +748,7 @@ function decompress!(
776748
nzval = Vector{R}(undef, length(large_rowval))
777749
A_and_noAᵀ = SparseMatrixCSC(m + n, m + n, large_colptr, large_rowval, nzval)
778750
Br_and_Bc = _join_compressed!(result, Br, Bc)
779-
decompress!(A_and_noAᵀ, Br_and_Bc, symmetric_result, :L)
751+
decompress!(A_and_noAᵀ, Br_and_Bc, symmetric_result)
780752
rvA = rowvals(A_and_noAᵀ)
781753
nzA = nonzeros(A_and_noAᵀ)
782754
for j in 1:n
@@ -797,6 +769,6 @@ function decompress!(
797769
A_and_noAᵀ = SparseMatrixCSC(m + n, m + n, large_colptr, large_rowval, A.nzval)
798770
# decompress lower triangle only
799771
Br_and_Bc = _join_compressed!(result, Br, Bc)
800-
decompress!(A_and_noAᵀ, Br_and_Bc, symmetric_result, :L)
772+
decompress!(A_and_noAᵀ, Br_and_Bc, symmetric_result)
801773
return A
802774
end

0 commit comments

Comments
 (0)