Skip to content

Commit 27808f0

Browse files
committed
Support decompression in one triangle
1 parent 904eb07 commit 27808f0

File tree

6 files changed

+172
-139
lines changed

6 files changed

+172
-139
lines changed

ext/SparseMatrixColoringsCUDAExt.jl

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,13 @@ function SMC.StarSetColoringResult(
4848
ag::SMC.AdjacencyGraph{T},
4949
color::Vector{<:Integer},
5050
star_set::SMC.StarSet{<:Integer},
51+
decompression_uplo::Symbol,
5152
) where {T<:Integer}
5253
group = SMC.group_by_color(T, color)
53-
compressed_indices = SMC.star_csc_indices(ag, color, star_set)
54+
compressed_indices = SMC.star_csc_indices(ag, color, star_set, decompression_uplo)
5455
additional_info = (; compressed_indices_gpu_csc=CuVector(compressed_indices))
5556
return SMC.StarSetColoringResult(
56-
A, ag, color, group, compressed_indices, additional_info
57+
A, ag, color, group, compressed_indices, decompression_uplo, additional_info
5758
)
5859
end
5960

@@ -86,12 +87,14 @@ function SMC.StarSetColoringResult(
8687
ag::SMC.AdjacencyGraph{T},
8788
color::Vector{<:Integer},
8889
star_set::SMC.StarSet{<:Integer},
90+
decompression_uplo::Symbol
8991
) where {T<:Integer}
9092
group = SMC.group_by_color(T, color)
91-
compressed_indices = SMC.star_csc_indices(ag, color, star_set)
93+
reversed_uplo = (decompression_uplo == :L) ? :U : (decompression_uplo == :U ? :L : decompression_uplo)
94+
compressed_indices = SMC.star_csc_indices(ag, color, star_set, reversed_uplo)
9295
additional_info = (; compressed_indices_gpu_csr=CuVector(compressed_indices))
9396
return SMC.StarSetColoringResult(
94-
A, ag, color, group, compressed_indices, additional_info
97+
A, ag, color, group, compressed_indices, decompression_uplo, additional_info
9598
)
9699
end
97100

@@ -120,15 +123,7 @@ function SMC.decompress!(
120123
A::CuSparseMatrixCSC,
121124
B::CuMatrix,
122125
result::SMC.StarSetColoringResult{<:CuSparseMatrixCSC},
123-
uplo::Symbol=:F,
124126
)
125-
if uplo != :F
126-
throw(
127-
SMC.UnsupportedDecompressionError(
128-
"Single-triangle decompression is not supported on GPU matrices"
129-
),
130-
)
131-
end
132127
compressed_indices = result.additional_info.compressed_indices_gpu_csc
133128
copyto!(A.nzVal, view(B, compressed_indices))
134129
return A
@@ -138,15 +133,7 @@ function SMC.decompress!(
138133
A::CuSparseMatrixCSR,
139134
B::CuMatrix,
140135
result::SMC.StarSetColoringResult{<:CuSparseMatrixCSR},
141-
uplo::Symbol=:F,
142136
)
143-
if uplo != :F
144-
throw(
145-
SMC.UnsupportedDecompressionError(
146-
"Single-triangle decompression is not supported on GPU matrices"
147-
),
148-
)
149-
end
150137
compressed_indices = result.additional_info.compressed_indices_gpu_csr
151138
copyto!(A.nzVal, view(B, compressed_indices))
152139
return A

src/SparseMatrixColorings.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ using LinearAlgebra:
2626
issymmetric,
2727
ldiv!,
2828
parent,
29-
transpose
29+
transpose,
30+
tril,
31+
triu
3032
using PrecompileTools: @compile_workload
3133
using Random: Random, AbstractRNG, default_rng, randperm
3234
using SparseArrays:

src/decompression.jl

Lines changed: 56 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -175,19 +175,34 @@ function decompress(B::AbstractMatrix, result::AbstractColoringResult)
175175
return decompress!(A, B, result)
176176
end
177177

178+
function decompress(
179+
B::AbstractMatrix,
180+
result::AbstractColoringResult{:symmetric,:column})
181+
A = respectful_similar(result.A, eltype(B))
182+
if A isa SparseMatrixCSC && result.uplo != :F
183+
(result.uplo == :L) && (A = tril(A))
184+
(result.uplo == :U) && (A = triu(A))
185+
end
186+
return decompress!(A, B, result)
187+
end
188+
178189
function decompress(
179190
Br::AbstractMatrix,
180191
Bc::AbstractMatrix,
181192
result::AbstractColoringResult{structure,:bidirectional},
182193
) where {structure}
183194
A = respectful_similar(result.A, Base.promote_eltype(Br, Bc))
195+
if A isa SparseMatrixCSC && result.symmetric_result.uplo != :F
196+
(result.symmetric_result.uplo == :L) && (A = tril(A))
197+
(result.symmetric_result.uplo == :U) && (A = triu(A))
198+
end
184199
return decompress!(A, Br, Bc, result)
185200
end
186201

187202
"""
188203
decompress!(
189204
A::AbstractMatrix, B::AbstractMatrix,
190-
result::AbstractColoringResult{_,:column/:row}, [uplo=:F]
205+
result::AbstractColoringResult{_,:column/:row},
191206
)
192207
193208
decompress!(
@@ -204,9 +219,6 @@ The out-of-place alternative is [`decompress`](@ref).
204219
Compression means summing either the columns or the rows of `A` which share the same color.
205220
It is done by calling [`compress`](@ref).
206221
207-
For `:symmetric` coloring results (and for those only), an optional positional argument `uplo in (:U, :L, :F)` can be passed to specify which part of the matrix `A` should be updated: the Upper triangle, the Lower triangle, or the Full matrix.
208-
When `A isa SparseMatrixCSC`, using the `uplo` argument requires a target matrix which only stores the relevant triangle(s).
209-
210222
!!! warning
211223
For some coloring variants, the `result` object is mutated during decompression.
212224
@@ -260,7 +272,7 @@ function decompress! end
260272
"""
261273
decompress_single_color!(
262274
A::AbstractMatrix, b::AbstractVector, c::Integer,
263-
result::AbstractColoringResult, [uplo=:F]
275+
result::AbstractColoringResult,
264276
)
265277
266278
Decompress the vector `b` corresponding to color `c` in-place into `A`, given a `:direct` coloring `result` of the sparsity pattern of `A` (it will not work with a `:substitution` coloring).
@@ -272,9 +284,6 @@ Decompress the vector `b` corresponding to color `c` in-place into `A`, given a
272284
!!! warning
273285
This function will only update some coefficients of `A`, without resetting the rest to zero.
274286
275-
For `:symmetric` coloring results (and for those only), an optional positional argument `uplo in (:U, :L, :F)` can be passed to specify which part of the matrix `A` should be updated: the Upper triangle, the Lower triangle, or the Full matrix.
276-
When `A isa SparseMatrixCSC`, using the `uplo` argument requires a target matrix which only stores the relevant triangle(s).
277-
278287
!!! warning
279288
For some coloring variants, the `result` object is mutated during decompression.
280289
@@ -446,95 +455,68 @@ end
446455
## StarSetColoringResult
447456

448457
function decompress!(
449-
A::AbstractMatrix, B::AbstractMatrix, result::StarSetColoringResult, uplo::Symbol=:F
450-
)
451-
(; ag, compressed_indices) = result
458+
A::AbstractMatrix, B::AbstractMatrix, result::StarSetColoringResult)
459+
(; ag, compressed_indices, uplo) = result
452460
(; S) = ag
453461
uplo == :F && check_same_pattern(A, S)
454462
fill!(A, zero(eltype(A)))
455463

456-
rvS = rowvals(S)
464+
l = 0
465+
rvS = rowvals(A)
457466
for j in axes(S, 2)
458467
for k in nzrange(S, j)
459468
i = rvS[k]
460469
if in_triangle(i, j, uplo)
461-
A[i, j] = B[compressed_indices[k]]
470+
l += 1
471+
A[i, j] = B[compressed_indices[l]]
462472
end
463473
end
464474
end
465475
return A
466476
end
467477

468478
function decompress_single_color!(
469-
A::AbstractMatrix,
479+
A::SparseMatrixCSC,
470480
b::AbstractVector,
471481
c::Integer,
472482
result::StarSetColoringResult,
473-
uplo::Symbol=:F,
474483
)
475-
(; ag, compressed_indices, group) = result
484+
(; ag, compressed_indices, group, uplo) = result
476485
(; S) = ag
486+
println(uplo)
477487
uplo == :F && check_same_pattern(A, S)
478488

479489
lower_index = (c - 1) * S.n + 1
480490
upper_index = c * S.n
481-
rvS = rowvals(S)
491+
rvA = rowvals(A)
492+
nzA = nonzeros(A)
482493
for j in group[c]
483-
for k in nzrange(S, j)
484-
# Check if the color c is used to recover A[i,j] / A[j,i]
494+
for k in nzrange(A, j)
495+
# Check if the color c is used to recover A[i,j]
485496
if lower_index <= compressed_indices[k] <= upper_index
486-
i = rvS[k]
487-
if i == j
488-
# Recover the diagonal coefficients of A
489-
A[i, i] = b[i]
490-
else
491-
# Recover the off-diagonal coefficients of A
492-
if in_triangle(i, j, uplo)
493-
A[i, j] = b[i]
494-
end
495-
if in_triangle(j, i, uplo)
496-
A[j, i] = b[i]
497-
end
498-
end
497+
i = rvA[k]
498+
nzA[k] = b[i]
499499
end
500500
end
501501
end
502502
return A
503503
end
504504

505-
function decompress!(
506-
A::SparseMatrixCSC, B::AbstractMatrix, result::StarSetColoringResult, uplo::Symbol=:F
507-
)
508-
(; ag, compressed_indices) = result
505+
function decompress!(A::SparseMatrixCSC, B::AbstractMatrix, result::StarSetColoringResult)
506+
(; ag, compressed_indices, uplo) = result
509507
(; S) = ag
510508
nzA = nonzeros(A)
511-
if uplo == :F
512-
check_same_pattern(A, S)
513-
for k in eachindex(nzA, compressed_indices)
514-
nzA[k] = B[compressed_indices[k]]
515-
end
516-
else
517-
rvS = rowvals(S)
518-
l = 0 # assume A has the same pattern as the triangle
519-
for j in axes(S, 2)
520-
for k in nzrange(S, j)
521-
i = rvS[k]
522-
if in_triangle(i, j, uplo)
523-
l += 1
524-
nzA[l] = B[compressed_indices[k]]
525-
end
526-
end
527-
end
509+
uplo == :F && check_same_pattern(A, S)
510+
for k in eachindex(nzA, compressed_indices)
511+
nzA[k] = B[compressed_indices[k]]
528512
end
529513
return A
530514
end
531515

532516
## TreeSetColoringResult
533517

534-
function decompress!(
535-
A::AbstractMatrix, B::AbstractMatrix, result::TreeSetColoringResult, uplo::Symbol=:F
536-
)
537-
(; ag, color, reverse_bfs_orders, tree_edge_indices, nt, buffer) = result
518+
function decompress!(A::AbstractMatrix, B::AbstractMatrix, result::TreeSetColoringResult)
519+
(; ag, color, reverse_bfs_orders, tree_edge_indices, nt, diagonal_indices, buffer, uplo) = result
538520
(; S) = ag
539521
uplo == :F && check_same_pattern(A, S)
540522
R = eltype(A)
@@ -548,10 +530,8 @@ function decompress!(
548530

549531
# Recover the diagonal coefficients of A
550532
if !augmented_graph(ag)
551-
for i in axes(S, 1)
552-
if !iszero(S[i, i])
553-
A[i, i] = B[i, color[i]]
554-
end
533+
for i in diagonal_indices
534+
A[i, i] = B[i, color[i]]
555535
end
556536
end
557537

@@ -590,7 +570,6 @@ function decompress!(
590570
A::SparseMatrixCSC{R},
591571
B::AbstractMatrix{R},
592572
result::TreeSetColoringResult,
593-
uplo::Symbol=:F,
594573
) where {R<:Real}
595574
(;
596575
ag,
@@ -603,6 +582,7 @@ function decompress!(
603582
lower_triangle_offsets,
604583
upper_triangle_offsets,
605584
buffer,
585+
uplo,
606586
) = result
607587
(; S) = ag
608588
A_colptr = A.colptr
@@ -706,9 +686,8 @@ function decompress!(
706686
A::AbstractMatrix,
707687
B::AbstractMatrix,
708688
result::LinearSystemColoringResult,
709-
uplo::Symbol=:F,
710689
)
711-
(; color, strict_upper_nonzero_inds, M_factorization, strict_upper_nonzeros_A) = result
690+
(; color, strict_upper_nonzero_inds, M_factorization, strict_upper_nonzeros_A, uplo) = result
712691
S = result.ag.S
713692
uplo == :F && check_same_pattern(A, S)
714693

@@ -770,10 +749,20 @@ end
770749
function decompress!(
771750
A::AbstractMatrix, Br::AbstractMatrix, Bc::AbstractMatrix, result::BicoloringResult
772751
)
752+
(; large_colptr, large_rowval, symmetric_result) = result
773753
m, n = size(A)
774754
Br_and_Bc = _join_compressed!(result, Br, Bc)
775-
A_and_Aᵀ = decompress(Br_and_Bc, result.symmetric_result)
776-
copyto!(A, A_and_Aᵀ[(n + 1):(n + m), 1:n]) # original matrix in bottom left corner
755+
nzval = Vector{eltype(A)}(undef, length(large_rowval))
756+
A_and_noAᵀ = SparseMatrixCSC(m + n, m + n, large_colptr, large_rowval, nzval)
757+
decompress!(A_and_noAᵀ, Br_and_Bc, symmetric_result)
758+
rvA = rowvals(A_and_noAᵀ)
759+
nzA = nonzeros(A_and_noAᵀ)
760+
for j in axes(A_and_noAᵀ, 2)
761+
for k in nzrange(A_and_noAᵀ, j)
762+
i = rvA[k]
763+
A[i-n, j] = nzA[k]
764+
end
765+
end
777766
return A
778767
end
779768

@@ -786,6 +775,6 @@ function decompress!(
786775
# pretend A is larger
787776
A_and_noAᵀ = SparseMatrixCSC(m + n, m + n, large_colptr, large_rowval, A.nzval)
788777
# decompress lower triangle only
789-
decompress!(A_and_noAᵀ, Br_and_Bc, symmetric_result, :L)
778+
decompress!(A_and_noAᵀ, Br_and_Bc, symmetric_result)
790779
return A
791780
end

0 commit comments

Comments
 (0)